#! /usr/bin/env python
#
#
#  Builds most alpha-helical structure
#  possible from a given assignment
#
#

import os
import sys
import getopt
import copy

from FSUPeptidePlane import PeptidePlane
from FSUPeptidePlane import Propagator
from FSUNMRBase import NMRBase
from FSUPDB import *
from FSUUtils import Utils


class StructFit :

    #
    # class which encapsulates the control
    # algorithm - calling the propagator as necessary
    #

    def __init__ ( s, prop ) :
        s . p = prop
	s . installDir = os.environ["PIPATH_INSTALL_DIR"]
	s . aaStructureFile = s . installDir + '/lib/FSUAminoAcids.pdb'
        s . aaDict = { 'A':'ALA', 'R':'ARG', 'N':'ASN', 'D':'ASP', 'C':'CYS',
                       'Q':'GLN', 'E':'GLU', 'G':'GLY', 'H':'HIS', 'I':'ILE',
                       'L':'LEU', 'K':'LYS', 'M':'MET', 'F':'PHE', 'P':'PRO',
                       'S':'SER', 'T':'THR', 'W':'TRP', 'Y':'TYR', 'V':'VAL' }
        s . aaDB = s . createAminoAcidDB ( )
        s . seq = ""
        s . minScore = 99999.
        s . spectrum = []
        # magnetic Field in Lab Frame
        s . B = Vector ( 0., 0., 1 )
        s . ERR_MATCH = 0.01
        s . NUM_ERR = 0.00001
        s . verbose = 0
        s . outputPDB = ""
        s . dataFile = ""
        s . correlate = 2 # default
        s . beta = 17.0
        s . nuParallel = 11.335
        s . sigma11 = 30.
        s . sigma22 = 60.
        s . sigma33 = 205.        
        s . GLYsigma11 = 28.
        s . GLYsigma22 = 49.
        s . GLYsigma33 = 197.
        s . setGlycine ( 0 )
        s . handleGLY = 0
        s . PHI = -65.
        s . PSI = -40.
        s . TAS = []
        s . PTAS = []
        s . greedyTorsions = []
        s . bestTorsions = []
        s . devTot = 0
        s . maxTreeDev = 0
        s . maxTreeChain = []

    def verifySeq ( s, seq ) :
        for i in seq:
            if not s . aaDict . has_key ( string . upper ( i ) ) :
                return 0
        return 1

    def dumpData ( s ) :

        for i in range ( len ( s . seq ) ):
            sys. stdout . write (  s . aaDict [ s . seq [ i ] ] + str (i) + '\n' )
        sys . stdout . write ( "DATA:\n" )
        for i in range ( len ( s . seq ) ):
            dString =  str( i ) + ' ' + str ( s . spectrum [i][0] ) + \
                      "\t" + str ( s . spectrum [i][1] )
            if s.spectrum[i][0] != 0. and \
               s.spectrum[i][1] != 0.:
                sgndip = s . p . nmr . determineDipoleSign ( s.spectrum[i][1],
                                                             s.spectrum[i][0],
                                                             s . correlate )
                sys . stdout . write ( dString + " sgn: " + str(sgndip) + '\n' )
            else:
                sys . stdout . write( dString + '\n' )
        sys . stdout . write ( '\n' )

    def setGlycine ( s, setg ) :
        # GLY has a unique chemical shift env
        if setg :
            s . p . nmr . betaRad =  s . p . nmr . toRad ( s . beta )
            s . p . nmr . nuParallel = s . nuParallel
            s . p . nmr . sigma11 = s . GLYsigma11
            s . p . nmr . sigma22 = s . GLYsigma22
            s . p . nmr . sigma33 = s . GLYsigma33
        else:
            s . p . nmr . betaRad =  s . p . nmr . toRad ( s . beta )
            s . p . nmr . nuParallel = s . nuParallel
            s . p . nmr . sigma11 = s . sigma11
            s . p . nmr . sigma22 = s . sigma22
            s . p . nmr . sigma33 = s . sigma33

    def setNMRParameters ( s, beta, nuParallel, sigma11, sigma22, sigma33 ) :
        s . beta =  beta
        s . nuParallel = nuParallel
        s . sigma11 = sigma11
        s . sigma22 = sigma22
        s . sigma33 = sigma33
        # push NMR constants to NMR module        
        s . setGlycine( 0 )


    def readData ( s, file ) :
        seqPattern = re . compile ( r'(\S+)\s+(\S+)\s+(\S+)(\s+[+-])*')
        seqMatched = 0
        ofile = open ( file , 'r' )
        seq = ''
        lineno = 0
        for line in ofile . readlines ( ) :
            lineno = lineno + 1
            match = seqPattern . search ( line )
            if match:
                seq  =  seq + match . group ( 1 ) 
                ( cs, dipole ) = float(match . group ( 2 )) , float(match . group ( 3 ))
                if s . handleGLY:
                    if seq[-1] == 'G':
                        s . setGlycine( 1 )
                    else:
                        s . setGlycine( 0 )
                # compute B in PAF for each observed data point
                # since this never changes
                sgn = 0
                ul = []
                eps = []
                if cs == 0.0 and dipole == 0.0 :
                    ipos = []
                    ipos. append ( Vector (0.0, 0.0, 0.0) )
                else:
                    valueString = match . group ( 3 )
                    if valueString == '+':
                        sgn = 1
                    elif valueString == '-':
                        sgn = -1
                    # have to get ALL possible degeneracies here
                    ipos, eps = s. p . nmr . getBinPAF ( cs, dipole, sgn )
                data = [ cs, dipole, ipos, sgn, eps ]
                s . spectrum . append ( data )
            else:
                print "Error parsing dataFile :"
                print line
                print "    see usage \'-h\'"
                sys .exit ( -1 )
            s . seq  = seq
            s . assign = range ( len ( s . seq ) )
            if not s . verifySeq ( s . seq ) :
                print "Line ", lineno
                print "Error : invalid AA sequence specified in data file"
                sys . exit (-1)
        if s . verbose:
            print s . seq
            for a in s . spectrum:
                print a
        ofile . close ( )


    def epsilonCheck( s, aEp, bEp, firstConditionOnly = 0, bypass = 0 ):
        if bypass: 
            return 1

        # check for wildcard
        if aEp[0] <> 0 and bEp[0] <> 0:
            # e1p(i-1) = e1(i)
            if aEp [ 1 ] <> bEp [ 0 ]:
                return 0
            if not firstConditionOnly:
                # e2p(i-1) = -e2(i)
                if aEp[4] == bEp[3] :
                    return 0
        return 1        

    def continuityCondition ( s, chainData ):

        for i in range( 1, len( chainData )) :
            (phi, psi, aEp ) = chainData[ i - 1 ] 
            (phi, psi, bEp ) = chainData[   i   ]
            if not s . epsilonCheck ( aEp, bEp ):
                return 0
        return 1


    def computeAlphaHelixDeviation ( s, phi, psi ) :

        dphi = math . fabs ( phi - s . PHI )
        dpsi = math . fabs ( psi  - s. PSI )
        if dphi > 180.:
            dphi = 360. - dphi
        if dpsi > 180.:
            dpsi = 360. - dpsi
        aDev = math . sqrt ( dphi*dphi + dpsi*dpsi )
        return aDev 
        
    def computeDeviation( s, chainDat ) :

        aDev = 0
        for (phi, psi, ep ) in chainData :
            aDev += s . computeAlphaHelixDeviation( phi, psi )
        print "Alpha-Helical Deviation : " , int( aDev + 0.5  ) 

    def createAminoAcidDB ( s ) :

        # this reads in all 20 amino acid residues
        # structures
        # Read from a simple PDB file
        if verbose:
            print 'loading AA db..'
        aas = Structure ( s . aaStructureFile )
        # now verify all AAs present
        if len ( aas . residues ) != len ( s . aaDict ) :
            print "Error: amino acid db mismatch"
            sys . exit ( -1 )

        # build aaName->resIndex db
        db = {}
        for i in s . aaDict . keys() :
            match = 0
            for j in aas . residues :
                if s . aaDict [ i ] == j . name :
                    match = 1
                    db [ i ] = j
                    continue
            if not match :
                print "Error: ", i, " not found in db"
                sys . exit ( -1 )
        if verbose:
            print "db verified."
        return db

    def populateChain ( s, chain, seq, buffer = 1 ) :

        # adds sidechains to backbone

        # append dummy peptide plane with
        # phi = -65 degrees, psi = -40 at end
        lastpp = chain [ -1 ]
        newLast = s . p . propagatePeptidePlane ( lastpp, -65., -40. )
        chain . append ( newLast )
        
        # empty chain to fill up
        protein = PeptideChain ( None, '0', '0' ) 

        for i in range ( len ( chain ) - buffer ) :
            pp = chain [ i ]
            npp = chain [ i + 1 ]

            newResidue = copy. deepcopy ( s . aaDB [ string . upper ( seq[i] ) ] )
            newResidue . number = str ( i + 1  )

            O = pp . atoms [ 'CA' ]
            A = pp . getU3 ( )
            A . normalize( )
            A . scale ( -1. ) 
            A . cleanup ()
            B = npp . getU1 ( )
            B . normalize ()
            B . cleanup ()
            s . p . alignResidueWithVectors ( newResidue,
                                              O, A, B )

            # just fill in atom coordinates from peptide chain
            for atom in newResidue :
                if atom . isHydrogen ( ) :
                    atom . position = pp . atoms [ 'H' ]
                if atom . isCarbonylCarbon ( ) :
                    atom . position = npp . atoms [ 'C' ]
                if atom . isAlphaCarbon ( ) :
                    atom . position = pp . atoms [ 'CA' ]
                if atom . isNitrogen ( ) :
                    atom . position = pp . atoms [ 'N' ]
                if atom . isOxygen ( ) :
                    atom . position = npp . atoms [ 'O' ]   

            protein . addResidue ( newResidue )
            
        del chain [-1]
        return protein

    def computeTorsionAngleSets ( s, offset ) :

        tot = 0
        for i in range( len ( s . spectrum ) - 1):

            sigma1 = fit . spectrum[i][0]
            nu1 = fit . spectrum[i][1]
            sigma2 = fit . spectrum[i+1][0]
            nu2 = fit . spectrum[i+1][1]
            tors = []
            eps = []

            tas = []
            if  ( ( sigma1 == 0. and nu1 == 0. ) or
                  ( sigma2 == 0. and nu2 == 0. ) ) :
                # wild card - always match
                eps . append( [ 0, 0, 0, 0, 0 ] )
                tas . append ( ( 1, s.PHI, s . PSI, eps[0] ) )
            else:

                B1s = s . spectrum[ i ] [2]
                b1eps = s . spectrum [ i ] [4]
                B2s = s . spectrum [ i + 1 ] [2]
                b2eps = s . spectrum [ i + 1 ] [4]
                for bi in range ( offset, len(B1s), 4 ) :
                    for bj in range ( offset, len(B2s), 4 ) :
              
                        tors, exact, grams, eps = \
                              s . p . computeTorsionsByChiralities \
                              ( B1s[bi], B2s[bj], s .p. nmr . betaRad, 0, 1 )

                        for j in range( len( tors )):
                            phi, psi = tors[j]
                            epsilons = eps[j]
                            dev =  int ( s . computeAlphaHelixDeviation ( phi, psi ) + 0.5 )
                            tas . append ( ( dev, phi, psi, epsilons ) )

            tas . sort()
            tot += len ( tas ) 
            s . TAS . append( tas )

                
    def toRad( s, deg ) :
        return deg * math.pi/180.

    def taDev ( s, tas ):

        tot = 0
        for dev, phi, psi, eps in tas:
            tot += dev
        return tot
    
    def maxtaDev ( s, tas ):

        maxdev = 0
        for dev, phi, psi, eps in tas:
            if dev > maxdev:
                maxdev = dev
        return maxdev

    def cherryPick ( s ) :

        del s . greedyTorsions [ : ] 
        for i in range ( len( s . TAS ) ):

            tas = s . TAS[ i ]
            j = 0
            picked = 0
            while not picked and j < len ( tas ):
                dev, phi, psi, epsilons = tas[j]
                if i == 0 :
                    s . greedyTorsions . append ( tas[j] )
                    lastEpsilons = epsilons
                    picked = 1
                else:
                    if s . epsilonCheck( lastEpsilons, epsilons ) :
                        s . greedyTorsions . append ( tas[j] )
                        lastEpsilons = epsilons
                        picked = 1
                j+= 1


        if len( s . greedyTorsions ) <> len( s . TAS ):
            print "Warning!!! Error!! Continuity Condition "
            print " can't be satisfied for this assignment!!"
            sys . exit ( -1 )

        if verbose:
            sStr = "Greedy STRUCT %s " %   s . taDev( s . greedyTorsions )
            for ta in s . greedyTorsions :
                dev, phi, psi, epsilons = ta
                sStr += str(dev) + " " 
            print sStr
            
        s . bestTorsions = copy. copy( s . greedyTorsions )

    def pruneTorsions ( s, tlist, bound ):

        # run thru and remove all Torsion Angles > bound
        # and meet cc condition
        plist = []
        tcount = 0
        pcount = 0
        for tas in tlist :

            ptas = []
            for ta in tas:
                dev, phi, psi, eps = ta
                tcount += 1
                
                if dev <= bound :
                    ptas . append ( ta )
                    pcount += 1
                    
            plist . append ( ptas )

        return plist
            

    def addTorsions ( s, tchain, tlist ):

        i = len( tchain )
        if i:
            dev, phi, psi, eps = tchain [ -1]
            lastEpsilons = eps

        printNestings = 0
        if printNestings:
            dStr = " " 
            for dev, phi, psi, eps in tchain:
                dStr += str(dev) + " " 
            print dStr

        for ta in tlist [ i ]:

            dev, phi, psi, epsilons = ta

            if i == 0 :
                if dev < s . devTot:
                    newchain = copy. copy( tchain )
                    newchain . append ( ta )
                else:
                    continue
            else:
                if s . epsilonCheck( lastEpsilons, epsilons ):
                    newchain = copy. copy( tchain )
                    newchain . append ( ta )
                else:
                    continue
                    
            newEpsilons = epsilons
            if i < ( len ( tlist ) - 1 ) :
                # add another
                newList = s . pruneTorsions( tlist,
                                             s . devTot - s . taDev( newchain ))
                if len( newList[i+1] ) :
                    retVal = s . addTorsions( newchain, newList )
                    if retVal == -1 :
                        return -1

            else:

                tadev =  s . taDev( newchain )
                if tadev <= s . taDev( s . bestTorsions ) :                    
                    s . bestTorsions = copy . copy ( newchain )
                    printTorsions = 0
                    if verbose and printTorsions:
                        sStr = "STRUCT %s ( " %   tadev
                        for ta in newchain :
                            dev, phi, psi, epsilons = ta
                            sStr += str(dev) + " "
                        print sStr + ")"
                        for ta in newchain :
                            dev, phi, psi, epsilons = ta
                            print phi, psi, epsilons
                    return -1
        return 0


    def partitionTrees ( s, tchain, tlist, size ):
        
        numTrees = len ( tlist ) / size
        if len(tlist) - size * numTrees > 0 :
            numTrees += 1
        s . maxTreeChain = copy . copy ( tchain )

        for it in range( numTrees ):

            itree = it * size
            s . maxTreeDev = 99999999
            limit = itree + size
            if limit >= len( tlist ) :
                limit = len( tlist )
                
            if verbose:
                print "NUMTREES", it, itree, limit
            
            if len( s . maxTreeChain ) == itree:

                prod = 1
                for j in range( itree, limit ) :
                    tas = tlist [ j ]
                    prod *= len(tas)
                if verbose :
                    print "exhausting ", prod, "possibilities."

                s . exhaust ( s . maxTreeChain, tlist, itree, limit )
            
        if len ( s . maxTreeChain ) == len( tlist ) :
            return 1
        else:
            if verbose:
                print "EXHAUST - ", len ( s . maxTreeChain ), s . taDev( s . maxTreeChain )
            return 0


    def exhaust ( s, tchain, tlist, i, bound, nest = 0 ):

        if nest == bound:
            return 
        
        if i:
            dev, phi, psi, eps = tchain [ -1]
            lastEpsilons = eps

        for ta in tlist [ i ]:

            dev, phi, psi, epsilons = ta

            if i == 0 :
                newchain = copy. copy( tchain )
                newchain . append ( ta )
            else:
                if s . epsilonCheck( lastEpsilons, epsilons ):
                    newchain = copy. copy( tchain )
                    newchain . append ( ta )
                else:
                    continue

            if len( newchain ) == bound :
                tadev =  s . taDev( newchain )
                if tadev < s . maxTreeDev:
                    s . maxTreeDev = tadev
                    s . maxTreeChain = copy . copy( newchain )
            else :        
                s . exhaust( newchain, tlist, i + 1, bound, nest + 1 )

            


    def cleanExit ( s ) :
        sys . exit ( -1 )

if __name__ == '__main__' :


    def usage ( ) :
        print "Usage: " + sys . argv [0] + "[options] -i dataFile"
        print "Builds most alpha-helical struture for a given assigned PISEMA data set"
        print "where options are: "
        print "\t-h           print this message"
        print "\t-v           verbose"
        print "\t-b           ssNMR beta in degrees [17.0]"
        print "\t-n           ssNMR nuParallel[11.335]"
        print "\t-1           ssNMR sigma11[30]"
        print "\t-2           ssNMR sigma22[60]"
        print "\t-3           ssNMR sigma33[205]"
        print "\t-g           handle GLY specifically[sigmas = 28, 49, 197]"
        print "\t-r           build all 4 PDBs (they are just reflections)"
        print "\t-o           outputPDB[bestA.pdb]"
        print "\t-p           tree partition size[10]"
        print ""
        print "Input data file is of the form:"
        print "residueName chemShift dipCoupling"
        print "e.g:"
        print "A 174.043 8.433"
        print ""
        sys . exit ( -1 )

    
    if len ( sys .argv ) == 1 :
            usage ()

    try:
        opts, args = getopt.getopt(sys.argv[1:], "rghvo:i:b:n:1:2:3:p:")
    except getopt.error, msg:
        sys.stderr.write(sys.argv[0] + ': ' + str(msg) + '\n')
        usage ( )

    outputPDB = "bestA"
    verbose = 0
    debug = 0
    nuParallel = 11.335
    sigma11 = 30.
    sigma22 = 60.
    sigma33 = 205.
    beta = 17.0
    handleGLY = 0
    reflect = 0
    partitionSize = 10

    for o, a in opts :
        if o == '-i' :
            dataFile = a
        if o == '-o' :
            outputPDB = a
        if o == '-v':
            verbose = 1
        if o == '-g':
            handleGLY = 1
        if o == '-h':
            usage ()
        if o == '-n':
            nuParallel = float( a )
        if o == '-1':
            sigma11 = float( a )
        if o == '-2':
            sigma22 = float( a )
        if o == '-3':
            sigma33 = float( a )
        if o == '-b':
            beta = float( a )
        if o == '-r':
            reflect = 1
        if o == '-p':
            partitionSize = int( a )
            
    # constants
    outDir = os . path. dirname( outputPDB )
    utils = Utils()
    if len( outDir) :
        utils . makedirs ( outDir ) 
  
    nmr = NMRBase ( )
    prop = Propagator ( nmr )

    fit = StructFit ( prop )
    fit . verbose = verbose
    fit . setNMRParameters ( beta, nuParallel, sigma11, sigma22, sigma33 )     
    fit . handleGLY = handleGLY
    fit . readData ( dataFile )

    if verbose:
        sys . stdout . write ( "\nNMR Constants:\n" )
        sys . stdout . write ( "\tbeta(rad): %7.3f\n" % fit . p. nmr.betaRad )
        sys . stdout . write ( "\tnuParallel: %7.3f\n" % fit . p . nmr.nuParallel )
        sys . stdout . write ( "\tsigma11: %7.3f\n" % fit . p. nmr.sigma11 )
        sys . stdout . write ( "\tsigma22: %7.3f\n" % fit . p . nmr.sigma22 )
        sys . stdout . write ( "\tsigma33: %7.3f\n\n" % fit . p . nmr.sigma33 )    
        fit . dumpData ( )

            
    fit . outputPDB = outputPDB
    fit . dataFile = os.path.basename( dataFile )

    cs =   fit . spectrum [ 0 ][0]
    dip =  fit . spectrum [ 0 ][1]
    ipos = fit . spectrum [ 0 ][2]
    sgn =  fit . spectrum [ 0 ][3]
    pisemaEps =  fit . spectrum [ 0 ][4]
    B = Vector( 0., 0., 1.)

    fit . computeTorsionAngleSets( 1 )
    fit  . cherryPick ( ) 

    devTot = fit . taDev ( fit . greedyTorsions )    
    maxdev = fit . maxtaDev ( fit . greedyTorsions )
    avgDev = devTot / len( fit . TAS )
    currentTAS = fit . pruneTorsions ( fit . TAS, avgDev )

    if verbose:
        prod = 1
        sStr = ""
        for tas in currentTAS:
            prod *= len(tas)
            sStr += str(len(tas)) + " " 
        print sStr, "TOTAL" , prod, avgDev
    
    tchain = []
    fit . devTot = devTot

    while not fit . partitionTrees ( tchain, currentTAS, partitionSize ) :
        # no struct found - increase bound
        avgDev += 10
        currentTAS = fit . pruneTorsions ( fit.TAS, avgDev )    

        if verbose:
            prod = 1
            sStr = ""
            for tas in currentTAS:
                prod *= len(tas)
                sStr += str(len(tas)) + " " 
            print sStr, "TOTAL" , prod, avgDev

    if verbose:
        totdev = 0
        for ta in fit . bestTorsions :
            dev, phi, psi, eps = ta
            sStr = "%3s (%6.2f,%6.2f) " % ( dev, phi, psi )
            sStr += "["
            for ep in eps:
                sStr += " %3d" % (ep * -1)
            sStr += " ]"
            print sStr
            totdev += dev
        print "DEV TOTAL", totdev , "IMPROVEMENT", totdev - devTot

    chainData = []

    if reflect:
        blimit = len ( ipos )
    else:
        blimit = 1

    for j in range ( 0, blimit ) :
        
        # begin construction of a peptide chain with
        # first peptide plane oriented to give these values
        chain = []
        torsion = []
        pp = PeptidePlane ( )
        if ipos[j][0] <> 0. and ipos[j][1] <> 0. and ipos[j][2] <> 0.:
            fit . p . alignWithPAF( pp, ipos [ j ], fit . B, nmr . betaRad )
            if verbose:
                print "DEGENERACY ", j, " PISEMA equation epsilons:", pisemaEps [j]

        chain . append ( pp )      

        del chainData[:]

        for i in range ( 1, len (fit.seq)) :

            dev, phi, psi, ep = fit . bestTorsions [ i - 1 ]
            ppnew = fit.p.propagatePeptidePlane( pp, phi, psi )

            chain . append ( ppnew )
            pp = ppnew

            chainData . append ( [ phi, psi, ep ] )

        if ( len ( chain ) == len ( fit.seq ) ):
                newProtein = fit . populateChain ( chain, fit.seq, 1 )
                if reflect:
                    oFile = outputPDB + "." + str(j+1) + ".pdb"
                else:
                    oFile = outputPDB + ".pdb"
                newProtein . writeToFile ( oFile )
                if verbose:
                    print oFile, "written."
                if not reflect:
                    devTot = 0
                    for dev, phi, psi, eps in fit . bestTorsions:
                        devTot += dev
                    print devTot
        else:
            print "PROBLEM!! only ", len ( chain ), " residues built."

    fit . cleanExit ( ) 
   
    

