# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
from scipy.optimize import curve_fit
import sys, os

## Local imports
from chrispy.util import hermite
# from arbdmodel import ArbdModel, ParticleType, PointParticle, Group, get_resource_path
from arbdmodel import ArbdModel, ParticleType, PointParticle, Group
from arbdmodel.abstract_polymer import PolymerSection, AbstractPolymerGroup
from arbdmodel.interactions import TabulatedPotential, HarmonicBond, HarmonicAngle, HarmonicDihedral
from arbdmodel.coords import quaternion_to_matrix, readArbdCoords

import MDAnalysis as mda
from gridData import Grid
from writeDx import writeDx

import arbdmodel.kh_polymer_model as khm
from arbdmodel.kh_polymer_model import _types as kh_types
from arbdmodel.kh_polymer_model import KhNonbonded

debye_length = ld_182 = 7.1603735
arbd='/home/cmaffeo2/development/arbd.after_server5/src/arbd'

for k,t in kh_types.items():
    t.damping_coefficient = 1
        
## update epsilon dictionary
_kh_eps_orig = dict(khm.epsilon_mj)
def update_epsilon_dict(P_keys=('P',), B_keys=('B',)):
    kh_eps = dict(_kh_eps_orig)
    _new_eps = dict()
    for B in set([B for A,B in kh_eps.keys()]):
        for K in P_keys:
            _new_eps[(K,B)] = _new_eps[(B,K)] = kh_eps[('PHE',B)]
        for K in B_keys:
            _new_eps[(K,B)] = _new_eps[(B,K)] = kh_eps[('TYR',B)]
    kh_eps.update(_new_eps)
    khm.epsilon_mj = kh_eps

"""Define particle types"""

def get_resource_path(x):
    return "../{}".format(x)

n_replicas = 1

## units "295 k K/(160 amu * 1.24/ps)" "AA**2/ns"
## units "295 k K/(180 amu * 1.24/ps)" "AA**2/ns"
_P = ParticleType("P",
                 diffusivity = 1621,
                 mass = 121,
                 radius = 5,
                  charge=-1,
                  sigma=5.58,   # after ASP
                  rigid_body_key='Pbead',
                 nts = 0.5      # made compatible with nbPot
)

_B = ParticleType("B",
                 diffusivity = 1093,
                 mass = 181,    # thymine
                 radius = 3,
                  charge=0,
                  sigma=6.46,   # After TYR
                  rigid_body_key='Bbead',
                 nts = 0.5      # made compatible with nbPot
)

class DnaStrandFromPolymer(Group):
    p = PointParticle(_P, (0,0,0), "P")
    b = PointParticle(_B, (3,0,1), "B")
    nt = Group( name = "nt", children = [p,b])
    nt.add_bond( i=p, j=b, bond = get_resource_path('two_bead_model/BPB.dat'), exclude = True )

    def __init__(self, polymer, **kwargs):
        self.polymer = polymer
        Group.__init__(self, **kwargs)
        
    def _clear_beads(self):
        ...
        
    def _generate_beads(self):
        nts = self.nts = self.children

        for i in range(self.polymer.num_monomers):
            c = self.polymer.monomer_index_to_contour(i)
            r = self.polymer.contour_to_position(c)
            o = self.polymer.contour_to_orientation(c)
            
            new = DnaStrandFromPolymer.nt.duplicate()
            new.orientation = o
            new.position = r
            self.add(new)

        ## Two consecutive nts 
        for i in range(len(nts)-1):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            self.add_bond( i=b1, j=p2, bond = get_resource_path('two_bead_model/BBP.dat'), exclude=True )
            self.add_bond( i=p1, j=p2, bond = get_resource_path('two_bead_model/BPP.dat'), exclude=True )
            self.add_angle( i=p1, j=p2, k=b2, angle = get_resource_path('two_bead_model/p1p2b2.dat') )
            self.add_angle( i=b1, j=p2, k=b2, angle = get_resource_path('two_bead_model/b1p2b2.dat') )
            self.add_dihedral( i=b1, j=p1, k=p2, l=b2, dihedral = get_resource_path('two_bead_model/b1p1p2b2.dat') )
            self.add_exclusion( i=b1, j=b2 )
            self.add_exclusion( i=p1, j=b2 )

        ## Three consecutive nts 
        for i in range(len(nts)-2):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            self.add_angle( i=p1, j=p2, k=p3, angle = get_resource_path('two_bead_model/p1p2p3.dat') )
            self.add_angle( i=b1, j=p2, k=p3, angle = get_resource_path('two_bead_model/b1p2p3.dat') )
            self.add_dihedral( i=b1, j=p2, k=p3, l=b3, dihedral = get_resource_path('two_bead_model/b1p2p3b3.dat') )
            self.add_exclusion( i=p1, j=p3 )
            self.add_exclusion( i=b1, j=p3 )

        ## Four consecutive nts 
        for i in range(len(nts)-3):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            p4,b4 = nts[i+3].children
            self.add_dihedral( i=p1, j=p2, k=p3, l=p4, dihedral = get_resource_path('two_bead_model/p0p1p2p3.dat') )

class IndependentDnaStrandFromPolymer(DnaStrandFromPolymer):

    particle_type_dict = {}
    
    def __init__(self, polymer, index, grid_path, **kwargs):
        if index not in IndependentDnaStrandFromPolymer.particle_type_dict:

            _P = ParticleType("P{:03d}".format(index),
                              diffusivity = 1621,
                              mass = 121,
                              radius = 5,
                              charge=-1,
                              sigma=5.58,   # after ASP
                              nts = 0.5,      # made compatible with nbPot
                              rigid_body_key='Pbead',
                              # grid=[('../{}/grid-P.dx'.format(grid_path), 0.57827709)]
                              )
            _B = ParticleType("B{:03d}".format(index),
                              diffusivity = 1093,
                              mass = 181,    # thymine
                              radius = 3,
                              charge=0,
                              sigma=6.46,   # After TYR
                              nts = 0.5,      # made compatible with nbPot
                              rigid_body_key='Bbead',                              
                              # grid=[('../{}/grid-B.dx'.format(grid_path), 0.57827709)]
                              )
            IndependentDnaStrandFromPolymer.particle_type_dict[index] = (_P,_B)
        _P,_B = self.types = IndependentDnaStrandFromPolymer.particle_type_dict[index]
        p = PointParticle(_P, (0,0,0), "P")
        b = PointParticle(_B, (3,0,1), "B")

        self.nt = nt = Group( name = "nt", children = [p,b])
        nt.add_bond( i=p, j=b, bond = get_resource_path('two_bead_model/BPB.dat'), exclude = True )
        
        self.polymer = polymer
        Group.__init__(self, **kwargs)
        
    def _clear_beads(self):
        ...
        
    def _generate_beads(self):
        nts = self.nts = self.children

        for i in range(self.polymer.num_monomers):
            c = self.polymer.monomer_index_to_contour(i)
            r = self.polymer.contour_to_position(c)
            o = self.polymer.contour_to_orientation(c)
            
            new = self.nt.duplicate()
            new.orientation = o
            new.position = r
            self.add(new)

        ## Two consecutive nts 
        for i in range(len(nts)-1):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            self.add_bond( i=b1, j=p2, bond = get_resource_path('two_bead_model/BBP.dat'), exclude=True )
            self.add_bond( i=p1, j=p2, bond = get_resource_path('two_bead_model/BPP.dat'), exclude=True )
            self.add_angle( i=p1, j=p2, k=b2, angle = get_resource_path('two_bead_model/p1p2b2.dat') )
            self.add_angle( i=b1, j=p2, k=b2, angle = get_resource_path('two_bead_model/b1p2b2.dat') )
            self.add_dihedral( i=b1, j=p1, k=p2, l=b2, dihedral = get_resource_path('two_bead_model/b1p1p2b2.dat') )
            self.add_exclusion( i=b1, j=b2 )
            self.add_exclusion( i=p1, j=b2 )

        ## Three consecutive nts 
        for i in range(len(nts)-2):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            self.add_angle( i=p1, j=p2, k=p3, angle = get_resource_path('two_bead_model/p1p2p3.dat') )
            self.add_angle( i=b1, j=p2, k=p3, angle = get_resource_path('two_bead_model/b1p2p3.dat') )
            self.add_dihedral( i=b1, j=p2, k=p3, l=b3, dihedral = get_resource_path('two_bead_model/b1p2p3b3.dat') )
            self.add_exclusion( i=p1, j=p3 )
            self.add_exclusion( i=b1, j=p3 )

        ## Four consecutive nts 
        for i in range(len(nts)-3):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            p4,b4 = nts[i+3].children
            self.add_dihedral( i=p1, j=p2, k=p3, l=p4, dihedral = get_resource_path('two_bead_model/p0p1p2p3.dat') )

class DnaModel(ArbdModel):
    def __init__(self, polymers, grid_path, num_polymers_per_replica,
                 DEBUG=False,
                 **kwargs):

        kwargs['timestep'] = 20e-6
        kwargs['temperature'] = 291
        kwargs['cutoff'] = 35
        kwargs['pairlist_distance'] = 60
        kwargs['decomp_period'] = 1000

        # kwargs['dummy_types'] = [i for k,i in kh_types.items()]
        
        self.polymer_group = AbstractPolymerGroup(polymers)
        self.strands = [IndependentDnaStrandFromPolymer(p,i//num_polymers_per_replica, grid_path)
                        for i,p in enumerate(self.polymer_group.polymers)]
        ArbdModel.__init__(self, self.strands, **kwargs)
        self.nbSchemes = []
        
        self.extra_bd_file_lines = """
## RigidBodies
rigidBody ssb
num 1
mass 48856.41410648823
inertia 19375322 18923208 13511100
position 0 0 0
orientation 1 0 0 0 1 0 0 0 1

transDamping 779.075531386 757.10491996 735.261090254
rotDamping 3096.68320292 3027.72621629 3547.0215724
# attachedParticles ../ssb_kh_particles.txt

potentialGrid Pbead ../grids-1/grid-P.dx
potentialGrid Bbead ../grids-1/grid-B.dx
potentialGridScale Pbead 0.57827709
potentialGridScale Bbead 0.57827709
"""
        
                 
        processed = set()
        P_types = []
        B_types = []
               
        for strand in self.strands:
            if strand.types not in processed:
                _P,_B = strand.types
                P_types.append(_P)
                B_types.append(_B)
                self.useNonbondedScheme( TabulatedPotential(get_resource_path('two_bead_model/NBBB.dat')), typeA=_B, typeB=_B )
                self.useNonbondedScheme( TabulatedPotential(get_resource_path('two_bead_model/NBPB.dat')), typeA=_P, typeB=_B )
                self.useNonbondedScheme( TabulatedPotential( '../NBPP.pb_correction.dat'), typeA=_P, typeB=_P )

        update_epsilon_dict([t.name for t in P_types], [t.name for t in B_types])

        # for A in P_types + B_types:
        #     for B in self.dummy_types:
        #         self.useNonbondedScheme( KhNonbonded(debye_length), typeA=A, typeB=B )
        #         self.useNonbondedScheme( KhNonbonded(debye_length), typeA=A, typeB=B )
        self.generate_beads()              


    def generate_beads(self):
        for s in self.strands:
            s._generate_beads()
        

def run_round(force, replica=1, last_coordinates = None, dry_run=False):
    strands_per_replica = 1
    dimensions = [3000]*3
    name = 'dna-70'
   
    IndependentDnaStrandFromPolymer.particle_type_dict = {} # Ugly hack to clear cached particle types that have wrong grids
    strands = []
    for i in range(strands_per_replica*n_replicas):
        s = PolymerSection("D{}".format(i), num_monomers=70, monomer_length=3,
                           start_position = np.array((20,0,-3*70/2)) )
        strands.append(s)

    ## Randomly place strands through system
    model = DnaModel( strands, grid_path='grids-1', num_polymers_per_replica=strands_per_replica, dimensions=dimensions )
    
    # if last_coordinates is not None:
    #     for p,c in zip([p for p in model],last_coordinates):
    #         p.position = c
            
    path = 'force-ext.{}pN.rep{}'.format(force,replica)
    P_beads = [p for p in model if p.name[0] == 'P']
    bead = P_beads[0]
    t0 = bead.type_
    t1 = ParticleType('Pdow', grid=[('../down.dx',force)], parent=t0)
    bead.type_ = t1
    bead = P_beads[-1]
    t0 = bead.type_
    t1 = ParticleType('Pup', grid=[('../up.dx',force)], parent=t0)
    bead.type_ = t1
    
    model.simulate( output_name = name, output_period=1e3, num_steps=1e8, directory=path, arbd=arbd, dry_run=dry_run ) # 20 ns
    try:
        coords = readArbdCoords('{}/output/{}.restart'.format(path,name))
    except:
        coords = None
    return coords

def create_attached_kh_particles_file():
    filename = 'ssb_kh_particles.txt'
    if not Path(filename).exists():
        kh_type_names = [t.name for k,t in kh_types.items()]

        u = mda.Universe('1eyg.psf','1eyg.pdb')
        sel = u.select_atoms("name CA")
        with open(filename,'w') as fh:
            for rn, r in zip(sel.resnames, sel.positions):
                assert(rn in kh_type_names)
                fh.write('{} {} {} {}\n'.format(rn,*r))

def create_PB_NBPP():
    if not Path("NBPP.pb_correction.dat").exists():
        fname = get_resource_path('two_bead_model/NBPP.dat')
        print(fname)
        r,u = np.loadtxt(fname).T

        q1=q2=1
        D = 80                  # dielectric of water
        ## units "e**2 / (4 * pi * epsilon0 AA)" kcal_mol
        A =  332.06371

        ## units "sqrt( 80 epsilon0 295 k K / (2*(100 mM/particle) e**2) )" AA
        ld_100 = 9.6598719
        ld_182 = 7.1603735
        u_correction =((A*q1*q2/D)/r)*(np.exp(-r/ld_182)-np.exp(-r/ld_100))
        u_correction[0] = u_correction[1] # avoid nans
        
        maxForce = 2
        if maxForce is not None:
            assert(maxForce > 0)
            f = np.diff(u_correction)/np.diff(r)
            f[f>maxForce] = maxForce
            f[f<-maxForce] = -maxForce
            u_correction[0] = 0
            u_correction[1:] = np.cumsum(f*np.diff(r))
            
        u = u+u_correction
        u = u-u[-1]
        np.savetxt("NBPP.pb_correction.dat", np.array([r,u]).T, fmt='%f')
    
                
if __name__ == "__main__":
    # create_attached_kh_particles_file()
    create_PB_NBPP()
        
    # """ #START
    c = None
    start=1
    for rep in [1,2,3,4]:
        for f in list(range(1,11)) + [12,15,17,20,25]:
            run_round( f, replica=rep, dry_run=True )
    """ #END """

    
    """ #START
    import matplotlib as mpl
    mpl.use('agg')
    from matplotlib import pyplot as plt

    fig = plt.figure(figsize=(2.1,1.7))
    ax = fig.gca()
    x = np.linspace(-2,1,50)
    ax.plot(x,x,label='identity')
    for factor in [0.8,0.9,1.0,1.2]:
        factor = factor**2
        y = _transform_grid_values(np.array(x),factor)
        ax.plot( x, y, label=str(factor) )
        ax.legend()
        fig.savefig('test.transform_grid_values.pdf')
    print(np.min(_original_grids['P'].grid))
    print(np.min(_original_grids['B'].grid))
    """ #END """