#!  /usr/bin/env python
#
# csPot.py
#
# An XPLOR-NIH implementation on the chemical shoift energy potential for
# solid state NMR as described in
#
# "Atomic Refinement Using Orientational Restraints from Solid-State NMR"
# R. Bertram, J. Quine, M. Chapman, T. Cross
# Journal of Magnetic Resonance, 147, p.9-16, 2000.#
#

import sys
import trace
import re
import xplorNIH
import xplor
import math
from derivList import DerivList
from pyPot import PyPot
import vec3
from FSUVector import Vector
from FSUMatrix import Matrix
from FSUNMRBase import NMRBase
from atomSel import AtomSel
import copy

class ChemShiftPot(PyPot):

  def __init__( s, name ):
      PyPot.__init__( s, name, s )
      # default  reference frame is XYZ axes
      s . B = vec3.Vec3 ( 0., 0., 1. )
      s . myB = Vector ( 0., 0., 1. )
      s . data = {}
      s . betaRad = 0.
      s . alphaRad = 0.
      s . sim = xplor. simulation
      s . cached = 0
      s . harmonic = 1
      s . flatWell = 2
      s . energyType = s . harmonic
      s . defaultScale = 1.
      s . weight = 1.
      s . D = Matrix ( )
      s . M = Matrix ( )
      s . N = Matrix ( )
      s . Sbar = Matrix ( ) 
      s . nmr = NMRBase ( )
      s . setSigmaPrincipalValues ( 35., 60., 205. )
      s . setBetaRad ( 17. * math . pi / 180. )
      s . expError = 5.0 # ppm
      
  def setSigmaPrincipalValues ( s, sigma11, sigma22, sigma33 ) :
    s . sigma11 = sigma11
    s . sigma22 = sigma22
    s . sigma33 = sigma33
    s . nmr . sigma11 = s . sigma11
    s . nmr . sigma22 = s . sigma22
    s . nmr . sigma33 = s . sigma33
    s . cacheData ( ) 

  def setBetaRad ( s, betaRad ) :
    s . betaRad = betaRad
    s . nmr . NHbetaRad = s . betaRad
    s . cacheData ( ) 

  def cacheData ( s ):
    s . M . setRotZ ( -s . betaRad )
    s . N . setRotX ( -s . alphaRad )
    s . D . setItem ( 1, 1, s . sigma11 )
    s . D . setItem ( 2, 2, s . sigma22 )
    s . D . setItem ( 0, 0, s . sigma33 )
    s . Sbar . setIdentity ( )
    s . Sbar = s . M . mult ( s . N )
    s . Sbar = s . Sbar . mult ( s . D )
    s . M . transpose ( ) # in place
    s . N . transpose ( ) # in place
    s . Sbar = s . Sbar . mult ( s . N )
    s . Sbar = s . Sbar . mult ( s . M )

  def setWeighting ( s, w ) :
    s . weight = w

  def readData ( s, dataFile ) :
    # Format of data file  [ #,! are comments ]
    # 
    # resNum atomA resNum atomB resNum atomC sigma11 sigma22 sigma33 csData  [OPTIONAL fields] 
    # OPTIONAL fields - assign any or all:
    #                        scale %f
    #                        id %s
    # e.g.:
    # 5 N 5 H 5 CA 30.0 60.0 205.0 169.4 id domain1

    dataPattern = re . compile \
    ('\s*(\d+) \s*(\w+) \s*(\d+) \s*(\w+) \s*(\d+) \s*(\w+) '\
     '\s*(-*\d+\.*\d*) \s*(-*\d+\.*\d*) \s*(-*\d+\.*\d*) '\
     '\s*(-*\d+\.*\d*)((\s*\#.*$)|(?:\s+\S+\s+\S+))*')
    #  resA     atomA    resB     atomB   resC      atomC          cs
    #   #comment | name value pairs
    file = open ( dataFile, 'r' )
    for line in file . readlines ( ) :

      # check for comment
      i = 0 
      while line [ i ] == ' ':
        i = i + 1
      if line [ i ] == '\n' or line [ i ] == '#' or line [ i ] == '!' :
        continue
      
      match = dataPattern . search ( line, i )

      if match :

        valueString = match . group ( 5 )
        vp = valueString . split ( ' ' )
        valuePairs = [] # make clean copy
        comment = 0
        for v in vp :
          if v == '#' :
            # rest of line is a comment
            comment = 1
          if not comment and len ( v ) :
            valuePairs . append ( v ) 
        # build data struct
        dataDict = {}
        dataDict [ 'resA' ] = match . group ( 1 )
        dataDict [ 'A' ] = match . group ( 2 )
        dataDict [ 'resB' ] = match . group ( 3 ) 
        dataDict [ 'B' ] = match . group ( 4 )
        dataDict [ 'resC' ] = match . group ( 5 )
        dataDict [ 'C' ] = match . group ( 6 )
        dataDict [ 'scale' ] = s . defaultScale
        dataDict [ 'sigma11' ] = float( match . group ( 7 ) )
        dataDict [ 'sigma22' ] = float( match . group ( 8 ) )
        dataDict [ 'sigma33' ] = float( match . group ( 9 ) )
        dataDict [ 'cs' ] = float ( match . group ( 10 ) ) 
        for i in range ( len ( valuePairs ) / 2 ) :
          if valuePairs [ 2*i ] == 'id' :
            dataDict [ 'id' ] = valuePairs [ 2*i + 1 ]
          else:
            # default is to cast value as float
            dataDict [ valuePairs [2*i] ] = float ( valuePairs [ 2*i + 1 ] )

        uniqueKey = dataDict [ 'resA' ] + '$' + dataDict ['A'] + '$' + \
                    dataDict [ 'resB' ] + '$' + dataDict ['B'] + '$' + \
                    dataDict [ 'resC' ] + '$' + dataDict ['C']
        s . data [ uniqueKey ] = dataDict
      else:
        print "Error - bad line in data file: ", line
        sys . exit ( 0 )
    print "chemshift data read."
          

  def pprintData (s):
    sorted_keys = s . data . keys ()
    sorted_keys . sort ()
    for key in sorted_keys:
      print key
      sk = s . data [ key ] . keys()
      sk . sort ( )
      for k in range ( len ( sk ) ) :
        print "\t" + sk [ k ] + " = " + str ( s . data [ key ] [ sk [ k ] ] )

  def reset ( s ):
    s . cachedIndices = 0

  def cacheIndices ( s ) :
    # run through data and store atoms of interest in for faster access
    alist = blist = clist = []
    for k in s . data . keys ( ) :
      resA = s . data [ k ][ 'resA' ]
      resB = s . data [ k ][ 'resB' ]
      resC = s . data [ k ][ 'resC' ]
      selString = '( resid %s and name %s)' % ( resA, s . data [k][ 'A' ] )
      s . data [k][ 'selA' ] = AtomSel ( selString ) 
      selString = '( resid %s and name %s)' % ( resB, s . data [k][ 'B' ] )
      s . data [k][ 'selB' ] = AtomSel ( selString )
      selString = '( resid %s and name %s)' % ( resC, s . data [k][ 'C' ] )
      s . data [k][ 'selC' ] = AtomSel ( selString ) 
      # sanity check
      if len ( s . data [k]['selA'].indices()) == 0:
        print "Error in residue " + k+ ": no A atom"
      # sanity check
      if len ( s . data [k]['selB'].indices()) == 0:
        print "Error in residue " + k +": no B atom"
      # sanity check
      if len ( s . data [k]['selA'].indices()) == 0:
        print "Error in residue " + k +": no C atom"
    s . cached = 1

  def calcChemShiftEnergy ( s, a, b, c, obs, scale ) :

    # convert to my data structures
    A = Vector ( a[0], a[1], a[2] )
    B = Vector ( b[0], b[1], b[2] )
    C = Vector ( c[0], c[1], c[2] )
    PAF = s . nmr . computePAF ( A, B, C, s . betaRad, s . alphaRad )
    calcsscs = s . nmr .computeChemShift ( PAF, s . myB )
    
    delta = calcsscs - obs
    if s . energyType == s . harmonic:
      energy = delta*delta
    elif s . energyType == s . flatWell:
      if delta < s . expError:
        energy = 0.
      else:
        energy = ( delta - s . expError)  * ( delta - s . expError )
    else:
      energy = delta
    return ( s . weight * scale * energy, calcsscs )  
    

  def calcEnergy ( s, printOut = 0 ) :

    #
    # called by XPLOR-NIH
    #

    if not s . cached:
      s . cacheIndices ( )

    totE = 0.
    # run through all residues with experimental data
    for k in s . data. keys ( ) :
      # Note: list is not sorted ( doesn't really matter )
      a = s . data[k]['selA'][0] . pos()
      b = s . data[k]['selB'][0] . pos()
      c = s . data[k]['selC'][0] . pos()
      obs = s . data[k]['cs']
      scale = s . data[k]['scale']
      s11 = s . data[k]['sigma11']
      s22 = s . data[k]['sigma22']
      s33 = s . data[k]['sigma33']
      s . setSigmaPrincipalValues( s11, s22, s33 )
      E, calc = s . calcChemShiftEnergy ( a, b, c, obs, scale )
      if printOut :
        print s . data[k]['resA'], \
              s . data[k]['A'], \
              s . data[k]['B'], \
              s . data[k]['C'], \
              calc, obs, calc - obs
      totE = totE + E
    return totE
      

  def calcEnergyAndDerivs(s,derivs):

    #
    # called by XPLOR-NIH
    #

    if not s . cached:
      s . cacheIndices ( )

    # loop over data points to compute energy
    for k in s . data. keys ( ) :   # Note: list is not sorted ( doesn't really matter )
        a = s . data[k]['selA'][0] . pos()
        b = s . data[k]['selB'][0] . pos()
        c = s . data[k]['selC'][0] . pos()
        obs = s . data[k]['cs']
        scale = s . data[k]['scale']
        s11 = s . data[k]['sigma11']
        s22 = s . data[k]['sigma22']
        s33 = s . data[k]['sigma33']
        s . setSigmaPrincipalValues( s11, s22, s33 )
        aIndex = s . data[k]['selA'][0].index()
        bIndex = s . data[k]['selB'][0].index()
        cIndex = s . data[k]['selC'][0].index()
        
        da = derivs[aIndex]
        db = derivs[bIndex]
        dc = derivs[cIndex]
        s . calcDerivs ( a, b, c, da, db, dc, obs, scale )
        derivs[aIndex] += s.weight*da
        derivs[bIndex] += s.weight*db
        derivs[cIndex] += s.weight*dc 

    return s. calcEnergy()
  

  def calcDerivs ( s, a, b, c, da, db, dc, obs, scale ) :

    #
    # for a complete explanation see section 3.2.4 of reference paper
    #

    # convert to my data structures
    A = Vector ( a[0], a[1], a[2] )
    B = Vector ( b[0], b[1], b[2] )
    C = Vector ( c[0], c[1], c[2] )

    v1 = B . subtract ( A )
    v2 = C . subtract ( A )
    u2 = Vector ( v2[0], v2[1], v2[2] )
    u2 . normalize ( )

    F = s . nmr . computeF ( A, B, C )
    u1 = F . getCol ( 0 )
    n  = F . getCol ( 1 )
    binorm  = F . getCol ( 2 ) 

    # Derivatives of u1.    
    lengthv1 = v1 . length ( )
    t1=1.0 /   lengthv1
    t2=1.0 / ( lengthv1*lengthv1*lengthv1 )

    du1xb = Vector ( t1-t2*v1[0]*v1[0],   -t2*v1[1]*v1[0],   -t2*v1[2]*v1[0] )
    du1yb = Vector (   -t2*v1[0]*v1[1], t1-t2*v1[1]*v1[1],   -t2*v1[2]*v1[1] )
    du1zb = Vector (   -t2*v1[0]*v1[2],   -t2*v1[1]*v1[2], t1-t2*v1[2]*v1[2] )
    du1xa = Vector ( du1xb[0], du1xb[1], du1xb[2] )
    du1ya = Vector ( du1yb[0], du1yb[1], du1yb[2] )
    du1za = Vector ( du1zb[0], du1zb[1], du1zb[2] )
    du1xa . scale ( -1. )
    du1ya . scale ( -1. )
    du1za . scale ( -1. )
    du1xc = Vector ( 0., 0., 0. )
    du1yc = Vector ( 0., 0., 0. )
    du1zc = Vector ( 0., 0., 0. )

    # Derivatives of u2 (unit vector).
    lengthv2 = v2 . length ( ) 
    t1=1.0 /   lengthv2
    t2=1.0 / ( lengthv2*lengthv2*lengthv2)
    
    du2xc = Vector ( t1-t2*v2[0]*v2[0],   -t2*v2[1]*v2[0],   -t2*v2[2]*v2[0] )
    du2yc = Vector (   -t2*v2[0]*v2[1], t1-t2*v2[1]*v2[1],   -t2*v2[2]*v2[1] )
    du2zc = Vector (   -t2*v2[0]*v2[2],   -t2*v2[1]*v2[2], t1-t2*v2[2]*v2[2] )
    du2xa = Vector ( du2xc[0], du2xc[1], du2xc[2] )
    du2ya = Vector ( du2yc[0], du2yc[1], du2yc[2] )
    du2za = Vector ( du2zc[0], du2zc[1], du2zc[2] )
    du2xa . scale ( -1. )
    du2ya . scale ( -1. )
    du2za . scale ( -1. )
    du2xb = Vector ( 0.0, 0.0, 0.0 )
    du2yb = Vector ( 0.0, 0.0, 0.0 )
    du2zb = Vector ( 0.0, 0.0, 0.0 )

    #
    # Unique elements of Skew Symmetric Matrix S
    #
    Px = Vector ( 0., 0., 0. )
    Px[0] = u1 . dot ( binorm . cross ( du1xa ) )
    Px[1] = -Px[0]
    Py = Vector ( 0., 0., 0. )
    Py[0] = u1 . dot ( binorm . cross ( du1ya ) )
    Py[1] = -Py[0]
    Pz = Vector ( 0., 0., 0. )
    Pz[0] = u1 . dot ( binorm . cross ( du1za ) )
    Pz[1] = -Pz[0]

    u1u2cross = u1 . cross ( u2 )
    t1 = 1.0 / u1u2cross . length ( ) 
    
    du1xacrossu2 = du1xa . cross ( u2 )
    du1xbcrossu2 = du1xb . cross ( u2 )
    du1xccrossu2 = du1xc . cross ( u2 )
    du1yacrossu2 = du1ya . cross ( u2 )
    du1ybcrossu2 = du1yb . cross ( u2 )
    du1yccrossu2 = du1yc . cross ( u2 )
    du1zacrossu2 = du1za . cross ( u2 )
    du1zbcrossu2 = du1zb . cross ( u2 )
    du1zccrossu2 = du1zc . cross ( u2 )

    Qx = Vector ( 0., 0., 0. )
    Qx[0] = t1 * ( u1 . dot ( du1xacrossu2 ) )
    Qx[1] = -Qx[0]
    Qy = Vector ( 0., 0., 0. )
    Qy[0] = t1 * ( u1 . dot ( du1yacrossu2 ) )
    Qy[1] = -Qy[0]
    Qz = Vector ( 0., 0., 0. )
    Qz[0] = t1 * ( u1 . dot ( du1zacrossu2 ) )
    Qz[1] = -Qz[0]

    Rx = Vector ( 0., 0., 0. )
    Rx[0] = t1 * n . dot ( u1 . cross ( du2xa ) . add ( du1xacrossu2 ) )
    Rx[1] = t1 * n . dot ( u1 . cross ( du2xb ) . add ( du1xbcrossu2 ) )
    Rx[2] = t1 * n . dot ( u1 . cross ( du2xc ) . add ( du1xccrossu2 ) )
    Ry = Vector ( 0., 0., 0. )
    Ry[0] = t1 * n . dot ( u1 . cross ( du2ya ) . add ( du1yacrossu2 ) )
    Ry[1] = t1 * n . dot ( u1 . cross ( du2yb ) . add ( du1ybcrossu2 ) )
    Ry[2] = t1 * n . dot ( u1 . cross ( du2yc ) . add ( du1yccrossu2 ) )
    Rz = Vector ( 0., 0., 0. )
    Rz[0] = t1 * n . dot ( u1 . cross ( du2za ) . add ( du1zacrossu2 ) )
    Rz[1] = t1 * n . dot ( u1 . cross ( du2zb ) . add ( du1zbcrossu2 ) )
    Rz[2] = t1 * n . dot ( u1 . cross ( du2zc ) . add ( du1zccrossu2 ) )

    Ft = F
    FtB = Ft . postMultRowVec ( s . myB )
    SFtB = s . Sbar . postMultRowVec ( FtB )

    T = Vector ( 0., 0., 0. )
    dsigmaa = Vector ( 0., 0., 0. )
    dsigmab = Vector ( 0., 0., 0. )
    dsigmac = Vector ( 0., 0., 0. )
    
    for i in range ( 3 ) :
      T[0] =  Px[i] * SFtB [1] + Qx[i] * SFtB [2]
      T[1] = -Px[i] * SFtB [0] + Rx[i] * SFtB [2]
      T[2] = -Qx[i] * SFtB [0] - Rx[i] * SFtB [1]
      if i == 0:
        dsigmaa [0] = 2. * FtB . dot ( T )
      if i == 1 :
        dsigmab [0] = 2. * FtB . dot ( T )
      if i == 2 :
        dsigmac [0] = 2. * FtB . dot ( T )
    for i in range ( 3 ) :
      T[0] =  Py[i] * SFtB [1] + Qy[i] * SFtB [2]
      T[1] = -Py[i] * SFtB [0] + Ry[i] * SFtB [2]
      T[2] = -Qy[i] * SFtB [0] - Ry[i] * SFtB [1]
      if i == 0:
        dsigmaa [1] = 2. * FtB . dot ( T )
      if i == 1 :
        dsigmab [1] = 2. * FtB . dot ( T )
      if i == 2 :
        dsigmac [1] = 2. * FtB . dot ( T )
    for i in range ( 3 ) :
      T[0] =  Pz[i] * SFtB [1] + Qz[i] * SFtB [2]
      T[1] = -Pz[i] * SFtB [0] + Rz[i] * SFtB [2]
      T[2] = -Qz[i] * SFtB [0] - Rz[i] * SFtB [1]
      if i == 0:
        dsigmaa [2] = 2. * FtB . dot ( T )
      if i == 1 :
        dsigmab [2] = 2. * FtB . dot ( T )
      if i == 2 :
        dsigmac [2] = 2. * FtB . dot ( T )


    dummy, calcsscs = s . calcChemShiftEnergy ( a, b, c, obs, scale )
    diff=2.0*scale*(calcsscs-obs)
    da[0] = da[0] + diff*dsigmaa[0]
    db[0] = db[0] + diff*dsigmab[0]
    dc[0] = dc[0] + diff*dsigmac[0] 
    da[1] = da[1] + diff*dsigmaa[1]
    db[1] = db[1] + diff*dsigmab[1]
    dc[1] = dc[1] + diff*dsigmac[1]
    da[2] = da[2] + diff*dsigmaa[2]
    db[2] = db[2] + diff*dsigmab[2]
    dc[2] = dc[2] + diff*dsigmac[2]

if __name__ == '__main__' :

  cs = ChemShiftPot ( "chemShiftPot" )
  cs . readData ( sys . argv[1] ) 
  cs . pprintData ( )
  print "---------------> data set has been parsed."
  
