import numpy as np
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme


"""
# TODO:


"""

class Location():
    """ Site for connection within an object """
    def __init__(self, container, address, type_):
        self.container = container
        self.address = address
        self.type_ = type_
        self.particle = None

class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
    def __init__(self, connections=[]):
        self.connections = connections

    def _connect(self, other, connection):
        self.connections.append(connection)
        other.connections.append(connection)

    def _find_connections(self, loc):
        return [c for c in self.connections if c.A == loc or c.B == loc]
       
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
        ConnectableElement.__init__(self, connections=[])

        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

        # self.end5 = Location( self, address=0, type_= "end5" )
        # self.end3 = Location( self, address=-1, type_ = "end3" )

        self.num_nts = num_nts
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

    def _generate_one_bead(self, pos, nts):
        raise NotImplementedError

    def _assign_particles_to_locations(self):
        raise NotImplementedError

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
        ## Assume that consecutive beads in self.children are bonded
        ret = []
        for i in range(len(self.children)-number+1):
            tmp = [self.children[i+j] for j in range(0,number)]
            ret.append( tmp )
        return ret

    def get_beads_before_bead(self, bead, number, inclusive=True):
        ## Assume that consecutive beads in self.children are bonded
        i = self.children.index(bead)
        l = len(self.children)
        if i-number < 0:
            raise Exception("Not enough beads after bead")
        
        start = 1
        if inclusive: start = 0
        return [self.children[i-j] for j in range(start,number)]

    def get_beads_after_bead(self, bead, number, inclusive=True):
        ## Assume that consecutive beads in self.children are bonded
        i = self.children.index(bead)
        l = len(self.children)
        if i+number >= l:
            raise Exception("Not enough beads after bead")
        
        start = 1
        if inclusive: start = 0
        return [self.children[i+i] for j in range(start,number)]

    def _generate_beads(self, bead_model, max_nts_per_bead=4):
        
        """ Generate beads (positions, types, etcl) and bonds, angles, dihedrals, exclusions """

        # self._bead_model_generation += 1
        # self._bead_model_max_nts_per_bead = max_nts_per_bead

        direction = self.end_position - self.start_position
        num_beads = (self.num_nts // max_nts_per_bead) + 1
        nts_per_bead = float(self.num_nts)/num_beads

        last = None

        for i in range(num_beads+1):
            nts = nts_per_bead
            if i == 0 or i == num_beads: 
                nts *= 0.5

            s = i*float(nts_per_bead)/(self.num_nts) # contour
            pos = direction * s + self.start_position

            b = self._generate_one_bead(pos,nts)
            self.children.append(b)
            # if last is not None:
            #     self.add_bond( i=last, j=b, bond="ssdna" )
            # last = b
        self._assign_particles_to_locations()

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...

    def _generate_atomic(self, atomic_model):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
                 twist = None,
                 start_orientation = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.nicks = []

        self.start5 = Location( self, address=0, type_= "end5" )
        self.start3 = Location( self, address=0, type_ = "end3" )

        self.end5 = Location( self, address=-1, type_= "end5" )
        self.end3 = Location( self, address=-1, type_ = "end3" )

    ## Convenience methods
    def connect_start5(self, end3, force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.start5, end3, force_connection = force_connection )
    def connect_start3(self, end5, force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.end5
        self._connect_ends( self.start3, end5, force_connection = force_connection )
    def connect_end3(self, end5, force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.end5
        self._connect_ends( self.end3, end5, force_connection = force_connection )
    def connect_end5(self, end3, force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.end5, end3, force_connection = force_connection )

        
    ## Real work
    def _connect_ends(self, end1, end2, force_connection):

        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )

        end1.container._connect( end2.container, Connection( end1, end2, type_="intrahelical" ) )

    def _generate_one_bead(self, pos, nts):
        return PointParticle( Segment.dsDNA_particle, pos, nts,
                              num_nts=nts, parent=self )

    def _assign_particles_to_locations(self):
        self.start3.particle =  self.start5.particle = self.children[0]
        self.end3.particle =  self.end5.particle = self.children[-1]

    def _generate_atomic(self, atomic_model):
        ...
    
        
    # def add_crossover(self, locationInA, B, locationInB):
    #     j = Crossover( [self, B], [locationInA, locationInB] )
    #     self._join(B,j)

    # def add_internal_crossover(self, locationInA, B, locationInB):
    #     j = Crossover( [self, B], [locationInA, locationInB] )
    #     self._join(B,j)


    # def stack_end(self, myEnd):
    #     ## Perhaps this should not really be possible; these ends should be part of same helix
    #     ...

    # def connect_strand(self, other):
    #     ...
        
    # def break_apart(self):
    #     """Break into smaller pieces so that "crossovers" are only at the ends"""
    #     ...

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.start = self.end5 = Location( self, address=0, type_= "end5" )
        self.end = self.end3 = Location( self, address=-1, type_ = "end3" )

    def connect_3end(self, end5, force_connection=False):
        self._connect_end( end5,  _5_to_3 = False, force_connection = force_connection )

    def connect_5end(self, end3, force_connection=False):
        self._connect_end( end3,  _5_to_3 = True, force_connection = force_connection )

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
            my_end = self.end5
            assert( other.type_ == "end3" )
        else:
            my_end = self.end3
            assert( other.type_ == "end5" )

        self._connect( other.container, Connection( my_end, other, type_="intrahelical" ) )

    def _generate_one_bead(self, pos, nts):
        return PointParticle( Segment.ssDNA_particle, pos, nts,
                              num_nts=nts, parent=self )

    def _assign_particles_to_locations(self):
        self.start.particle = self.children[0]
        self.end.particle = self.children[-1]

    def _generate_atomic(self, atomic_model):
        ...
    

class SegmentModel(ArbdModel):
    def __init__(self, segments = [], 
                 max_basepairs_per_bead = 7,
                 max_nucleotides_per_bead = 4,
                 dimensions=(1000,1000,1000), temperature=291,
                 timestep=50e-6, cutoff=50, 
                 decompPeriod=10000, pairlistDistance=None, 
                 nonbondedResolution=0):


        ArbdModel.__init__(self,segments,
                           dimensions, temperature, timestep, cutoff, 
                           decompPeriod, pairlistDistance=None,
                           nonbondedResolution=0)

        self.max_basepairs_per_bead = max_basepairs_per_bead     # dsDNA
        self.max_nucleotides_per_bead = max_nucleotides_per_bead # ssDNA
        self.children = self.segments = segments

        self._bonded_potential = dict() # cache bonded potentials

        self._generate_bead_model(segments, max_nucleotides_per_bead, max_nucleotides_per_bead)
    
    def _get_intrahelical_beads(self):
        ret = []
        for s in self.segments:
            ret.extend( s.get_all_consecutive_beads(2) )

        for s in self.segments:
            for c in s.connections:
                if c.type_ == "intrahelical":
                    if c.A.container == s: # avoid double-counting
                        b1,b2 = [loc.particle for loc in (c.A,c.B)]
                        for b in (b1,b2): assert( b is not None )
                        ret.append( [b1,b2] )
        return ret

    def _get_intrahelical_angle_beads(self):
        ret = []
        for s in self.segments:
            ret.extend( s.get_all_consecutive_beads(3) )

        for s1 in self.segments:
            for c in s1.connections:
                if c.A.container != s1: continue
                s2 = c.B.container
                if c.type_ == "intrahelical":
                    b1,b2 = [loc.particle for loc in (c.A,c.B)]
                    for b in (b1,b2): assert( b is not None )
                    try:
                        b0 = s1.get_beads_before_bead(b1,1)[0]
                        assert( b0 is not None )
                        ret.append( [b0,b1,b2] )
                    except:
                        ...
                    try:
                        b0 = s1.get_beads_after_bead(b1,1)[0]
                        assert( b0 is not None )
                        ret.append( [b2,b1,b0] )
                    except:
                        ...
                    try:
                        b3 = s2.get_beads_before_bead(b2,1)[0]
                        assert( b3 is not None )
                        ret.append( [b3,b2,b1] )
                    except:
                        ...
                    try:
                        b3 = s2.get_beads_after_bead(b2,1)[0]
                        assert( b3 is not None )
                        ret.append( [b1,b2,b3] )
                    except:
                        ...
        return ret

    def _get_potential(self, type_, kSpring, d):
        key = (type_,kSpring,d)
        if key not in self._bonded_potential:
            if type_ == "bond":
                self._bonded_potential[key] = HarmonicBond(kSpring,d)
            elif type_ == "angle":
                self._bonded_potential[key] = HarmonicAngle(kSpring,d)
            elif type_ == "dihedral":
                self._bonded_potential[key] = HarmonicDihedral(kSpring,d)
            else:
                raise Exception("Unhandled potential type '%s'" % type_)
        return self._bonded_potential[key]
    def get_bond_potential(self, kSpring, d):
        return self._get_potential("bond", kSpring, d)
    def get_angle_potential(self, kSpring, d):
        return self._get_potential("angle", kSpring, d)
    def get_dihedral_potential(self, kSpring, d):
        return self._get_potential("dihedral", kSpring, d)


    def _generate_bead_model(self, segments,
                             max_basepairs_per_bead = 7,
                             max_nucleotides_per_bead = 4):


        """ Generate beads """
        for s in segments:
            s._generate_beads( max_nucleotides_per_bead )

        """ Combine beads at junctions as needed """
        for s in segments:
            for c in s.connections:
                if c.A.container == s:
                    ...

        """ Reassign bead types """
        beadtype_s = dict()
        for segment in segments:
            for b in segment:
                key = (b.type_.name[0].upper(), b.num_nts)
                if key in beadtype_s:
                    b.type_ = beadtype_s[key]
                else:
                    t = deepcopy(b.type_)
                    if key[0] == "D":
                        t.__dict__["nts"] = b.num_nts*2
                    elif key[0] == "S":
                        t.__dict__["nts"] = b.num_nts
                    else:
                        raise Exception("TODO")
                    print(t.nts)
                    t.name = t.name + "%03d" % (100*t.nts)
                    beadtype_s[key] = b.type_ = t


        """ Add intrahelical potentials """
        ## First replace intrahelical bonds
        for b1,b2 in self._get_intrahelical_beads():
            if b1.parent == b2.parent:
                sep = 0.5*(b1.num_nts+b2.num_nts)
                parent = b1.parent
            else:
                sep = 1
                parent = self
                
            if b1.type_.name[0] == "D" and b1.type_.name[0] == "D":
                k = 10.0/np.sqrt(sep) # TODO: determine from simulations
                d = 3.4*sep
            else:
                ## TODO: get correct numbers from ssDNA model
                k = 1.0/np.sqrt(sep)
                d = 5*sep

            bond = self.get_bond_potential(k,d)
            parent.add_bond( b1, b2, bond, exclude=True )
            

        """ Add connection potentials """
        ## TODO


    # def get_bead(self, location):
    #     if type(location.container) is not list:
    #         s = self.segments.index(location.container)
    #         s.get_bead(location.address)
    #     else:
    #         r
    #         ...
    


if __name__ == "__main__":

    seg1 = DoubleStrandedSegment("strand", num_nts = 46)
    seg2 = SingleStrandedSegment("strand", 
                                 start_position = seg1.end_position + np.array((1,0,1)),
                                 num_nts = 12)

    seg3 = SingleStrandedSegment("strand", 
                                 start_position = seg1.start_position + np.array((-1,0,-1)),
                                 end_position = seg1.end_position + np.array((-1,0,1)),
                                 num_nts = 128)

    seg1.start3
    seg1.start5
    seg1.end3
    seg1.end5

    seg1.connect_end3(seg2)
    seg1.connect_end5(seg3)
    seg1.connect_start3(seg3)

    model = SegmentModel( [seg1, seg2, seg3],
                          dimensions=(5000,5000,5000),
                      )
    model.useNonbondedScheme( nbDnaScheme )
    model.simulate( outputPrefix = 'strand-test', outputPeriod=1e4, numSteps=1e6, gpu=1 )



    # seg = SingleStrandedSegment("strand", num_nts = 21)
    # generate_bead_model( [seg] )
    # for b in seg:
    #     print(b.num_nts, b.position)