# -*- coding: utf-8 -*-
from pathlib import Path
import numpy as np
from scipy.optimize import curve_fit
import sys, os

## Local imports
from arbdmodel import ArbdModel, ParticleType, PointParticle, Group, get_resource_path
from arbdmodel.abstract_polymer import PolymerSection, AbstractPolymerGroup
from arbdmodel.interactions import TabulatedPotential, HarmonicBond, HarmonicAngle, HarmonicDihedral
from arbdmodel.coords import quaternion_to_matrix, readArbdCoords

import MDAnalysis as mda
from gridData import Grid
from writeDx import writeDx

"""Define particle types"""

n_replicas = 10

## units "295 k K/(160 amu * 1.24/ps)" "AA**2/ns"
## units "295 k K/(180 amu * 1.24/ps)" "AA**2/ns"
_P = ParticleType("P",
                 diffusivity = 1621,
                 mass = 121,
                 radius = 5,
                 nts = 0.5      # made compatible with nbPot
)

_B = ParticleType("B",
                 diffusivity = 1093,
                 mass = 181,    # thymine
                 radius = 3,                 
                 nts = 0.5      # made compatible with nbPot
)

class DnaStrandFromPolymer(Group):
    p = PointParticle(_P, (0,0,0), "P")
    b = PointParticle(_B, (3,0,1), "B")
    nt = Group( name = "nt", children = [p,b])
    nt.add_bond( i=p, j=b, bond = '../../common/two_bead_model/BPB.dat', exclude = True )

    def __init__(self, polymer, **kwargs):
        self.polymer = polymer
        Group.__init__(self, **kwargs)
        
    def _clear_beads(self):
        ...
        
    def _generate_beads(self):
        nts = self.nts = self.children

        for i in range(self.polymer.num_monomers):
            c = self.polymer.monomer_index_to_contour(i)
            r = self.polymer.contour_to_position(c)
            o = self.polymer.contour_to_orientation(c)
            
            new = DnaStrandFromPolymer.nt.duplicate()
            new.orientation = o
            new.position = r
            self.add(new)

        ## Two consecutive nts 
        for i in range(len(nts)-1):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            self.add_bond( i=b1, j=p2, bond = '../../common/two_bead_model/BBP.dat', exclude=True )
            self.add_bond( i=p1, j=p2, bond = '../../common/two_bead_model/BPP.dat', exclude=True )
            self.add_angle( i=p1, j=p2, k=b2, angle = '../../common/two_bead_model/p1p2b2.dat' )
            self.add_angle( i=b1, j=p2, k=b2, angle = '../../common/two_bead_model/b1p2b2.dat' )
            self.add_dihedral( i=b1, j=p1, k=p2, l=b2, dihedral = '../../common/two_bead_model/b1p1p2b2.dat' )
            self.add_exclusion( i=b1, j=b2 )
            self.add_exclusion( i=p1, j=b2 )

        ## Three consecutive nts 
        for i in range(len(nts)-2):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            self.add_angle( i=p1, j=p2, k=p3, angle = '../../common/two_bead_model/p1p2p3.dat' )
            self.add_angle( i=b1, j=p2, k=p3, angle = '../../common/two_bead_model/b1p2p3.dat' )
            self.add_dihedral( i=b1, j=p2, k=p3, l=b3, dihedral = '../../common/two_bead_model/b1p2p3b3.dat' )
            self.add_exclusion( i=p1, j=p3 )
            self.add_exclusion( i=b1, j=p3 )

        ## Four consecutive nts 
        for i in range(len(nts)-3):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            p4,b4 = nts[i+3].children
            self.add_dihedral( i=p1, j=p2, k=p3, l=p4, dihedral = '../../common/two_bead_model/p0p1p2p3.dat' )

class IndependentDnaStrandFromPolymer(DnaStrandFromPolymer):

    particle_type_dict = {}
    
    def __init__(self, polymer, index, grid_path, **kwargs):
        if index not in IndependentDnaStrandFromPolymer.particle_type_dict:

            _P = ParticleType("P{:03d}".format(index),
                              diffusivity = 1621,
                              mass = 121,
                              radius = 5,
                              nts = 0.5,      # made compatible with nbPot
                              grid=[('../{}/grid-P.dx'.format(grid_path), 0.57827709)]
                              )
            _B = ParticleType("B{:03d}".format(index),
                              diffusivity = 1093,
                              mass = 181,    # thymine
                              radius = 3,                 
                              nts = 0.5,      # made compatible with nbPot
                              grid=[('../{}/grid-B.dx'.format(grid_path), 0.57827709)]
                              )
            IndependentDnaStrandFromPolymer.particle_type_dict[index] = (_P,_B)
        _P,_B = self.types = IndependentDnaStrandFromPolymer.particle_type_dict[index]
        p = PointParticle(_P, (0,0,0), "P")
        b = PointParticle(_B, (3,0,1), "B")
        
        self.nt = nt = Group( name = "nt", children = [p,b])
        nt.add_bond( i=p, j=b, bond = '../../common/two_bead_model/BPB.dat', exclude = True )
        
        self.polymer = polymer
        Group.__init__(self, **kwargs)
        
    def _clear_beads(self):
        ...
        
    def _generate_beads(self):
        nts = self.nts = self.children

        for i in range(self.polymer.num_monomers):
            c = self.polymer.monomer_index_to_contour(i)
            r = self.polymer.contour_to_position(c)
            o = self.polymer.contour_to_orientation(c)
            
            new = self.nt.duplicate()
            new.orientation = o
            new.position = r
            self.add(new)

        ## Two consecutive nts 
        for i in range(len(nts)-1):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            self.add_bond( i=b1, j=p2, bond = '../../common/two_bead_model/BBP.dat', exclude=True )
            self.add_bond( i=p1, j=p2, bond = '../../common/two_bead_model/BPP.dat', exclude=True )
            self.add_angle( i=p1, j=p2, k=b2, angle = '../../common/two_bead_model/p1p2b2.dat' )
            self.add_angle( i=b1, j=p2, k=b2, angle = '../../common/two_bead_model/b1p2b2.dat' )
            self.add_dihedral( i=b1, j=p1, k=p2, l=b2, dihedral = '../../common/two_bead_model/b1p1p2b2.dat' )
            self.add_exclusion( i=b1, j=b2 )
            self.add_exclusion( i=p1, j=b2 )

        ## Three consecutive nts 
        for i in range(len(nts)-2):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            self.add_angle( i=p1, j=p2, k=p3, angle = '../../common/two_bead_model/p1p2p3.dat' )
            self.add_angle( i=b1, j=p2, k=p3, angle = '../../common/two_bead_model/b1p2p3.dat' )
            self.add_dihedral( i=b1, j=p2, k=p3, l=b3, dihedral = '../../common/two_bead_model/b1p2p3b3.dat' )
            self.add_exclusion( i=p1, j=p3 )
            self.add_exclusion( i=b1, j=p3 )

        ## Four consecutive nts 
        for i in range(len(nts)-3):
            p1,b1 = nts[i].children
            p2,b2 = nts[i+1].children
            p3,b3 = nts[i+2].children
            p4,b4 = nts[i+3].children
            self.add_dihedral( i=p1, j=p2, k=p3, l=p4, dihedral = '../../common/two_bead_model/p0p1p2p3.dat' )

class DnaModel(ArbdModel):
    def __init__(self, polymers, grid_path, num_polymers_per_replica,
                 DEBUG=False,
                 **kwargs):

        kwargs['particle_integrator'] = 'Langevin'
        kwargs['extra_bd_file_lines'] = 'ParticleLangevinIntegrator BAOAB'
        kwargs['timestep'] = 20e-6
        kwargs['temperature'] = 291
        kwargs['cutoff'] = 35
        kwargs['pairlist_distance'] = 60
        kwargs['decomp_period'] = 1000
        
        self.polymer_group = AbstractPolymerGroup(polymers)
        self.strands = [IndependentDnaStrandFromPolymer(p,i//num_polymers_per_replica, grid_path)
                        for i,p in enumerate(self.polymer_group.polymers)]
        ArbdModel.__init__(self, self.strands, **kwargs)
        self.nbSchemes = []

        processed = set()
        for strand in self.strands:
            if strand.types not in processed:
                _P,_B = strand.types
                self.useNonbondedScheme( TabulatedPotential('../../common/two_bead_model/NBBB.dat'), typeA=_B, typeB=_B )
                self.useNonbondedScheme( TabulatedPotential('../../common/two_bead_model/NBPB.dat'), typeA=_P, typeB=_B )
                self.useNonbondedScheme( TabulatedPotential( '../../common/two_bead_model/NBPP.dat'), typeA=_P, typeB=_P )
                processed.add( strand.types )            
        self.generate_beads()
        
    def generate_beads(self):
        for s in self.strands:
            s._generate_beads()
        

def run_round(index, last_coordinates = None, generate_grid=False):

    strands_per_replica = 31
    dimensions = [106.620003]*3
    name = 'many-strands'

    if generate_grid: 
        path = 'iter-{}'.format(index-1)
        try:
            last_coordinates = readArbdCoords('{}/output/{}.restart'.format(path,name))
        except:
            pass
        generate_new_grid( index-1, name )
    
    IndependentDnaStrandFromPolymer.particle_type_dict = {} # Ugly hack to clear cached particle types that have wrong grids
    strands = []
    for i in range(strands_per_replica*n_replicas):
        r0 = np.array( [(a-0.5)*b for a,b in 
                        zip( np.random.uniform(size=3), dimensions )] )
        r1 = r0 + (np.random.uniform(size=3)-0.5)*5*5

        s = PolymerSection("D{}".format(i), num_monomers=5, monomer_length=5,
                           start_position=r0, end_position=r1)
        strands.append(s)

    ## Randomly place strands through system
    model = DnaModel( strands, grid_path='grids-{}'.format(index), num_polymers_per_replica=strands_per_replica, dimensions=dimensions )


    if last_coordinates is not None:
        for p,c in zip([p for p in model],last_coordinates):
            p.position = c
            
    path = 'iter-{}'.format(index)
    model.simulate( output_name = name, output_period=1e3, num_steps=1e6, directory=path, gpu=1 ) # 20 ns

    coords = readArbdCoords('{}/output/{}.restart'.format(path,name))
    generate_new_grid( index, name )
    return coords

def symmetrize_grid(asym):
    sym = np.array(asym)
    sym = sym + asym[::-1,:,::-1]
    sym = sym + asym[::-1,::-1,:]
    sym = sym + asym[:,::-1,::-1]
    sym = 0.25 * sym
    return sym

target_density_grids = {c:Grid('target_density/grid-{}.dx'.format(c)) for c in ('P','B')}

_last_grids = {c:None for c in ('P','B')}
for k,g in target_density_grids.items():
    g.grid = symmetrize_grid(g.grid)

def get_mask():
    if Path('mask.dx').exists():
        return Grid('mask.dx').grid
    from scipy.ndimage import gaussian_filter
    mask_grid = Grid('ssb-density.dx')
    mask = mask_grid.grid
    mask = gaussian_filter( mask, sigma=1.5/mask_grid.delta, mode='constant' )
    mask = symmetrize_grid(mask)
    mask = (mask - 0.001) / (0.025-0.001)
    mask[mask>1] = 1
    mask[mask<0] = 0
    mask = symmetrize_grid(mask)
    writeDx('mask.dx', mask,
            origin=mask_grid.origin, delta=mask_grid.delta, fmt='%.6f')
    return mask
mask = get_mask()
    
def get_avg_edge_value(u):
    return (u[0,:,:].mean() + u[-1,:,:].mean() +
            u[:,0,:].mean() + u[:,-1,:].mean() +
            u[:,:,0].mean() + u[:,:,-1].mean())/6
    
def generate_initial_grid( d ):
    for bead_type in ('P','B'):
        target = target_density_grids[bead_type]
        u = np.array(target.grid)
        u = symmetrize_grid(u)
        u = -1.0*np.log( (u + 1e-15) ) * mask
        u = u - get_avg_edge_value(u)
        writeDx('{}/grid-{}.dx'.format(d,bead_type), u,
                origin=target.origin, delta=target.delta, fmt='%.6f')
        _last_grids[bead_type] = Grid(u, origin=target.origin, delta=target.delta)

def generate_new_grid( index, name ):
    d = 'grids-{}'.format(index+1)
    d_old = 'grids-{}'.format(index)
    try:
        os.makedirs(d)
    except:
        pass

    if index == 0:
        return generate_initial_grid(d)
    
    import subprocess
    from scipy.ndimage import gaussian_filter
    path = 'iter-{}'.format(index)

    vmdin = """
set ID [mol new {path}/{name}.psf]
mol addfile {path}/output/{name}.dcd beg 200 waitfor all
    
foreach type "P B" {{
    set sel [atomselect $ID "name $type"]
    $sel set radius 3 
    volmap density $sel -o {path}/{name}.$type-density.dx -res 0.5 -combine avg -minmax "{{-38 -36 -50}} {{38 36 50}}" -allframes -checkpoint 0
}}
    """.format( path=path, name=name )
    subprocess.run(['vmd','-dispdev','text'], input=vmdin, encoding='ascii', stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

    for bead_type in ('P','B'):
        density = Grid('{}/{}.{}-density.dx'.format(path,name,bead_type)).grid
        density = density / n_replicas
        density = symmetrize_grid(density)

        target = target_density_grids[bead_type] 

        dU = -1.0*np.log( (target.grid+0.5*1e-6)/(density+0.5*1e-6) ) * mask
        dU = dU - get_avg_edge_value(dU)

        """ ibi-combine """
        if _last_grids[bead_type] is None:
            _last_grids[bead_type] = Grid('{}/grid-{}.dx'.format(d_old,bead_type))
            
        ulast = _last_grids[bead_type].grid
        out = ulast+dU
        sl = (target.grid < 1e-6) | (out > 20)
        out[sl] = 20

        writeDx('{}/grid-{}.dx'.format(d,bead_type), out,
                origin=target.origin, delta=target.delta, fmt='%.6f')
        _last_grids[bead_type] = Grid(out, origin=target.origin, delta=target.delta)

        
if __name__ == "__main__":
    
    # """ #START
    c = None
    start=1
    for i in range(start,start+200):
        c = run_round( i, c, generate_grid = (i==start) )
        print("Round {}, c.shape {}".format(i,c.shape))
    """ #END """