#! /usr/bin/env python
#
#  computes a membrane energy
#  potential between helices
#  with the scoring method provided
#  by Fleishman and Ben-Tal
#  2002, JMB, 321, pp. 363-378.
#
#

import sys
import trace
import string
import re
import xplorNIH
import xplor
import math
from derivList import DerivList
from pyPot import PyPot
import vec3
from atomSel import AtomSel
import copy

class MembranePot(PyPot):

  def __init__( s, name ):
      PyPot.__init__( s, name, s )
      # defaul t reference frame is XYZ axes
      s . data = {}
      s . sim = xplor. simulation
      s . harmonic = 1
      s . defaultScale = 1.
      s . weight = 1.
      s . BurialCutoff = 0.2
      s . LargeBurialCutoff = 0.9
      s . data = {} # cached internal data
      s . score = 0.
      s . brainDump = 0
      s . brain = 0
      # below is taken from Table 1 p.366
      s . MaxTable = [ [ -1., -1., -.5, -.25, 0.],
                       [ -1.,  0., -.5,   0., 0.],
                       [ -.5, -.5, -.5,   0., 0.],
                       [-.25,  0.,  0.,   0., 0.],
                       [  0.,  0.,  0.,   0., 0.] ]

  def setWeighting ( s, w ) :
    s . weight = w

  def dumpListasPDB ( s, l ) :
    # dump backbone
    for i in range ( len ( l ) ):
      print "ATOM    %3d  CA  ALA   %3d     %7.3f %7.3f %7.3f    1.0 0   0.         G" %\
            ( i, i, l[i][0], l[i][1], l[i][2] )

  def buildBackboneList ( s ) :
    # build list of alpha carbon
    selString = '( name CA)' 
    CAsel = AtomSel ( selString )
    s . data [ 'resID' ] = []
    s . data [ 'resName' ] = []
    s . data [ 'segID' ] = []
    s . data [ 'CA' ] = []
    for atom in CAsel :
      s . data [ 'segID' ] . append ( atom . segmentName ( ) )
      s . data [ 'resID' ] . append ( atom . residueNum ( ) )
      s . data [ 'resName' ] . append ( atom . residueName ( ) )
      s . data [ 'CA' ] . append ( atom . pos() )
    # split up separate helices
    s . data [ 'helices'] = []
    s . data [ 'helIndices' ] = []
    for seg in s . data [ 'segID' ] :
      if not s . data [ 'helices' ] . count ( seg ) :
        s . data [ 'helices' ] . append ( seg ) 
    for helix in s . data [ 'helices' ] :
      hx = []
      for i in range ( len ( s . data [ 'CA' ] ) ):
        if s . data [ 'segID' ][i] == helix :
          hx . append ( i )
      s . data [ 'helIndices' ] . append ( hx )
    
  def cross ( s, a, b ) :
    # not in vec3 class for some reason
    i = a[1] * b[2] - a[2] * b[1]
    j = a[0] * b[2] - a[2] * b[0]
    k = a[0] * b[1] - a[1] * b[0]
    return vec3 . Vec3 ( i, -j, k )

  def findPj ( s, C_i, j ) :
    # finds the closest point on helical
    # axis of helix J to residue I
    p_jlen = 99999.
    hx = s . data [ 'helIndices'][j]
    for k in range ( len ( hx ) ) :
      p_i = C_i - s . data [ 'p' ][hx[k]]
      if vec3 . norm ( p_i ) < p_jlen:
        p_jlen = vec3 . norm ( p_i )
        p_j = s . data [ 'p' ][hx[k]]
    return ( p_j, p_jlen )

  def getAngleandDistanceToHelix ( s, i, j ) :
    # compute angular orientation and
    # distance of res i with
    # repsect to helix j
    C_i = s . data [ 'CA'][i]
    p_j, D_i = s . findPj ( C_i, j )
    ivec = C_i - s . data [ 'p' ] [i]
    ijvec = p_j - s . data [ 'p' ][i]
    nivec = 1.0 / vec3 . norm ( ivec )
    nijvec = 1.0 / vec3 . norm ( ijvec )
    ivec = nivec * ivec
    ijvec = nijvec * ijvec
    A_i = math . acos ( vec3 . dot ( ivec , ijvec ) )
    return A_i, D_i

  def computeBurialScore ( s ) :
    # using equtions [4] and [5] from paper
    # 
    s . data [ 'Burial' ] = []
    s . data [ 'oppHel' ] = []
    for i in range ( len ( s . data [ 'CA' ] ) ) :
      # make sure in helix first
      helIndx = -1
      for j in range ( len ( s . data [ 'helices' ] )) :
        if s . data [ 'segID' ][i] == s . data [ 'helices'][j] :
          helIndx = j
      if helIndx == -1:
        s . data [ 'Burial' ] . append ( 0. )
        s . data [ 'oppHel' ] . append ( -1 )
      else:
        AScore_i = 0.
        DScore_i = 0.
        for j in range ( len ( s . data [ 'helices' ] )) :
          if j <> helIndx :
            # this is a different helix
            A_i, D_i = s . getAngleandDistanceToHelix ( i, j )
            # these formulas are taken directly from paper
            # Eqs [4] and [5]
            delta = ( D_i - 4.3 ) / 2.5
            delta6 = delta * delta * delta * delta * delta * delta
            DScore = 1. / ( delta6 + 1. )
            delta = ( A_i / ( math.pi/3.) )
            delta4 = delta * delta * delta * delta
            AScore = 1./ ( delta4 + 1. )
            if ( ( AScore * DScore ) > ( AScore_i * DScore_i ) ) :
              AScore_i = AScore
              DScore_i = DScore
              opposing = j
        s . data [ 'Burial' ] . append ( DScore_i * AScore_i )
        if s . brainDump and s . data [ 'Burial' ] [i] > s . BurialCutoff:
          s . brain . write ( str(i) + ' ' + str(DScore_i) + ' ' + str(AScore_i) + \
                              ' ' + str(s . data['Burial'][i]) + '\n')
        if ( DScore_i * AScore_i ) > 0. :
          s . data [ 'oppHel' ] .append ( opposing )
        else:
          s . data [ 'oppHel' ] . append ( -1 )
        
  def computeLocalHelicalAxis ( s ):
    s . buildBackboneList ( )

    s . data [ 'helAxis' ] = []  
    for j in range ( len ( s . data [ 'helices' ] ) ) :
      hx = s . data [ 'helIndices' ][j]
      # using Cthothia algorithm
      for i in range ( len ( hx ) - 3 ) :
        C_i = s . data [ 'CA' ][hx[i]]
        C_ip1 = s . data [ 'CA' ][hx[i + 1]] 
        C_ip2 = s . data [ 'CA' ][hx[i + 2]]
        C_ip3 = s . data [ 'CA' ][hx[i + 3]]
        Q_i =   C_i   +  C_ip2 - 2. * C_ip1
        Q_ip1 = C_ip1 +  C_ip3 - 2. * C_ip2
        v_i = s . cross ( Q_i, Q_ip1 )
        vlen = 1.0 / vec3. norm ( v_i )
        v_i = vlen * v_i
        s . data [ 'helAxis' ] . append ( v_i )

      # extend the local helical axis to the end per definition
      s . data [ 'helAxis' ] . append ( s . data [ 'helAxis' ][-1] )
      s . data [ 'helAxis' ] . append ( s . data [ 'helAxis' ][-1] )
      s . data [ 'helAxis' ] . append ( s . data [ 'helAxis' ][-1] )


  def computeNearestPointOnHelAxis ( s ) :
    # this is just the geometric average of
    # the CAs from residues i-1 to i + 3 excluding i
    s . data [ 'p' ] = []
    for j in range ( len ( s . data [ 'helices' ] ) ) :
      hx = s . data [ 'helIndices' ][j]
      for i in range ( len ( hx ) ) :
        C_i = s . data [ 'CA' ][hx[i]]
        if i == 0:
          # fake a Ca at 1.5 ang in direction of helical axis
          offset = 1.5 * s . data [ 'helAxis' ][hx[i]]
          C_im1 = C_i - offset
        else:
          C_im1 = s . data [ 'CA' ][hx[i - 1]]

        if i >= ( len ( hx ) - 1 ) :
          offset = 1.5 * s . data [ 'helAxis' ][hx[i]]
          C_ip1 = C_i + offset
        else:
          C_ip1 = s . data [ 'CA' ][hx[i + 1]]

        if i >= ( len ( hx ) - 2 ) :
          offset = 3.0 * s . data [ 'helAxis' ][hx[i]]
          C_ip2 = C_i + offset
        else:
          C_ip2 = s . data [ 'CA' ][hx[i + 2]]

        avgC = C_im1 + C_i + C_ip1 + C_ip2
        avgC = 0.25 * avgC
        s . data [ 'p' ] . append ( avgC )
        
  def getPairIndex ( s, res ) :
    if res == 'GLY' :
      return ( 0 ) 
    if res == 'ILE' or res == 'THR' or res == 'VAL' :
      return ( 1 )
    if res == 'ALA' or res == 'CYS' or res == 'SER' :
      return ( 2 )
    if res == 'LEU' or res == 'ASN' or res == 'PRO' :
      return ( 3 )
    return ( 4 )


  def getInteractingPairs ( s ) :
    s . data [ 'wcon' ] = []
    s . data [ 'pairs' ] = []

    
    for h in range ( len ( s . data [ 'helices' ] ) ) :
      hx = s . data [ 'helIndices' ][h]
      # find the stretch of 10 consecutive residues where
      # burial is maximal
      numTries = len ( hx ) - 9
      maxScore = 0.
      maxIndex = 0
      for i in range ( numTries ):
        start = hx [ i ]
        score = 0.
        for j in range ( start, start + 10 ) :
          if s . data [ 'Burial' ][j] > s . BurialCutoff :
            score = score + s . data [ 'Burial'][j]
        if score > maxScore :
          maxScore = score
          maxIndex = start
      if maxScore > 0. :
        s . data [ 'wcon' ] . append ( maxIndex )
      if maxScore == 0.:
        # no buried helix found
        s . data [ 'wcon' ] . append ( -1 )

    # now form pairs in wcon sets
    for i in range ( len ( s . data [ 'wcon' ] )) :
      start = s . data [ 'wcon' ][i]
      if start > -1 :
        for ir in range ( start, start + 10 ) :
          opp = s . data [ 'oppHel'][ir]
          if opp > -1:
            jstart = s . data [ 'wcon' ][opp]
            if jstart > -1:
              for jr in range ( jstart, jstart + 10 ) :
                jopp = s . data [ 'oppHel'][jr]
                if jopp == i \
                       and s . data [ 'Burial' ][ir] > s . BurialCutoff \
                       and s . data [ 'Burial' ][jr] > s . BurialCutoff :
                  # add to match list
                  if ir > jr :
                    if not s . data [ 'pairs' ] . count ( (jr,ir) ) :
                      s . data [ 'pairs' ] . append ( ( jr, ir ) )
                  else:
                    if not s . data [ 'pairs' ] . count ( (ir,jr) ) :
                      s . data [ 'pairs' ] . append ( ( ir,jr ) )

  def largeSideChain ( s, res ) :
    if res == 'ARG' or res == 'HIS' or res == 'LYS' or \
       res == 'MET' or res == 'PHE' or res == 'TRP' or \
       res == 'TYR' :
      return 1
    else:
      return 0

  def computeFinalWeighting ( s ) :
    # find interacting pairs
    s . getInteractingPairs()

    s . score = 0.
    for k in range ( len ( s . data [ 'pairs' ] ) ) :
      i, j = s . data [ 'pairs' ][k]
      iIndex = s . getPairIndex ( s . data [ 'resName'][i] )
      jIndex = s . getPairIndex ( s . data [ 'resName'][j] )
      score = s . data [ 'Burial'][i] + s . data [ 'Burial'][j]
      score = score * s . MaxTable[iIndex][jIndex]
      if s . brainDump :
        s. brain . write ( s . data['resName'][i] + str( i ) + ' ' +\
                           s . data['resName'][j] + str( j ) + ' ' + \
                           str(j) +  ' [' + str(iIndex) +  ' , ' + str(jIndex) +\
                           '] ' + str ( s .MaxTable [iIndex][jIndex] ) + ' ' + str(score) + '\n' )
      s . score = s . score + score

    for i in range ( len ( s . data [ 'CA' ] ) ) :
      if s . largeSideChain ( s . data [ 'resName' ][i] ) :
        if s . data [ 'Burial' ] [i] > s. LargeBurialCutoff :
          if s . brainDump:
            s . brain .write(  s . data['resName'][i] + str(i) + ' ' + \
                               str( 10. * s. data [ 'Burial'][i]) + '\n' )
          score = score + 10. * s . data [ 'Burial'][i]
          s . score = s . score + score
    if s . brainDump:
      s . brain . write ( "unweighted membrane potential: " + str(s . score) + '\n')

  def calcEnergy ( s ) :

    s . computeLocalHelicalAxis ( )
    s . computeNearestPointOnHelAxis ( ) 
    s . computeBurialScore ( )
    s . computeFinalWeighting ( ) 
    return s . weight * s . score
      
  def calcEnergyAndDerivs(s,derivs):

    return s. calcEnergy()
  

if __name__ == '__main__' :

  mb = MemPot ( "membranePot" )
  
