#! /usr/bin/env python
#
# dipPot.py
#
# An XPLOR-NIH implementation on the dipolr energy potential for
# solid state NMR as described in
#
# "Atomic Refinement Using Orientational Restraints from Solid-State NMR"
# R. Bertram, J. Quine, M. Chapman, T. Cross
# Journal of Magnetic Resonance, 147, p.9-16, 2000.
#
#
#

import sys
import trace
import string
import re
import xplorNIH
import xplor
import math
from derivList import DerivList
from pyPot import PyPot
import vec3
from atomSel import AtomSel
import copy

class DipolarPot(PyPot):

  def __init__( s, name ):
      PyPot.__init__( s, name, s )
      # default reference frame is XYZ axes
      s . B = vec3.Vec3 ( 0., 0., 1. ) 
      s . data = {}
      s . nupar = 0.
      s . sim = xplor. simulation
      s . cached = 0
      s . harmonic = 1
      s . flatWell = 2
      s . energyType = s . harmonic
      s . defaultScale = 1.
      s . weight = 1.
      # 2-D correlation support
      s . correlate = 1  # "2D"
      s . csdata = {}
      s . betaRad = 0.
      s . alphaRad = 0.
      s . cBeta = s . sBeta = s . c2Beta  = s . s2Beta = 0.
      s . setBetaRad ( 17. * math . pi / 180. )
      s . setNuParallel ( 11.335 )
      s . expError = 0.5  # in kHz

  def setBetaRad ( s, betaRad ) :
    s . betaRad = betaRad
    # cache some common values
    s . cBeta = math . cos ( s . betaRad )
    s . sBeta = math . sin ( s . betaRad )
    s . c2Beta = math . cos ( 2. * s . betaRad )
    s . s2Beta = math . sin ( 2. * s . betaRad )

  def setWeighting ( s, w ) :
    s . weight = w

  def setNuParallel ( s, nup ) :
    s . nupar = nup
    
  def setCorrelation ( s, correlateType ) :
    # NOTE : this has 3 settings
    #  0 - no correlation
    #  1 - "1D" correlation - all values above nuParallel/2. are assumed +
    #  2 - "2D" correlation - using sigma values to determine where
    #      in the frequency plane ellipse it falls - better than
    #      1D, but still not perfect
    #
    s . correlate = correlateType

  def readData ( s, dataFile ) :
    # Format of data file  [ #,! are comments ]
    # 
    # resNum atomA atomB nuParallel dipData  [OPTIONAL fields] 
    # OPTIONAL fields - assign any or all:
    #                        scale %f
    #                        sign %d
    #                        id %s
    # e.g.:
    # 5 N 6 H 11.335 4.3 sign 1 id domain1

    dataPattern = re . compile \
    ('\s*(\d+) \s*(\w+) \s*(\d+) \s*(\w+) '\
     '\s*(-*\d+\.*\d*) \s*(-*\d+\.*\d*)((\s*\#.*$)|(?:\s+\S+\s+\S+))*')
    #  resA     atomA    resB     atomB      dipole         #comment | name value pairs
    file = open ( dataFile, 'r' )
    for line in file . readlines ( ) :

      # check for comment
      i = 0 
      while line [ i ] == ' ':
        i = i + 1
      if line [ i ] == '\n' or line [ i ] == '#' or line [ i ] == '!' :
        continue
      match = dataPattern . search ( line, i )
      if match :

        valueString = match . group ( 6 )
        vp = valueString . split ( ' ' )
        valuePairs = [] # make clean copy
        comment = 0
        for v in vp :
          if v == '#' :
            # rest of line is a comment
            comment = 1
          if not comment and len ( v ) :
            valuePairs . append ( v ) 

        # build data struct
        dataDict = {}
        dataDict [ 'resA' ] = match . group ( 1 ) 
        dataDict [ 'A' ] = match . group ( 2 )
        dataDict [ 'resB' ] = match . group ( 3 ) 
        dataDict [ 'B' ] = match . group ( 4 )
        dataDict [ 'nuParallel' ] = float( match . group ( 5 ) )
        dataDict [ 'dipole' ] = float ( match . group ( 6 ) ) 
        dataDict [ 'scale' ] = s . defaultScale
        for i in range ( len ( valuePairs ) / 2 ) :
          if valuePairs [ 2*i ] == 'id' :
            dataDict [ 'id' ] = valuePairs [ 2*i + 1 ]
          elif valuePairs [ 2*i ] == 'sign':
            dataDict [ 'sign' ] = int ( valuePairs [ 2*i + 1 ] )
          else:
            # default is to cast value as float
            dataDict [ valuePairs [2*i] ] = float ( valuePairs [ 2*i + 1 ] )
        uniqueKey = dataDict ['resA'] + '$' + dataDict['A'] + '$' + \
                    dataDict ['resB'] + '$' + dataDict['B']
        s . data [ uniqueKey ] = dataDict
      else:
        print "Error - bad line in dip data file: ", line
        sys . exit ( 0 )
    print "dipolar data read."
          

  def readChemShiftData ( s, dataFile ) :
    # NOTE: you can read in a CS value to use
    #       its value to determine sign of dipole
    #
    # Format of data file  [ #,! are comments ]
    # 
    # resNum atomA resNum atomB resNum atomC s11 s22 s33 csData [OPTIONAL fields] 
    # OPTIONAL fields - assign any or all:
    #                        scale %f
    #                        id %s
    # e.g.:
    # 5 N 5 H 5 CA 30.0 60.0 205.0 169.4 id domain1
    
    dataPattern = re . compile \
     ('\s*(\d+) \s*(\w+) \s*(\d+) \s*(\w+) \s*(\d+) \s*(\w+) '\
      '\s*(-*\d+\.*\d) \s*(-*\d+\.*\d) \s*(-*\d+\.*\d) '\
      '\s*(-*\d+\.*\d*)((\s*\#.*$)|(?:\s+\S+\s+\S+))*')
    #  resA     atomA    resB     atomB   resC  atomC sigma11 sigma22 sigma33   cs
    #   #comment | name value pairs
    file = open ( dataFile, 'r' )
    for line in file . readlines ( ) :
        
      # check for comment
      i = 0 
      while line [ i ] == ' ':
        i = i + 1
        if line [ i ] == '\n' or line [ i ] == '#' or line [ i ] == '!' :
          continue
          
        match = dataPattern . search ( line, i )  
        if match :
          valueString = match . group ( 10 )
          vp = valueString . split ( ' ' )
          valuePairs = [] # make clean copy
          comment = 0
          for v in vp :
            if v == '#' :
              # rest of line is a comment
              comment = 1
              if not comment and len ( v ) :
                valuePairs . append ( v ) 
          # build data struct
          dataDict = {}
          dataDict [ 'resA' ] = match . group ( 1 )
          dataDict [ 'A' ] = match . group ( 2 )
          dataDict [ 'resB' ] = match . group ( 3 ) 
          dataDict [ 'B' ] = match . group ( 4 )
          dataDict [ 'cs' ] = float ( match . group ( 10 ) )
          dataDict [ 'scale' ] = s . defaultScale
          for i in range ( len ( valuePairs ) / 2 ) :
            if valuePairs [ 2*i ] == 'id' :
              dataDict [ 'id' ] = valuePairs [ 2*i + 1 ]
            else:
              # default is to cast value as float
              dataDict [ valuePairs [2*i] ] = float ( valuePairs [ 2*i + 1 ] )

          uniqueKey = dataDict [ 'resA' ] + '$' + dataDict ['A'] + '$' + \
                      dataDict [ 'resB' ] + '$' + dataDict ['B']
          s . csdata [ uniqueKey ] [ 'cs' ] = dataDict[ 'cs' ]
        else:
          print "Error - bad line in chemical shift data file: ", line
          sys . exit ( 0 )        
    print "chemshift data read."

  def pprintData (s):
    sorted_keys = s . data . keys ()
    sorted_keys . sort ()
    for key in sorted_keys:
      print key
      sk = s . data [ key ] . keys()
      sk . sort ( )
      for k in range ( len ( sk ) ) :
        print "\t" + sk [ k ] + " = " + str ( s . data [ key ] [ sk [ k ] ] )

  def reset ():
    s . cachedIndices = 0

  def cacheIndices ( s ) :
    # run through data and store atoms of interest in for faster access
    alist = blist = clist = []
    for k in s . data . keys ( ) :
      resA = s . data [ k ][ 'resA' ]
      resB = s . data [ k ][ 'resB' ]
      selString = '( resid %s and name %s)' % ( resA, s . data [k][ 'A' ] )
      s . data [ k ][ 'selA' ] = AtomSel ( selString ) 
      selString = '( resid %s and name %s)' % ( resB, s . data [k][ 'B' ] )
      s . data [ k ][ 'selB' ] = AtomSel ( selString )
      # sanity check
      if len ( s . data [k]['selA'].indices()) == 0:
        print "Error in residue " + k+ ": no A atom"
      if len ( s . data [k]['selB'].indices()) == 0:
        print "Error in residue " + k +": no B atom"

    # grab CS data if there
    for k in s . csdata . keys ( ) :
      uniqueKey = s . csdata [k] ['resA'] + '$' + s . csdata [k]['A'] + '$' + \
                  s . csdata [k]['resB'] + '$' + s . csdata [k]['B']  
      s . data [ uniqueKey ] [ 'cs' ] = s . csdata [ k]['cs']
        
    s . cached = 1

  def calcDipoleEnergy ( s, a, b, obs, cs, scale,userSign ) :
    B = s . B
    A = b - a
    Anorm = vec3 . norm ( A )
    Bnorm = vec3 . norm ( B ) 
    cosTheta = vec3 . dot ( A, B ) / ( Anorm * Bnorm )
    calcDipole= s . nupar * (3.*cosTheta*cosTheta - 1.) /2.
    if userSign != 0 :
      # user override!
      obs = userSign * obs
      delta = calcDipole - obs   
    else:
      if ( cs > 0. ) :
        delta, sgn = nmr . getCorrelatedError ( calcDipole, cs, obs, s . correlate )
      else:
        delta = math . fabs ( calcDipole ) - obs
    
    if s . energyType == s . harmonic:
      energy = delta * delta
    elif s . energyType == s . flatWell:
      if delta < s . expError:
        energy = 0.
      else:
        energy = ( delta - s . expError)  * ( delta - s . expError )
    else:
      energy = delta
    return ( s . weight * scale * energy, calcDipole )  # harmonic
    
  def calcEnergy ( s, printOut = 0 ) :


    #
    # called by XPLOR-NIH
    #

    if not s . cached:
      s . cacheIndices ( )

    totE = 0.
    # run through all residues with experimental data
    for k in s . data. keys ( ) :

      # Note: list is not sorted ( doesn't really matter )
      a =  s . data[k]['selA'][0] . pos()
      b =  s . data[k]['selB'][0] . pos()
      obs = s . data[k]['dipole']
      scale = s . data[k]['scale']
      cs = 0.
      if s . data[k] . has_key ( 'cs'):
        cs = s . data[k]['cs']
      userSign = 0  # default is unknown
      if s . data[k] . has_key ( 'sign' ) :
        userSign = data[k]['sign']
      nup = s . data[k]['nuParallel']
      s . setNuParallel ( nup )
      E, calc = s . calcDipoleEnergy ( a, b, obs, cs, scale, userSign )
      if printOut :
        print k, calc, obs, calc - obs 
      totE = totE + E

    return totE
      
  def calcEnergyAndDerivs(s,derivs):

    #
    # called by XPLOR-NIH
    #

    if not s . cached:
      s . cacheIndices ( )

    totE = 0.
    # loop over data points to compute energy and derivs
    for k in s . data. keys ( ) :

      # Note: list is not sorted ( doesn't really matter )
      a =  s . data[k]['selA'][0] . pos()
      b =  s . data[k]['selB'][0] . pos()
      obs = s . data[k]['dipole']
      scale = s . data[k]['scale']
      cs = 0.
      if s . data[k] . has_key ( 'cs'):
        cs = s . data[k]['cs']
      userSign = 0  # default is unknown
      if s . data[k] . has_key ( 'sign' ) :
        userSign = data[k]['sign']
      aIndex = s . data[k]['selA'][0].index()
      bIndex = s . data[k]['selB'][0].index()
      nup = s . data[k]['nuParallel']
      s . setNuParallel ( nup )
      da = vec3.Vec3 ( 0., 0., 0. )
      db = vec3.Vec3 ( 0., 0., 0. )
      E = s . calcDerivs ( a, b, da, db, obs, scale, userSign, cs )
      derivs[aIndex] = derivs[aIndex] + da
      derivs[bIndex] = derivs[bIndex] + db

      totE = totE + E
             
    return s. calcEnergy( 0 )
  

  def calcDerivs ( s, a, b, da, db, obs, scale, userSign, cs ) :
    #
    #  derivative of E with respect to coordinates is:
    #
    #   dE / dx_i = dE / dtheta * dtheta / dx_i
    #   dE / theta = -12 K deltaE A sintheta costheta
    #   dtheta/dx_i = -1 / sintheta * ( d costheta / dx_i )
    #
    #  dcos theta              B_xi              ( A . B ) A_xi
    #   _____________   =  (   ______     -       ______________  ) * dA_xi/dx_i
    #
    #     dx_i                 len(A)*len(B)      len(A)^3*len(B)
    #
    #
    #   where A and B are vectors forming dipole
    #   B here is B field
    #   A = N - M 
    #
    #  
    
    B = s. B
    A = b - a
    Anorm = vec3 . norm ( A )
    Bnorm = vec3 . norm ( B ) 
    ABdot = vec3 . dot ( A, B )
    theta = math . acos ( ABdot / ( Anorm * Bnorm ) )
    sinTheta = math . sin ( theta )
    cosTheta = math . cos ( theta )
    c1 = 1 / ( Anorm * Bnorm )
    c2 = c1 * ABdot / ( Anorm * Anorm )  
    calcDipole = s . nupar * (3.*cosTheta*cosTheta - 1.) / 2.
   
    if userSign != 0 :
      # user override!
      obs = userSign * obs
      delta = calcDipole - obs   
    else:
      if ( cs > 0. ) :
        # try correlation
        delta, sgn = nmr . getCorrelatedError ( calcDipole, cs, obs, s . correlate )
      else:
        delta = math . fabs ( calcDipole ) - obs
        sgn = 0

    # compute deriv term
    if sgn == 0:
      if ( 3. * cosTheta * cosTheta - 1. ) > 0. :
        sgn = 1
      else:
        sgn = -1

    dEdTheta = -6. * delta * s . weight * scale * s . nupar * sinTheta * cosTheta * sgn
    
    for i in range ( 3 ) :
      # M and N points change in opposite directions
      da [ i ] =  ( c2 * A [ i ] - c1 * B [ i ] ) / -sinTheta * dEdTheta
      db [ i ] =  ( c1 * B [ i ] - c2 * A [ i ] ) / -sinTheta * dEdTheta

    return ( delta*delta*scale )

if __name__ == '__main__' :

  dp = DipolarPot ( "dipolarPot" )
  dp . readData ( sys . argv[1] ) 
  dp . pprintData ( )
  print "---------------> data set has been parsed."
  
