import numpy as np
import random
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme

from scipy.special import erf
import scipy.optimize as opt
from scipy import interpolate

from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC

# import pdb
"""
TODO:
 + fix handling of crossovers for atomic representation
 + map to atomic representation
    + add nicks
    - transform ssDNA nucleotides 
    - shrink ssDNA
    - shrink dsDNA backbone
    + make orientation continuous
    - sequence
    - handle circular dna
 + ensure crossover bead potentials aren't applied twice 
 + remove performance bottlenecks
 - test for large systems
 - assign sequence
 - ENM
 - rework Location class 
 - remove recursive calls
 - document
 - add unit test of helices connected to themselves
"""

class Location():
    """ Site for connection within an object """
    def __init__(self, container, address, type_, on_fwd_strand = True):
        ## TODO: remove cyclic references(?)
        self.container = container
        self.address = address  # represents position along contour length in segments
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
        self.on_fwd_strand = on_fwd_strand
        self.type_ = type_
        self.particle = None
        self.connection = None
        self.is_3prime_side_of_connection = None

        self.prev_in_strand = None
        self.next_in_strand = None
        
        self.combine = None     # some locations might be combined in bead model 

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

    def set_connection(self, connection, is_3prime_side_of_connection):
        self.connection = connection # TODO weakref? 
        self.is_3prime_side_of_connection = is_3prime_side_of_connection

    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        # return "<Location in {} at contour {} {} with connection {}>".format( self.container.name, self.address, self.on_fwd_strand, self.connection )
        # return "<Location {} in {} at contour {} {} with connection {}>".format( self.type_, self.container.name, self.address, on_fwd, self.connection )
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
        
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
    ## TODO: eliminate mutable default arguments
    def __init__(self, connection_locations=[], connections=[]):
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
        self.connections = connections

    def get_locations(self, type_=None, exclude=[]):
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
        return locs

    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
        loc = None
        if (self.num_nts == 1):
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
            assert( loc is not None )
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
        if loc is None:
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc

    def get_connections_and_locations(self, connection_type=None, exclude=[]):
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
        type_ = connection_type
        ret = []
        for c in self.connections:
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
                if   c.A.container is self:
                    ret.append( [c, c.A, c.B] )
                elif c.B.container is self:
                    ret.append( [c, c.B, c.A] )
                else:
                    import pdb
                    pdb.set_trace()
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

    def _connect(self, other, connection, in_3prime_direction=None):
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
        A.connection = B.connection = connection
        self.connections.append(connection)
        other.connections.append(connection)
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]

class SegmentParticle(PointParticle):
    def __init__(self, type_, position, name="A", segname="A", **kwargs):
        self.name = name
        self.contour_position = None
        PointParticle.__init__(self, type_, position, name=name, segname=segname, **kwargs)
        self.intrahelical_neighbors = []
        self.other_neighbors = []

    # def get_contour_position(self,seg):
    #     assert( isinstance(seg,Segment) )
    #     if seg == self.parent:
    #         return self.contour_position
    #     else:
    #         ## TODO replace with something more elegant
    #         for c,A,B in self.parent.get_connections_and_locations():
    #             if A.particle is self and B.container is seg:
    #                 nt = np.abs( (self.contour_position - A.address)*(A.container.num_nts-1) )
    #                 if B.address < 0.5:
    #                     return B.address-nt/(seg.num_nts-1)
    #                 else:
    #                     return B.address+nt/(seg.num_nts-1)
    #         ## ERROR
    #         print("")
    #         for c,A,B in self.parent.get_connections_and_locations():
    #             print("  ",c.type_)
    #             print(A,B)
    #             print(A.particle,self)
    #             print(B.container,seg)
    #         print("")
    #         import pdb
    #         pdb.set_trace()
    #         raise Exception("Did not find location for particle {} in Segment {}".format(self,seg))

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b
        

    def get_nt_position(self,seg):
        if seg == self.parent:
            return seg.contour_to_nt_pos(self.contour_position)
        else:
            cl = [e for e in self.parent.get_connections_and_locations() if e[2].container is seg]
            dc = [(self.contour_position - A.address)**2 for c,A,B in cl]

            if len(dc) == 0:
                pdb.set_trace()

            i = np.argmin(dc)
            c,A,B = cl[i]
            ## TODO: generalize, removing np.abs and conditional 
            delta_nt = np.abs( A.container.contour_to_nt_pos(self.contour_position - A.address) )
            B_nt_pos = seg.contour_to_nt_pos(B.address)
            if B.address < 0.5:
                return B_nt_pos-delta_nt
            else:
                return B_nt_pos+delta_nt

    def get_contour_position_old(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
            cl = [e for e in self.parent.get_connections_and_locations() in B.container is seg]
            dc = [(self.contour_position - A.address)**2 for c,A,B in e]

            if len(dc) == 0:
                pdb.set_trace()

            i = np.argmin(dc)

            nt = np.abs( (self.contour_position - A.address)*(A.container.num_nts-1) )
            if B.address < 0.5:
                return seg.nt_pos_to_contour(B.address-nt)
            else:
                return seg.nt_pos_to_contour(B.address+nt)

    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
            nt_pos = self.get_nt_position(seg)
            return seg.nt_pos_to_contour(nt_pos)

## TODO break this class into smaller, better encapsulated pieces
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )

    # orientation_bond = HarmonicBond(10,2)
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
        ConnectableElement.__init__(self, connection_locations=[], connections=[])

        self.resname = name
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

        self.num_nts = int(num_nts)
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

        ## Set up interpolation for positions
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        self.sequence = None

    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None

    def contour_to_nt_pos(self, contour_pos, round_nt=False):
        nt = contour_pos*(self.num_nts-1)
        if round_nt:
            assert( (np.around(nt) - nt)**2 < 1e-3 )
            nt = np.around(nt)
        return nt

    def nt_pos_to_contour(self,nt_pos):
        if self.num_nts == 1:
            assert(nt_pos == 0)
            return 0
        else:
            return nt_pos/(self.num_nts-1)

    def contour_to_position(self,s):
        p = interpolate.splev( s, self.position_spline_params )
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
        t = interpolate.splev( s, self.position_spline_params, der=1 )
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
        

    def contour_to_orientation(self,s):
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
        orientation = None
        if self.start_orientation is not None:
            # axis = self.start_orientation.dot( np.array((0,0,1)) )
            if self.quaternion_spline_params is None:
                axis = self.contour_to_tangent(s)
                orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=True )
            else:
                q = interpolate.splev( s, self.quaternion_spline_params )
                if len(q) > 1: q = np.array(q).T # TODO: is this needed?
                orientation = quaternion_to_matrix(q)
        return orientation

    def get_contour_sorted_connections_and_locations(self,type_):
        sort_fn = lambda c: c[1].address
        cl = self.get_connections_and_locations(type_)
        return sorted(cl, key=sort_fn)
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
        # bases = ['T']        ## FOR DEBUG
        if self.sequence is None:
            self.sequence = [random.choice(bases) for i in range(self.num_nts)]
        else:
            assert(len(self.sequence) == self.num_nts) # TODO move
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)

    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

    def _generate_one_bead(self, contour_position, nts):
        raise NotImplementedError

    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale):
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
        if self.local_twist:
            orientation = self.contour_to_orientation(contour_position)
            ## TODO: move this code (?)
            if orientation is None:
                axis = self.contour_to_tangent(contour_position)
                angleVec = np.array([1,0,0])
                if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
                angleVec = angleVec - angleVec.dot(axis)*axis
                angleVec = angleVec/np.linalg.norm(angleVec)
                y = np.cross(axis,angleVec)
                orientation = np.array([angleVec,y,axis]).T
                ## TODO: improve placement of ssDNA
                # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nts, normalizeAxis=True )
                # orientation = rot.dot(orientation)
            else:
                orientation = orientation
                            
        else:
            raise NotImplementedError

        # key = self.sequence
        # if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        # if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        # if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"

        key = seq
        if not is_fwd:
            nt_dict = canonicalNtFwd
        else:
            nt_dict = canonicalNtRev
        atoms = nt_dict[ key ].generate() # TODO: clone?
                        
        atoms.orientation = orientation.dot(atoms.orientation)
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
                    a.beta = 0
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsed_position()
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
                        a.beta = 0
            atoms.position = pos

        return atoms

    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
        c = self.nt_pos_to_contour(nt)
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
    def add_5prime(self, nt, on_fwd_strand=True):
        self.add_location(nt,"5prime",on_fwd_strand)

    def add_3prime(self, nt, on_fwd_strand=True):
        self.add_location(nt,"3prime",on_fwd_strand)

    def get_3prime_locations(self):
        return self.get_locations("3prime")
    
    def get_5prime_locations(self):
        ## TODO? ensure that data is consistent before _build_model calls
        return self.get_locations("5prime")

    def iterate_connections_and_locations(self, reverse=False):
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
        if reverse:
            cl = cl[::-1]
            
        for c in cl:
            yield c

    ## TODO rename
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
        """ Walks through locations, checking for crossovers """
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0

        ## Iterate through locations
        locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        # print(locations)

        for l in locations:
            pos = self.contour_to_nt_pos(l.address, round_nt=True)

            ## DEBUG

            ## Skip locations encountered before our strand
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue

            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
                print("  found end at",l)
                return pos, None, None, None, None

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
                print("  passing through",l)
                print("from {}, connection {} to {}".format(nt_pos,l,B))
                Bpos = B.container.contour_to_nt_pos(B.address, round_nt=True)
                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
                if nt_pos == pos: continue
                print("  pausing at",l)
                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0

        import pdb
        pdb.set_trace()
        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None


    def get_end_of_strand_old(self, contour_pos, is_fwd):
        """ Walks through locations, checking for crossovers """

        ## Iterate through locations
        # for l in self.locations:
        def loc_iter():
            locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
            # if is_fwd:
            for l in locations:
                yield l
            # else:
            #     for l in locations[::-1]:
            #         yield l
            
        for l in loc_iter():
            # if l.particle is None:
            #     pos = l.address
            # else:
            #     pos = l.particle.get_contour_position()          
            pos = l.address

            ## DEBUG
            # if self.name == "1-0" and is_fwd == False:
            #     import pdb
            #     pdb.set_trace()

            ## Skip locations encountered before our strand
            if is_fwd:
                if pos <= contour_pos: continue
            elif pos >= contour_pos: continue

            # print("  ?",l)
            
            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
                return pos, None, None, None

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
                # print("  passing through",l)
                # print("from {}, connection {} to {}".format(contour_pos,l,B))
                return pos, B.container, B.address, B.on_fwd_strand
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
                # print("  pausing at",l)
                # print("pausing at {}".format(l))
                return pos, l.container, pos, is_fwd

        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None
        
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
        # TODO: include beads in connections?
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
        ## Assume that consecutive beads in self.beads are bonded
        ret = []
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
            ret.append( tmp )
        return ret   

    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
        
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
        self.add(b)
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
            self.add(o)
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []

        if True:
            print("WARNING: DEBUG")
            ## Remove duplicates, preserving order
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
                else:
                    print("  duplicate particle found!")
            new_children = tmp

        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))


    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):

        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
        ## TODO: decide whether to remove bead_model argument
        ##       (currently unused)

        ## First find points between-which beads must be generated
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
        existing_beads = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads), key=lambda b: b.get_contour_position(self) )
        
        if len(existing_beads) != len(set(existing_beads)):
            pdb.set_trace()
        for b in existing_beads:
            assert(b.parent is not None)

        ## Add ends if they don't exist yet
        ## TODOTODO: test 1 nt segments?
        if len(existing_beads) == 0 or existing_beads[0].get_contour_position(self) > 0:
            if len(existing_beads) > 0:            
                assert(existing_beads[0].get_nt_position(self) >= 0.5)

            b = self._generate_one_bead(0, 0)
            existing_beads = [b] + existing_beads
        if existing_beads[-1].get_contour_position(self) < 1:
            # assert((1-existing_beads[0].get_contour_position(self))*(self.num_nts-1) >= 0.5)
            assert(self.num_nts-1-existing_beads[0].get_nt_position(self) >= 0.5)
            b = self._generate_one_bead(1, 0)
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
        last = None
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
            if eb1 is eb2:
                pdb.set_trace()
            assert( eb1 is not eb2 )

            # print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
            ds = e_ds / (num_beads+1)
            nts = ds*self.num_nts
            eb1.num_nts += 0.5*nts
            eb2.num_nts += 0.5*nts

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
                last.intrahelical_neighbors.append(eb1)
                eb1.intrahelical_neighbors.append(last)
                assert(len(last.intrahelical_neighbors) <= 2)
                assert(len(eb1.intrahelical_neighbors) <= 2)
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
                b = self._generate_one_bead(s,nts)

                last.intrahelical_neighbors.append(b)
                b.intrahelical_neighbors.append(last)
                assert(len(last.intrahelical_neighbors) <= 2)
                assert(len(b.intrahelical_neighbors) <= 2)
                last = b
                tmp_children.append(b)

        last.intrahelical_neighbors.append(eb2)
        eb2.intrahelical_neighbors.append(last)
        assert(len(last.intrahelical_neighbors) <= 2)
        assert(len(eb2.intrahelical_neighbors) <= 2)

        if eb2.parent == self:
            tmp_children.append(eb2)
        self._rebuild_children(tmp_children)

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
                 local_twist = False,
                 num_turns = None,
                 start_orientation = None,
                 twist_persistence_length = 90 ):
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.local_twist = local_twist
        if num_turns is None:
            num_turns = float(num_nts) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_nts

        if start_orientation is None:
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
        self.start_orientation = start_orientation
        self.twist_persistence_length = twist_persistence_length

        self.nicks = []

        self.start = self.start5 = Location( self, address=0, type_= "end5" )
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )

        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
        for l in (self.start5,self.start3,self.end3,self.end5):
            self.locations.append(l)

        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        ## TODO: initialize sensible spline for orientation
        self.quaternion_spline_params = None


    ## Convenience methods
    ## TODO: add errors if unrealistic connections are made
    ## TODO: make connections automatically between unconnected strands
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.start5
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.start5
        self._connect_ends( self.end3, end5, type_, force_connection = force_connection )
    def connect_end5(self, end3, type_="intrahelical", force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.end5, end3, type_, force_connection = force_connection )

    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False]):
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO

        if isinstance(other,SingleStrandedSegment):
            other.add_crossover(other_nt, self, nt, strands_fwd[::-1])
        else:

            ## Create locations, connections and add to segments
            c = self.nt_pos_to_contour(nt)
            assert(c >= 0 and c <= 1)

            loc = self.get_location_at(c, strands_fwd[0])

            c = other.nt_pos_to_contour(other_nt)
            assert(c >= 0 and c <= 1)
            other_loc = other.get_location_at(c, strands_fwd[1])
            self._connect(other, Connection( loc, other_loc, type_="crossover" ))
            loc.is_3prime_side_of_connection = not strands_fwd[0]
            other_loc.is_3prime_side_of_connection = not strands_fwd[1]
            

    ## Real work
    def _connect_ends(self, end1, end2, type_, force_connection):
        ## TODO remove self?
        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )
        ## Create and add connection
        if end2.type_ == "end3":
            end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ), in_3prime_direction=True )
        else:
            end2.container._connect( end1.container, Connection( end2, end1, type_=type_ ), in_3prime_direction=True )
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_basepairs_per_bead)

    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
        if self.local_twist:
            orientation = self.contour_to_orientation(contour_position)
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
            opos = pos + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
            o = SegmentParticle( Segment.orientation_particle, opos, nts,
                                 num_nts=nts, parent=self )
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
                                    num_nts=nts, parent=self, 
                                    orientation_bead=o,
                                    contour_position=contour_position )

        else:
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
                                    num_nts=nts, parent=self,
                                    contour_position=contour_position )
        self._add_bead(bead)
        return bead

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.start = self.start5 = Location( self, address=0, type_= "end5" ) # TODO change type_?
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        for l in (self.start5,self.end3):
            self.locations.append(l)

    def connect_end3(self, end5, force_connection=False):
        self._connect_end( end5,  _5_to_3 = False, force_connection = force_connection )

    def connect_5end(self, end3, force_connection=False): # TODO: change name or possibly deprecate
        self._connect_end( end3,  _5_to_3 = True, force_connection = force_connection )

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
            my_end = self.end5
            assert( other.type_ == "end3" )
            conn = Connection( my_end, other, type_="intrahelical" )
            self._connect( other.container, conn, in_3prime_direction=True )
        else:
            my_end = self.end3
            assert( other.type_ == "end5" )
            conn = Connection( other, my_end, type_="intrahelical" )
            other.container._connect( self, conn, in_3prime_direction=True )

    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False]):
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO
       
        ## TODO: fix direction

        c1 = self.nt_pos_to_contour(nt)
        ## Ensure connections occur at ends, otherwise the structure doesn't make sense
        assert(np.isclose(c1,0) or np.isclose(c1,1))
        loc = self.get_location_at(c1, True)

        c2 = other.nt_pos_to_contour(other_nt)
        if isinstance(other,SingleStrandedSegment):
            ## Ensure connections occur at opposing ends
            assert(np.isclose(c2,0) or np.isclose(c2,1))
            other_loc = other.get_location_at( c2, True )
            assert( loc.type_ in ("end3","end5"))
            assert( other_loc.type_ in ("end3","end5"))
            if loc.type_ == "end3":
                self.connect_end3( other_loc )
            else:
                assert( other_loc.type_ == "end3" )
                other.connect_end3( self )

        else:
            assert( loc.type_ in ("end3","end5"))
            assert(c2 >= 0 and c2 <= 1)
            other_loc = other.get_location_at( c2, strands_fwd[1] )
            if loc.type_ == "end3":
                self._connect(other, Connection( loc, other_loc, type_="sscrossover" ), in_3prime_direction=True )
            else:
                other._connect(self, Connection( other_loc, loc, type_="sscrossover" ), in_3prime_direction=True )

    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_nucleotides_per_bead)

    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
        b = SegmentParticle( Segment.ssDNA_particle, pos, 
                             name="NAS",
                             num_nts=nts, parent=self,
                             contour_position=contour_position )
        self._add_bead(b)
        return b

    
class StrandInSegment(Group):
    """ Class that holds atomic model, maps to segment """
    
    def __init__(self, segment, start, end, is_fwd):
        """ start/end should be provided expressed in nt coordinates, is_fwd tuples """
        Group.__init__(self)
        self.num_nts = 0
        # self.sequence = []
        self.segment = segment
        self.start = start
        self.end = end
        self.is_fwd = is_fwd

        nts = np.abs(end-start)+1
        self.num_nts = int(round(nts))
        assert( np.abs(self.num_nts-nts) < 1e-5 )

        # print(" Creating {}-nt StrandInSegment in {} from {} to {} {}".format(self.num_nts, segment.name, start, end, is_fwd))
    
    def _nucleotide_ids(self):
        nt0 = self.start # seg.contour_to_nt_pos(self.start)
        assert( np.abs(nt0 - round(nt0)) < 1e-5 )
        nt0 = int(round(nt0))
        assert( (self.end-self.start) >= 0 or not self.is_fwd )

        direction = (2*self.is_fwd-1)
        return range(nt0,nt0 + direction*self.num_nts, direction)

    def get_sequence(self):
        """ return 5-to-3 """
        # TODOTODO test
        seg = self.segment
        if self.is_fwd:
            return [seg.sequence[nt] for nt in self._nucleotide_ids()]
        else:
            return [seqComplement[seg.sequence[nt]] for nt in self._nucleotide_ids()]
    
    def get_contour_points(self):
        c0,c1 = [self.segment.nt_pos_to_contour(p) for p in (self.start,self.end)]
        return np.linspace(c0,c1,self.num_nts)
            
class Strand(Group):
    """ Class that holds atomic model, maps to segments """
    def __init__(self, segname = None):
        Group.__init__(self)
        self.num_nts = 0
        self.children = self.strand_segments = []
        self.segname = segname

    ## TODO disambiguate names of functions
    def add_dna(self, segment, start, end, is_fwd):
        """ start/end should be provided expressed as contour_length, is_fwd tuples """
        if not (segment.contour_to_nt_pos(np.abs(start-end)) > 0.9):
            import pdb
            pdb.set_trace()
        for s in self.strand_segments:
            if s.segment == segment and s.is_fwd == is_fwd:
                # assert( s.start not in (start,end) )
                # assert( s.end not in (start,end) )
                if s.start in (start,end) or s.end in (start,end):
                    import pdb
                    pdb.set_trace()
                    print("  CIRCULAR DNA")

        s = StrandInSegment( segment, start, end, is_fwd )
        self.add( s )
        self.num_nts += s.num_nts

    def set_sequence(self,sequence): # , set_complement=True):
        ## validate input
        assert( len(sequence) >= self.num_nts )
        assert( np.all( [i in ('A','T','C','G') for i in sequence] ) )
        
        seq_idx = 0
        ## set sequence on each segment
        for s in self.children:
            seg = s.segment
            if seg.sequence is None:
                seg.sequence = [None for i in range(seg.num_nts)]

            if s.is_fwd:
                for nt in s._nucleotide_ids():
                    seg.sequence[nt] = sequence[seq_idx]
                    seq_idx += 1
            else:
                for nt in s._nucleotide_ids():
                    seg.sequence[nt] = seqComplement[sequence[seq_idx]]
                    seq_idx += 1

    # def get_sequence(self):
    #     sequence = []
    #     for ss in self.strand_segments:
    #         sequence.extend( ss.get_sequence() )

    #     assert( len(sequence) >= self.num_nts )
    #     ret = ["5"+sequence[0]] +\
    #           sequence[1:-1] +\
    #           [sequence[-1]+"3"]
    #     assert( len(ret) == self.num_nts )
    #     return ret

    def generate_atomic_model(self,scale):
        last = None
        resid = 1
        strand_segment_count = 0
        for s in self.strand_segments:
            strand_segment_count += 1
            seg = s.segment
            contour = s.get_contour_points()
            assert(s.end != s.start)
            assert(np.linalg.norm( seg.contour_to_position(contour[-1]) - seg.contour_to_position(contour[0]) ) > 0.1)
            for c,seq in zip(contour,s.get_sequence()):
                if last is None:
                    seq = "5"+seq
                if strand_segment_count == len(s.strand_segments) and c == 1:
                    seq = seq+"3"

                nt = seg._generate_atomic_nucleotide( c, s.is_fwd, seq, scale )
                # if s.is_fwd:                    
                # else:
                #     nt = seg._generate_atomic_nucleotide( c, s.is_fwd, "A" )

                s.add(nt)
                ## Join last basepairs
                if last is not None:
                    o3,c3,c4,c2,h3 = [last.atoms_by_name[n] 
                                      for n in ("O3'","C3'","C4'","C2'","H3'")]
                    p,o5,o1,o2,c5 = [nt.atoms_by_name[n] 
                                     for n in ("P","O5'","O1P","O2P","C5'")]
                    self.add_bond( o3, p, None )
                    self.add_angle( c3, o3, p, None )
                    for x in (o5,o1,o2):
                        self.add_angle( o3, p, x, None )
                        self.add_dihedral(c3, o3, p, x, None )
                    for x in (c4,c2,h3):
                        self.add_dihedral(x, c3, o3, p, None )
                    self.add_dihedral(o3, p, o5, c5, None)
                nt.__dict__['resid'] = resid
                resid += 1
                last = nt

    def update_atomic_orientations(self,default_orientation):
        last = None
        resid = 1
        for s in self.strand_segments:
            seg = s.segment
            contour = s.get_contour_points()
            for c,seq,nt in zip(contour,s.get_sequence(),s.children):

                orientation = seg.contour_to_orientation(c)
                ## TODO: move this code (?)
                if orientation is None:
                    axis = seg.contour_to_tangent(c)
                    angleVec = np.array([1,0,0])
                    if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
                    angleVec = angleVec - angleVec.dot(axis)*axis
                    angleVec = angleVec/np.linalg.norm(angleVec)
                    y = np.cross(axis,angleVec)
                    orientation = np.array([angleVec,y,axis]).T

                nt.orientation = orientation.dot(default_orientation) # this one should be correct

class SegmentModel(ArbdModel):
    def __init__(self, segments=[], local_twist=True, escapable_twist=True,
                 max_basepairs_per_bead=7,
                 max_nucleotides_per_bead=4,
                 dimensions=(1000,1000,1000), temperature=291,
                 timestep=50e-6, cutoff=50, 
                 decompPeriod=10000, pairlistDistance=None, 
                 nonbondedResolution=0,DEBUG=0):
        self.DEBUG = DEBUG
        if DEBUG > 0: print("Building ARBD Model")
        ArbdModel.__init__(self,segments,
                           dimensions, temperature, timestep, cutoff, 
                           decompPeriod, pairlistDistance=None,
                           nonbondedResolution=0)

        # self.max_basepairs_per_bead = max_basepairs_per_bead     # dsDNA
        # self.max_nucleotides_per_bead = max_nucleotides_per_bead # ssDNA
        self.children = self.segments = segments

        self._bonded_potential = dict() # cache for bonded potentials

        self._generate_bead_model( max_basepairs_per_bead, max_nucleotides_per_bead, local_twist, escapable_twist)

        self.useNonbondedScheme( nbDnaScheme )


    def get_connections(self,type_=None,exclude=[]):
        """ Find all connections in model, without double-counting """
        added=set()
        ret=[]
        for s in self.segments:
            items = [e for e in s.get_connections_and_locations(type_,exclude=exclude) if e[0] not in added]
            added.update([e[0] for e in items])
            ret.extend( items )
        return ret
    
    def _recursively_get_beads_within_bonds(self,b1,bonds,done=[]):
        ret = []
        done = list(done)
        done.append(b1)
        if bonds == 0:
            return [[]]

        for b2 in b1.intrahelical_neighbors:
            if b2 in done: continue
            for tmp in self._recursively_get_beads_within_bonds(b2, bonds-1, done):
                ret.append( [b2]+tmp )
        return ret

    def _get_intrahelical_beads(self,num=2):
        ## TODO: add check that this is not called before adding intrahelical_neighbors in _generate_bead_model

        assert(num >= 2)

        ret = []
        for s in self.segments:
            for b1 in s.beads:
                for bead_list in self._recursively_get_beads_within_bonds(b1, num-1):
                    assert(len(bead_list) == num-1)
                    if b1.idx < bead_list[-1].idx: # avoid double-counting
                        ret.append([b1]+bead_list)
        return ret


    def _get_intrahelical_angle_beads(self):
        return self._get_intrahelical_beads(num=3)

    def _get_potential(self, type_, kSpring, d, max_potential = None):
        key = (type_, kSpring, d, max_potential)
        if key not in self._bonded_potential:
            if type_ == "bond":
                self._bonded_potential[key] = HarmonicBond(kSpring,d, rRange=(0,500), max_potential=max_potential)
            elif type_ == "angle":
                self._bonded_potential[key] = HarmonicAngle(kSpring,d, max_potential=max_potential)
                # , resolution = 1, maxForce=0.1)
            elif type_ == "dihedral":
                self._bonded_potential[key] = HarmonicDihedral(kSpring,d, max_potential=max_potential)
            else:
                raise Exception("Unhandled potential type '%s'" % type_)
        return self._bonded_potential[key]
    def get_bond_potential(self, kSpring, d):
        return self._get_potential("bond", kSpring, d)
    def get_angle_potential(self, kSpring, d):
        return self._get_potential("angle", kSpring, d)
    def get_dihedral_potential(self, kSpring, d, max_potential=None):
        while d > 180: d-=360
        while d < -180: d+=360
        return self._get_potential("dihedral", kSpring, d, max_potential)


    def _getParent(self, *beads ):
        if np.all( [b1.parent == b2.parent 
                    for b1,b2 in zip(beads[:-1],beads[1:])] ):
            return beads[0].parent
        else:
            return self

    def _get_twist_spring_constant(self, sep):
        """ sep in nt """
        kT = 0.58622522         # kcal/mol
        twist_persistence_length = 90  # set semi-arbitrarily as there is a large spread in literature
        ## <cos(q)> = exp(-s/Lp) = integrate( cos[x] exp(-A x^2), {x, 0, pi} ) / integrate( exp(-A x^2), {x, 0, pi} )
        ##   Assume A is small
        ## int[B_] :=  Normal[Integrate[ Series[Cos[x] Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]/
        ##             Integrate[Series[Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]]

        ## Actually, without assumptions I get fitFun below
        ## From http://www.annualreviews.org/doi/pdf/10.1146/annurev.bb.17.060188.001405
        ##   units "3e-19 erg cm/ 295 k K" "nm" =~ 73
        Lp = twist_persistence_length/0.34 

        fitFun = lambda x: np.real(erf( (4*np.pi*x + 1j)/(2*np.sqrt(x)) )) * np.exp(-1/(4*x)) / erf(2*np.sqrt(x)*np.pi) - np.exp(-sep/Lp)
        k = opt.leastsq( fitFun, x0=np.exp(-sep/Lp) )
        return k[0][0] * 2*kT*0.00030461742

    """ Mapping between different resolution models """
    def _clear_beads(self):
        for s in self.segments:
            s.clear_all()
        self.clear_all(keep_children=True)
        assert( len([b for b in self]) == 0 )
        locParticles = []
        # for c,A,B in self.get_connections():
        for s in self.segments:
            for c,A,B in s.get_connections_and_locations():
                for l in (A,B):
                    if l.particle is not None:
                        locParticles.append(A.particle)
        assert( len(locParticles) == 0 )
        assert( len([b for s in self.segments for b in s.beads]) == 0 )

    def _update_segment_positions(self, bead_coordinates):
        """ Set new function for each segments functions
        contour_to_position and contour_to_orientation """

        for s in self.segments:
            cabs = s.get_connections_and_locations("intrahelical")
            if np.any( [B.particle is None for c,A,B in cabs] ):
                print( "WARNING: none type found in connection, skipping" )
                cabs = [e for e in cabs if e[2].particle is not None]
            beads = set(s.beads + [A.particle for c,A,B in cabs])

            ## Add nearby beads
            for c,A,B in cabs:
                ## TODOTODO test?
                filter_fn = lambda x: x is not None and x not in beads
                bs = list( filter( filter_fn, B.particle.intrahelical_neighbors ) )
                beads.update(bs)
                bs = list( filter( filter_fn, [n for b in bs for n in b.intrahelical_neighbors] ) )
                beads.update(bs)

            beads = list(beads)
                        
            ## Skip beads that are None (some locations were not assigned a particle to avoid double-counting) 
            beads = [b for b in beads if b is not None]
            contours = [b.get_contour_position(s) for b in beads]

            cb = sorted( zip(contours,beads), key=lambda a:a[0] )
            beads = [b for c,b in cb] 
            contours = [c for c,b in cb] 

            ids = [b.idx for b in beads]
            
            """ Get positions """
            positions = bead_coordinates[ids,:].T

            try:
                tck, u = interpolate.splprep( positions, u=contours, s=0, k=3 )
            except:
                tck, u = interpolate.splprep( positions, u=contours, s=0, k=1 )

            s.position_spline_params = tck

            """ Get twist """
            cb = [e for e in cb if 'orientation_bead' in e[1].__dict__]
            beads = [b for c,b in cb] 
            contours = [c for c,b in cb] 
            ids = [b.idx for b in beads]
            # if 'orientation_bead' in beads[0].__dict__:
            if len(beads) > 3:
                tangents = s.contour_to_tangent(contours)
                quats = []
                lastq = None
                for b,t in zip(beads,tangents):
                    o = b.orientation_bead
                    # positions
                    # angleVec = o.position - b.position
                    angleVec = bead_coordinates[o.idx,:] - bead_coordinates[b.idx,:]
                    angleVec = angleVec - angleVec.dot(t)*t
                    angleVec = angleVec/np.linalg.norm(angleVec)
                    y = np.cross(t,angleVec)
                    assert( np.abs(np.linalg.norm(y) - 1) < 1e-2 )
                    q = quaternion_from_matrix( np.array([angleVec,y,t]).T)
                    # q = quaternion_from_matrix( np.array([angleVec,y,t]))
                    
                    if lastq is not None:
                        if q.dot(lastq) < 0:
                            q = -q
                    quats.append( q )
                    lastq = q

                # pdb.set_trace()
                quats = np.array(quats)
                ## TODOTODO test smoothing
                try:
                    tck, u = interpolate.splprep( quats.T, u=contours, s=3, k=3 )
                except:
                    tck, u = interpolate.splprep( quats.T, u=contours, s=0, k=1 )
                s.quaternion_spline_params = tck

    def _generate_bead_model(self,
                             max_basepairs_per_bead = 7,
                             max_nucleotides_per_bead = 4,
                             local_twist=False,
                             escapable_twist=True):


        segments = self.segments
        for s in segments:
            s.local_twist = local_twist

        """ Simplify connections """
        # d_nt = dict()           # 
        # for s in segments:
        #     d_nt[s] = 1.5/(s.num_nts-1)
        # for s in segments:
        #     ## replace consecutive crossovers with
        #     cl = sorted( s.get_connections_and_locations("crossover"), key=lambda x: x[1].address )
        #     last = None
        #     for entry in cl:
        #         c,A,B = entry
        #         if last is not None and \
        #            (A.address - last[1].address) < d_nt[s]:
        #             same_type = c.type_ == last[0].type_
        #             same_dest_seg = B.container == last[2].container
        #             if same_type and same_dest_seg:
        #                 if np.abs(B.address - last[2].address) < d_nt[B.container]:
        #                     ## combine
        #                     A.combine = last[1]
        #                     B.combine = last[2]

        #                     ...
        #         # if last is not None:
        #         #     s.bead_locations.append(last)
        #         ...
        #         last = entry
        # del d_nt

        """ Generate beads at intrahelical junctions """
        if self.DEBUG: print( "Adding intrahelical beads at junctions" )

        ## Loop through all connections, generating beads at appropriate locations
        for c,A,B in self.get_connections("intrahelical"):
            s1,s2 = [l.container for l in (A,B)]

            ## TODO be more elegant!
            if A.on_fwd_strand == False: continue # TODO verify this avoids double-counting
            assert( A.particle is None )
            assert( B.particle is None )

            ## TODO: offload the work here to s1
            a1,a2 = [l.address   for l in (A,B)]            
            # a1,a2 = [a - s.nt_pos_to_contour(0.5) if a == 0 else a + s.nt_pos_to_contour(0.5) for a,s in zip((a1,a2),(s1,s2))]
            for a in (a1,a2):
                assert( a in (0,1,0.0,1.0) )
            a1,a2 = [a - s.nt_pos_to_contour(0.5) if a == 0 else a + s.nt_pos_to_contour(0.5) for a,s in zip((a1,a2),(s1,s2))]
            
            ## TODO improve this for combinations of ssDNA and dsDNA (maybe a1/a2 should be calculated differently)
            if isinstance(s2,DoubleStrandedSegment):
                b = s2._generate_one_bead(a2,0)
            else:
                b = s1._generate_one_bead(a1,0)
            A.particle = B.particle = b

        """ Generate beads at other junctions """
        for c,A,B in self.get_connections(exclude="intrahelical"):
            s1,s2 = [l.container for l in (A,B)]
            if A.particle is not None and B.particle is not None:
                continue
            # assert( A.particle is None )
            # assert( B.particle is None )

            ## TODO: offload the work here to s1/s2 (?)
            a1,a2 = [l.address   for l in (A,B)]

            if A.particle is None:
                b = s1.get_nearest_bead(a1)
                if b is not None and s1.contour_to_nt_pos(np.abs(b.contour_position-a1)) < 1.9:
                    ## combine beads
                    b.contour_position = 0.5*(b.contour_position + a1) # avg position
                    A.particle = b
                else:
                    A.particle = s1._generate_one_bead(a1,0)

            if B.particle is None:
                b = s2.get_nearest_bead(a2)
                if b is not None and s2.contour_to_nt_pos(np.abs(b.contour_position-a2)) < 1.9:
                    ## combine beads
                    b.contour_position = 0.5*(b.contour_position + a2) # avg position
                    B.particle = b
                else:
                    B.particle = s2._generate_one_bead(a2,0)

        """ Some tests """
        for c,A,B in self.get_connections("intrahelical"):
            for l in (A,B):
                if l.particle is None: continue
                assert( l.particle.parent is not None )

        """ Generate beads in between """
        if self.DEBUG: print("Generating beads")
        for s in segments:
            s._generate_beads( self, max_basepairs_per_bead, max_nucleotides_per_bead )

        # """ Combine beads at junctions as needed """
        # for c,A,B in self.get_connections():
        #    ...

        """ Add intrahelical neighbors at connections """
        for c,A,B in self.get_connections("intrahelical"):
            b1,b2 = [l.particle for l in (A,B)]
            if b1 is b2:
                ## already handled by Segment._generate_beads
                continue
            else:
                for b in (b1,b2): assert( b is not None )
                if b2 not in b1.intrahelical_neighbors:
                    b1.intrahelical_neighbors.append(b2)
                    b2.intrahelical_neighbors.append(b1)
                assert(len(b1.intrahelical_neighbors) <= 2)
                assert(len(b2.intrahelical_neighbors) <= 2)

        """ Reassign bead types """
        if self.DEBUG: print("Assigning bead types")
        beadtype_s = dict()
        for segment in segments:
            for b in segment:
                b.num_nts = np.around( b.num_nts, decimals=1 )
                key = (b.type_.name[0].upper(), b.num_nts)
                if key in beadtype_s:
                    b.type_ = beadtype_s[key]
                else:
                    t = deepcopy(b.type_)
                    if key[0] == "D":
                        t.__dict__["nts"] = b.num_nts*2
                    elif key[0] == "S":
                        t.__dict__["nts"] = b.num_nts
                    elif key[0] == "O":
                        t.__dict__["nts"] = b.num_nts
                    else:
                        raise Exception("TODO")
                    # print(t.nts)
                    t.name = t.name + "%03d" % (10*t.nts)
                    beadtype_s[key] = b.type_ = t

        """ Update bead indices """
        self._countParticleTypes() # probably not needed here
        self._updateParticleOrder()

        """ Add intrahelical bond potentials """
        if self.DEBUG: print("Adding intrahelical bond potentials")
        dists = dict()          # intrahelical distances built for later use
        intra_beads = self._get_intrahelical_beads() 
        if self.DEBUG: print("  Adding %d bonds" % len(intra_beads))
        for b1,b2 in intra_beads:
            parent = self._getParent(b1,b2)

            ## TODO: could be sligtly smarter about sep
            sep = 0.5*(b1.num_nts+b2.num_nts)

            conversion = 0.014393265 # units "pN/AA" kcal_mol/AA^2
            if b1.type_.name[0] == "D" and b2.type_.name[0] == "D":
                elastic_modulus = 1000 # pN http://markolab.bmbcb.northwestern.edu/marko/Cocco.CRP.02.pdf
                d = 3.4*sep
                k = conversion*elastic_modulus/d
            else:
                ## TODO: get better numbers our ssDNA model
                elastic_modulus = 800 # pN http://markolab.bmbcb.northwestern.edu/marko/Cocco.CRP.02.pdf
                d = 5*sep
                k = conversion*elastic_modulus/d
                # print(sep,d,k)
              
            if b1 not in dists:
                dists[b1] = dict()
            if b2 not in dists:
                dists[b2] = dict()
            # dists[b1].append([b2,sep])
            # dists[b2].append([b1,sep])
            dists[b1][b2] = sep
            dists[b2][b1] = sep

            # dists[[b1,b2]] = dists[[b2,b1]] = sep
            bond = self.get_bond_potential(k,d)
            parent.add_bond( b1, b2, bond, exclude=True )

        """ Add intrahelical angle potentials """
        if self.DEBUG: print("Adding intrahelical angle potentials")
        for b1,b2,b3 in self._get_intrahelical_angle_beads():
            ## TODO: could be slightly smarter about sep
            sep = 0.5*b1.num_nts+b2.num_nts+0.5*b3.num_nts
            parent = self._getParent(b1,b2,b3)

            kT = 0.58622522         # kcal/mol
            if b1.type_.name[0] == "D" and b2.type_.name[0] == "D" and b3.type_.name[0] == "D":
                ## <cos(q)> = exp(-s/Lp) = integrate( x^4 exp(-A x^2) / 2, {x, 0, pi} ) / integrate( x^2 exp(-A x^2), {x, 0, pi} )
                ## <cos(q)> ~ 1 - 3/4A                                                                                            
                ## where A = k_spring / (2 kT)                                                                                    
                k = 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/147))) * 0.00030461742; # kcal_mol/degree^2
                if local_twist:
                    ## TODO optimize this paramter
                    k *= 0.5    # halve because orientation beads have similar springs
            else:
                ## TODO: get correct number from ssDNA model
                k = 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/3))) * 0.00030461742; # kcal_mol/degree^2

            angle = self.get_angle_potential(k,180)
            parent.add_angle( b1, b2, b3, angle )

        """ Add intrahelical exclusions """
        if self.DEBUG: print("Adding intrahelical exclusions")
        beads = dists.keys()
        def _recursively_get_beads_within(b1,d,done=[]):
            ret = []
            for b2,sep in dists[b1].items():
                if b2 in done: continue
                if sep < d:
                    ret.append( b2 )
                    done.append( b2 )
                    tmp = _recursively_get_beads_within(b2, d-sep, done)
                    if len(tmp) > 0: ret.extend(tmp)
            return ret

        exclusions = set()
        for b1 in beads:
            exclusions.update( [(b1,b) for b in _recursively_get_beads_within(b1, 20, done=[b1])] )

        if self.DEBUG: print("Adding %d exclusions" % len(exclusions))
        for b1,b2 in exclusions:
            parent = self._getParent(b1,b2)
            parent.add_exclusion( b1, b2 )

        """ Twist potentials """
        if local_twist:
            if self.DEBUG: print("Adding twist potentials")

            for b1 in beads:
                if "orientation_bead" not in b1.__dict__: continue
                for b2,sep in dists[b1].items():
                    if "orientation_bead" not in b2.__dict__: continue
                    if b2.idx < b1.idx: continue # Don't double-count

                    p1,p2 = [b.parent for b in (b1,b2)]
                    o1,o2 = [b.orientation_bead for b in (b1,b2)]

                    parent = self._getParent( b1, b2 )

                    """ Add heuristic 90 degree potential to keep orientation bead orthogonal """
                    k = (1.0/2) * 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/147))) * 0.00030461742; # kcal_mol/degree^2
                    pot = self.get_angle_potential(k,90)
                    parent.add_angle(o1,b1,b2, pot)
                    parent.add_angle(b1,b2,o2, pot)
                                        
                    ## TODO: improve this
                    twist_per_nt = 0.5 * (p1.twist_per_nt + p2.twist_per_nt)
                    angle = sep*twist_per_nt
                    if angle > 360 or angle < -360:
                        print("WARNING: twist angle out of normal range... proceeding anyway")
                        # raise Exception("The twist between beads is too large")
                        
                    k = self._get_twist_spring_constant(sep)
                    if escapable_twist:
                        pot = self.get_dihedral_potential(k,angle,max_potential=1)
                    else:
                        pot = self.get_dihedral_potential(k,angle)
                    parent.add_dihedral(o1,b1,b2,o2, pot)

        """ Add connection potentials """
        for c,A,B in self.get_connections("terminal_crossover"):
            ## TODO: use a better description here
            b1,b2 = [loc.particle for loc in (c.A,c.B)]
            pot = self.get_bond_potential(4,18.5)
            self.add_bond(b1,b2, pot)

        
        crossover_bead_pots = set()
        for c,A,B in self.get_connections("crossover"):
            b1,b2 = [loc.particle for loc in (c.A,c.B)]

            ## Avoid double-counting
            if (b1,b2) in crossover_bead_pots: continue
            crossover_bead_pots.add((b1,b2))
            crossover_bead_pots.add((b2,b1))
                
            pot = self.get_bond_potential(4,18.5)
            self.add_bond(b1,b2, pot)


            ## Get beads above and below
            u1,u2 = [b.get_intrahelical_above() for b in (b1,b2)]
            d1,d2 = [b.get_intrahelical_below() for b in (b1,b2)]


            k_fn = lambda sep: (1.0/2) * 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/147))) * 0.00030461742; # kcal_mol/degree^2
            if u1 is not None and u2 is not None:
                t0 = 0
                k = k_fn( 0.5*(dists[b1][u1]+dists[b2][u2]) )
                pot = self.get_dihedral_potential(k,t0)
                self.add_dihedral( u1,b1,b2,u2,  pot )
            elif d1 is not None and d2 is not None:
                t0 = 0
                k = k_fn( 0.5*(dists[b1][d1]+dists[b2][d2]) )
                pot = self.get_dihedral_potential(k,t0)
                self.add_dihedral( d1,b1,b2,d2, pot )
            elif d1 is not None and u2 is not None:
                t0 = 180
                k = k_fn( 0.5*(dists[b1][d1]+dists[b2][u2]) )
                pot = self.get_dihedral_potential(k,t0)
                self.add_dihedral( d1,b1,b2,u2, pot )
            elif u1 is not None and d2 is not None:
                t0 = 180
                k = k_fn( 0.5*(dists[b1][u1]+dists[b2][d2]) )
                pot = self.get_dihedral_potential(k,t0)
                self.add_dihedral( u1,b1,b2,d2, pot )


            if local_twist:
                k = (1.0/2) * 1.5 * kT * (1.0 / (1-np.exp(-float(1)/147))) * 0.00030461742; # kcal_mol/degree^2
                if 'orientation_bead' in b1.__dict__:
                    # t0 = 90 + 60
                    t0 = 150
                    if A.on_fwd_strand: t0 = 30 # TODO handle antiparallel segments
                    o = b1.orientation_bead
                    pot = self.get_angle_potential(k,t0)
                    self.add_angle( o,b1,b2, pot )
                else:
                    t0 = 150
                    if B.on_fwd_strand: t0 = 30
                    o = b2.orientation_bead
                    pot = self.get_angle_potential(k,t0)
                    self.add_angle( b1,b2,o, pot )


                t0 = 90
                if 'orientation_bead' in b1.__dict__:
                    o1 = b1.orientation_bead
                    if u2 is not None:
                        k = k_fn( dists[b2][u2] )
                        pot = self.get_dihedral_potential(k,t0)
                        self.add_dihedral( o1,b1,b2,u2, pot )
                    elif d2 is not None:
                        k = k_fn( dists[b2][d2] )
                        pot = self.get_dihedral_potential(k,t0)
                        self.add_dihedral( o1,b1,b2,d2, pot )
                if 'orientation_bead' in b2.__dict__:
                    o2 = b2.orientation_bead
                    if u1 is not None:
                        k = k_fn( dists[b1][u1] )
                        pot = self.get_dihedral_potential(k,t0)
                        self.add_dihedral( o2,b2,b1,u1, pot )
                    elif d1 is not None:
                        k = k_fn( dists[b1][d1] )
                        pot = self.get_dihedral_potential(k,t0)
                        self.add_dihedral( o2,b2,b1,d1, pot )

            
        ## TODOTODO check that this works
        for crossovers in self.get_consecutive_crossovers():
            ## filter crossovers
            new_cl = []
            lastParticle = None
            for cl in crossovers:
                c,A,B,d = cl
                if A.particle is not lastParticle:
                    new_cl.append(cl)
                    lastParticle = A.particle
            crossovers = new_cl
            for i in range(len(crossovers)-2):
                c1,A1,B1,dir1 = crossovers[i]
                c2,A2,B2,dir2 = crossovers[i+1]
                s1,s2 = [l.container for l in (A1,A2)]
                sep = A1.particle.get_nt_position(s1) - A2.particle.get_nt_position(s2)
                sep = np.abs(sep)

                n1,n2,n3,n4 = (B1.particle, A1.particle, A2.particle, B2.particle)

                ## <cos(q)> = exp(-s/Lp) = integrate( cos[x] exp(-A x^2), {x, 0, pi} ) / integrate( exp(-A x^2), {x, 0, pi} )
                ##   Assume A is small
                ## int[B_] :=  Normal[Integrate[ Series[Cos[x] Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]/
                ##             Integrate[Series[Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]]
                ## Actually, without assumptions I get fitFun below

                ## From http://www.annualreviews.org/doi/pdf/10.1146/annurev.bb.17.060188.001405
                ##   units "3e-19 erg cm/ 295 k K" "nm" =~ 73
                Lp = s1.twist_persistence_length/0.34  # set semi-arbitrarily as there is a large spread in literature

                fitFun = lambda x: np.real(erf( (4*np.pi*x + 1j)/(2*np.sqrt(x)) )) * np.exp(-1/(4*x)) / erf(2*np.sqrt(x)*np.pi) - np.exp(-sep/Lp)
                k = opt.leastsq( fitFun, x0=np.exp(-sep/Lp) )
                k = k[0][0] * 2*kT*0.00030461742
                        
                t0 = sep*s1.twist_per_nt # TODO weighted avg between s1 and s2
                
                # pdb.set_trace()
                if A1.on_fwd_strand: t0 -= 120
                if dir1 != dir2:
                    A2_on_fwd = not A2.on_fwd_strand
                else:
                    A2_on_fwd = A2.on_fwd_strand
                if A2_on_fwd: t0 += 120
   
                # t0 = (t0 % 360

                # if n2.idx == 0:
                #     print( n1.idx,n2.idx,n3.idx,n4.idx,k,t0,sep )
                pot = self.get_dihedral_potential(k,t0)
                self.add_dihedral( n1,n2,n3,n4, pot )
                

    def walk_through_helices(segment, direction=1, processed_segments=None):
        """
    
        First and last segment should be same for circular helices
        """

        assert( direction in (1,-1) )
        if processed_segments == None:
            processed_segments = set()

        def segment_is_new_helix(s):
            return isinstance(s,DoubleStrandedSegment) and s not in processed_segments

        new_s = None
        s = segment
        ## iterate intrahelically connected dsDNA segments
        while segment_is_new_helix(s):
            conn_locs = s.get_contour_sorted_connections_and_locations("intrahelical")[::direction]
            processed_segments.add(new_s)        
            new_s = None
            new_dir = None
            for i in range(len(conn_locs)):
                c,A,B = conn_locs[i]
                ## TODO: handle change of direction
                address = 1*(direction==-1)
                if A.address == address and segment_is_new_helix(B.container):
                    new_s = B.container
                    assert(B.address in (0,1))
                    new_dir = 2*(B.address == 0) - 1
                    break

            yield s,direction
            s = new_s   # will break if None
            direction = new_dir
            
        #     if new_s is None:
        #         break
        #     else:
        #         s = new_s
        # yield s
        ## return s


    def get_consecutive_crossovers(self):
        ## TODOTODO TEST
        crossovers = []
        processed_segments = set()
        for s1 in self.segments:
            if not isinstance(s1,DoubleStrandedSegment):
                continue
            if s1 in processed_segments: continue

            s0,d0 = list(SegmentModel.walk_through_helices(s1,direction=-1))[-1]

            # s,direction = get_start_of_helices()
            tmp = []
            for s,d in SegmentModel.walk_through_helices(s0,-d0):
                if s == s0 and len(tmp) > 0:
                    ## end of circular helix, only add first crossover
                    tmp.append( s.get_contour_sorted_connections_and_locations("crossover")[::d][0] + [d] )
                else:
                    tmp.extend( [cl + [d] for cl in s.get_contour_sorted_connections_and_locations("crossover")[::d]] )
                processed_segments.add(s)
            crossovers.append(tmp)
        return crossovers

    def _generate_strands(self):
        self.strands = strands = []

        """ Ensure unconnected ends have 5prime Location objects """
        for seg in self.segments:
            ## TODO move into Segment calls
            import pdb
            if seg.start5.connection is None:
                add_end = True
                for l in seg.get_locations("5prime"):
                    if l.address == 0 and l.on_fwd_strand:
                        add_end = False
                        break
                if add_end:
                    seg.add_5prime(0) 
            if 'end5' in seg.__dict__ and seg.end5.connection is None:
                add_end = True
                for l in seg.get_locations("5prime"):
                    if l.address == 1 and (l.on_fwd_strand is False):
                        add_end = False
                        break
                if add_end:
                    seg.add_5prime(seg.num_nts-1,on_fwd_strand=False)

            if 'start3' in seg.__dict__ and seg.start3.connection is None:
                add_end = True
                for l in seg.get_locations("3prime"):
                    if l.address == 0 and (l.on_fwd_strand is False):
                        add_end = False
                        break
                if add_end:
                    seg.add_3prime(0,on_fwd_strand=False)
            if seg.end3.connection is None:
                add_end = True
                for l in seg.get_locations("3prime"):
                    if l.address == 1 and l.on_fwd_strand:
                        add_end = False
                        break
                if add_end:
                    seg.add_3prime(seg.num_nts-1)

            # print( [(l,l.get_connected_location()) for l in seg.locations] )
            # addresses = np.array([l.address for l in seg.get_locations("5prime")])
            # if not np.any( addresses == 0 ):
            #     ## check if end is connected
        # for c,l,B in self.get_connections_and_locations():
        #     if c[0]

        """ Build strands from connectivity of helices """
        def _recursively_build_strand(strand, segment, pos, is_fwd, mycounter=0, move_at_least=0.5):
            mycounter+=1
            if mycounter > 1000:
                import pdb
                pdb.set_trace()
            s,seg = [strand, segment]

            end_pos, next_seg, next_pos, next_dir, move_at_least = seg.get_strand_segment(pos, is_fwd, move_at_least)
            s.add_dna(seg, pos, end_pos, is_fwd)

            if next_seg is not None:
                # print("  next_dir: {}".format(next_dir))
                _recursively_build_strand(s, next_seg, next_pos, next_dir, mycounter, move_at_least)

        for seg in self.segments:
            locs = seg.get_5prime_locations()
            if locs is None: continue
            # for pos, is_fwd in locs:
            for l in locs:
                print("Tracing",l)
                pos = seg.contour_to_nt_pos(l.address, round_nt=True)
                is_fwd = l.on_fwd_strand
                s = Strand()
                _recursively_build_strand(s, seg, pos, is_fwd)
                # print("{} {}".format(seg.name,s.num_nts))
                strands.append(s)
        
        self.strands = sorted(strands, key=lambda s:s.num_nts)[::-1] # or something
        ## relabel segname
        counter = 0
        for s in self.strands:
            if s.segname is None:
                s.segname = "D%03d" % counter
            counter += 1

    def _update_orientations(self,orientation):
        for s in self.strands:
            s.update_atomic_orientations(orientation)


    def _generate_atomic_model(self, scale=1):
        self.children = self.strands
        for s in self.strands:
            s.generate_atomic_model(scale)
        return

        ## Angle optimization
        angles = np.linspace(-180,180,180)
        score = []
        for a in angles:
            o = rotationAboutAxis([0,0,1], a)
            sum2 = count = 0
            for s in self.strands:
                s.update_atomic_orientations(o)
                for s1,s2 in zip(s.strand_segments[:-1],s.strand_segments[1:]):
                    nt1 = s1.children[-1]
                    nt2 = s2.children[0]
                    o3 = nt1.atoms_by_name["O3'"]
                    p = nt2.atoms_by_name["P"]
                    sum2 += np.sum((p.collapsedPosition()-o3.collapsedPosition())**2)
                    count += 1
            score.append(sum2/count)
        print(angles[np.argmin(score)])
        print(score)