Skip to content
Snippets Groups Projects
segmentmodel.py 32.61 KiB
import numpy as np
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme

from scipy.special import erf
import scipy.optimize as opt
from scipy import interpolate

import types

"""
TODO:
 - document
 - handle crossovers
    - connections in the middle of a segment?
    - merge beads at ends of connected helices?
 - map to atomic representation
 - remove performance bottlenecks
 - test for large systems
 - assign sequence
"""

class Location():
    """ Site for connection within an object """
    def __init__(self, container, address, type_):
        self.container = container
        self.address = address
        self.type_ = type_
        self.particle = None

class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
    def __init__(self, connections=[]):
        self.connections = connections

    def get_connections_and_locations(self, type_=None):
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
        ret = []
        for c in self.connections:
            if type_ is None or c.type_ == type_:
                if   c.A.container == self:
                    ret.append( [c, c.A, c.B] )
                elif c.B.container == self:
                    ret.append( [c, c.B, c.A] )
                else:
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

    def _connect(self, other, connection):
        self.connections.append(connection)
        other.connections.append(connection)
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
       
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )

    # orientation_bond = HarmonicBond(10,2)
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
        ConnectableElement.__init__(self, connections=[])

        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

        self.num_nts = int(num_nts)
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

        ## Set up interpolation for positions
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck


    def contour_to_position(self,s):
        p = interpolate.splev( s, self.position_spline_params )
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
        t = interpolate.splev( s, self.position_spline_params, der=1 )
        if len(t) > 1: t = np.array(t).T
        return t
        

    def contour_to_orientation(self,s):
        if self.start_orientation is not None:
            # axis = self.start_orientation.dot( np.array((0,0,1)) )
            if self.quaternion_spline_params is None:
                axis = self.contour_to_tangent(s)
                # orientation = rotationAboutAxis( axis, s*self.twist_per_nt*self.num_nts, normalizeAxis=False )
                orientation = rotationAboutAxis( axis, s*self.twist_per_nt*self.num_nts, normalizeAxis=False )
                ## TODO: ensure this is correct
                # orientation = self.start_orientation.dot(orientation) # .dot( self.start_orientation )
                orientation = orientation.dot( self.start_orientation )
            else:
                q = interpolate.splev(s, self.quaternion_spline_params)
                orientation = quaternion_to_matrix(q)
        else:
            orientation = None
        return orientation


    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

    def _generate_one_bead(self, pos, nts, orientation=None):
        raise NotImplementedError

    def _assign_particles_to_locations(self):
        raise NotImplementedError

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
        ## Assume that consecutive beads in self.beads are bonded
        ret = []
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
            ret.append( tmp )
        return ret

    def get_beads_before_bead(self, bead, number, inclusive=False):
        ## Assume that consecutive beads in self.beads are bonded
        i = self.beads.index(bead)
        l = len(self.beads)
        if i-number < 0:
            raise Exception("Not enough beads after bead")
        
        start = 1
        if inclusive: start = 0
        return [self.beads[i-j] for j in range(start,number)]

    def get_beads_after_bead(self, bead, number, inclusive=False):
        ## Assume that consecutive beads in self.beads are bonded
        i = self.beads.index(bead)
        l = len(self.beads)
        if i+number >= l:
            raise Exception("Not enough beads after bead")
        
        start = 1
        if inclusive: start = 0
        return [self.beads[i+i] for j in range(start,number)]

    # def get_bead_pairs_within(self, cutoff):
    #     for b1,b2 in self.get_all_consecutive_beads(self, number)


    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
        
        """ Generate beads (positions, types, etcl) and bonds, angles, dihedrals, exclusions """

        ## TODO: decide whether to remove bead_model argument
        ##       (currently unused)

        # self._bead_model_generation += 1
        # self._bead_model_max_nts_per_bead = max_nts_per_bead

        num_beads = self._get_num_beads( max_basepairs_per_bead, max_nucleotides_per_bead )
        nts_per_bead = float(self.num_nts)/num_beads
        twist_per_bead = nts_per_bead * self.twist_per_nt

        last = None

        if num_beads <= 2:
            ## not yet implemented for dsDNA
            assert( isinstance(self, SingleStrandedSegment) )
            b = self._generate_one_bead(0.5, 
                                        self.start_position,
                                        self.num_nts,
                                        self.start_orientation)
            self.children.append(b)
            self.beads.append(b) # don't add orientation bead
            self._assign_particles_to_locations()
            return


        for i in range(num_beads+1):
            nts = nts_per_bead
            if i == 0 or i == num_beads: 
                nts *= 0.5

            s = i*float(nts_per_bead)/(self.num_nts) # contour
            pos = self.contour_to_position(s)
            if self.start_orientation is not None:
                axis = self.start_orientation.dot( np.array((0,0,1)) )
                orientation = rotationAboutAxis( axis, s*self.twist_per_nt*self.num_nts, normalizeAxis=False )
                ## TODO: ensure this is correct
                # orientation = self.start_orientation.dot(orientation) # .dot( self.start_orientation )
                orientation = orientation.dot( self.start_orientation )
            else:
                orientation = None

            b = self._generate_one_bead(s,pos,nts,orientation)
            self.children.append(b)
            self.beads.append(b) # don't add orientation bead

            if "orientation_bead" in b.__dict__:
                o = b.orientation_bead
                self.children.append(o)
                self.add_bond(b,o, Segment.orientation_bond, exclude=True)
                
            # if last is not None:
            #     self.add_bond( i=last, j=b, bond="ssdna" )
            # last = b
        self._assign_particles_to_locations()

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...

    def _generate_atomic(self, atomic_model):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
                 local_twist = False,
                 num_turns = None,
                 start_orientation = None):
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.local_twist = local_twist
        if num_turns is None:
            num_turns = float(num_nts) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_nts

        if start_orientation is None:
            start_orientation = np.array(((1,0,0),(0,1,0),(0,0,1)))
        self.start_orientation = start_orientation

        self.nicks = []

        self.start5 = Location( self, address=0, type_= "end5" )
        self.start3 = Location( self, address=0, type_ = "end3" )

        self.end5 = Location( self, address=-1, type_= "end5" )
        self.end3 = Location( self, address=-1, type_ = "end3" )

        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        ## TODO: initialize sensible spline for orientation
        self.quaternion_spline_params = None


    ## Convenience methods
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.end5
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
        if isinstance(end5, SingleStrandedSegment):
            end5 = end5.end5
        self._connect_ends( self.end3, end5, type_, force_connection = force_connection )
    def connect_end5(self, end3, type_="intrahelical", force_connection=False):
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
        self._connect_ends( self.end5, end3, type_, force_connection = force_connection )

        
    ## Real work
    def _connect_ends(self, end1, end2, type_, force_connection):

        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )

        end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ) )

    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead):
        return (self.num_nts // max_basepairs_per_bead) + 1

    def _generate_one_bead(self, contour_position, pos, nts, orientation=None):
        if self.local_twist:
            assert(orientation is not None)
            # opos = pos + np.array((2,0,0)).dot(orientation)
            opos = pos + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
            o = PointParticle( Segment.orientation_particle, opos, nts,
                               num_nts=nts, parent=self )
            bead = PointParticle( Segment.dsDNA_particle, pos, nts,
                                  num_nts=nts, parent=self, 
                                  orientation_bead=o,
                                  contour_position=contour_position )

        else:
            bead = PointParticle( Segment.dsDNA_particle, pos, nts,
                                  num_nts=nts, parent=self,
                                  contour_position=contour_position )
        return bead


    def _assign_particles_to_locations(self):
        self.start3.particle =  self.start5.particle = self.beads[0]
        self.end3.particle   =  self.end5.particle   = self.beads[-1]

    def _generate_atomic(self, atomic_model):
        ...
    
        
    # def add_crossover(self, locationInA, B, locationInB):
    #     j = Crossover( [self, B], [locationInA, locationInB] )
    #     self._join(B,j)

    # def add_internal_crossover(self, locationInA, B, locationInB):
    #     j = Crossover( [self, B], [locationInA, locationInB] )
    #     self._join(B,j)


    # def stack_end(self, myEnd):
    #     ## Perhaps this should not really be possible; these ends should be part of same helix
    #     ...

    # def connect_strand(self, other):
    #     ...
        
    # def break_apart(self):
    #     """Break into smaller pieces so that "crossovers" are only at the ends"""
    #     ...

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

        self.start = self.end5 = Location( self, address=0, type_= "end5" )
        self.end = self.end3 = Location( self, address=-1, type_ = "end3" )

    def connect_3end(self, end5, force_connection=False):
        self._connect_end( end5,  _5_to_3 = False, force_connection = force_connection )

    def connect_5end(self, end3, force_connection=False):
        self._connect_end( end3,  _5_to_3 = True, force_connection = force_connection )

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
            my_end = self.end5
            assert( other.type_ == "end3" )
        else:
            my_end = self.end3
            assert( other.type_ == "end5" )

        self._connect( other.container, Connection( my_end, other, type_="intrahelical" ) )

    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead):
        # if (self.num_nts // max_nucleotides_per_bead) + 1 <= 1:
        #     pdb.set_trace()
        return (self.num_nts // max_nucleotides_per_bead) + 1

    def _generate_one_bead(self, contour_position, pos, nts, orientation=None):
        return PointParticle( Segment.ssDNA_particle, pos, nts,
                              num_nts=nts, parent=self,
                              contour_position=contour_position )

    def _assign_particles_to_locations(self):
        self.start.particle = self.children[0]
        self.end.particle = self.children[-1]

    def _generate_atomic(self, atomic_model):
        ...
    

class SegmentModel(ArbdModel):
    def __init__(self, segments=[], local_twist=True,
                 max_basepairs_per_bead=7,
                 max_nucleotides_per_bead=4,
                 dimensions=(1000,1000,1000), temperature=291,
                 timestep=50e-6, cutoff=50, 
                 decompPeriod=10000, pairlistDistance=None, 
                 nonbondedResolution=0,DEBUG=0):
        self.DEBUG = DEBUG
        if DEBUG > 0: print("Building ARBD Model")
        ArbdModel.__init__(self,segments,
                           dimensions, temperature, timestep, cutoff, 
                           decompPeriod, pairlistDistance=None,
                           nonbondedResolution=0)

        # self.max_basepairs_per_bead = max_basepairs_per_bead     # dsDNA
        # self.max_nucleotides_per_bead = max_nucleotides_per_bead # ssDNA
        self.children = self.segments = segments

        self._bonded_potential = dict() # cache bonded potentials

        self._generate_bead_model( max_basepairs_per_bead, max_nucleotides_per_bead, local_twist)

        self.useNonbondedScheme( nbDnaScheme )


    def get_connections(self,type_=None):
        """ Find all connections in model, without double-counting """
        added=set()
        ret=[]
        for s in self.segments:
            items = [e for e in s.get_connections_and_locations(type_) if e[0] not in added]
            added.update([e[0] for e in items])
            ret.extend( items )
        return ret
    
    def _get_intrahelical_beads(self):
        ret = []
        for s in self.segments:
            ret.extend( s.get_all_consecutive_beads(2) )

        for c,A,B in self.get_connections("intrahelical"):
            # TODO: check that b1,b2 not same
            b1,b2 = [l.particle for l in (A,B)]
            for b in (b1,b2): assert( b is not None )
            ret.append( [b1,b2] )
        return ret

    def _get_intrahelical_angle_beads(self):
        ret = []
        for s in self.segments:
            ret.extend( s.get_all_consecutive_beads(3) )

        for c,A,B in self.get_connections("intrahelical"):
            s1,s2 = [loc.container for loc in (A,B)]
            b1,b2 = [loc.particle  for loc in (A,B)]
            for b in (b1,b2): assert( b is not None )
            ## TODO: make this code more robust
            try:
                b0 = s1.get_beads_before_bead(b1,1)
                assert(len(b0) == 1)
                b0 = b0[0]
                assert( b0 is not None )
                ret.append( [b0,b1,b2] )
            except:
                ...
            try:
                b0 = s1.get_beads_after_bead(b1,1)
                assert(len(b0) == 1)
                b0 = b0[0]
                assert( b0 is not None )
                ret.append( [b2,b1,b0] )
            except:
                ...
            try:
                b3 = s2.get_beads_before_bead(b2,1)
                assert(len(b3) == 1)
                b3 = b3[0]
                assert( b3 is not None )
                ret.append( [b3,b2,b1] )
            except:
                ...
            try:
                b3 = s2.get_beads_after_bead(b2,1)
                assert(len(b3) == 1)
                b3 = b3[0]
                assert( b3 is not None )
                ret.append( [b1,b2,b3] )
            except:
                ...
        return ret

    # def _get_intrahelical_bead_pairs_within(self, cutoff):
        
    #     dist = dict()
    #     for b1,b2 in self._get_intrahelical_beads:
    #         dist(b1,b2)


    #     ret = []
    #     for s in self.segments:
    #         ret.extend( s.get_bead_pairs_within(cutoff) )

    #     for s1 in self.segments:
    #         for c in s1.connections:
    #             if c.A.container != s1: continue
    #             s2 = c.B.container
    #             if c.type_ == "intrahelical":

    #     ret

    def _get_potential(self, type_, kSpring, d, max_potential = None):
        key = (type_,kSpring,d)
        if key not in self._bonded_potential:
            if type_ == "bond":
                self._bonded_potential[key] = HarmonicBond(kSpring,d, rRange=(0,500), max_potential=max_potential)
            elif type_ == "angle":
                self._bonded_potential[key] = HarmonicAngle(kSpring,d, max_potential=max_potential)
                # , resolution = 1, maxForce=0.1)
            elif type_ == "dihedral":
                self._bonded_potential[key] = HarmonicDihedral(kSpring,d, max_potential=max_potential)
            else:
                raise Exception("Unhandled potential type '%s'" % type_)
        return self._bonded_potential[key]
    def get_bond_potential(self, kSpring, d):
        return self._get_potential("bond", kSpring, d)
    def get_angle_potential(self, kSpring, d):
        return self._get_potential("angle", kSpring, d)
    def get_dihedral_potential(self, kSpring, d, max_potential=None):
        while d > 180: d-=360
        while d < -180: d+=360
        return self._get_potential("dihedral", kSpring, d, max_potential)


    def _getParent(self, *beads ):
        ## TODO: test
        if np.all( [b1.parent == b2.parent 
                    for b1,b2 in zip(beads[::2],beads[1::2])] ):
            return beads[0].parent
        else:
            return self

    def _get_twist_spring_constant(self, sep):
        """ sep in nt """
        kT = 0.58622522         # kcal/mol
        twist_persistence_length = 90  # set semi-arbitrarily as there is a large spread in literature
        ## <cos(q)> = exp(-s/Lp) = integrate( cos[x] exp(-A x^2), {x, 0, pi} ) / integrate( exp(-A x^2), {x, 0, pi} )
        ##   Assume A is small
        ## int[B_] :=  Normal[Integrate[ Series[Cos[x] Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]/
        ##             Integrate[Series[Exp[-B x^2], {B, 0, 1}], {x, 0, \[Pi]}]]

        ## Actually, without assumptions I get fitFun below
        ## From http://www.annualreviews.org/doi/pdf/10.1146/annurev.bb.17.060188.001405
        ##   units "3e-19 erg cm/ 295 k K" "nm" =~ 73
        Lp = twist_persistence_length/0.34 

        fitFun = lambda x: np.real(erf( (4*np.pi*x + 1j)/(2*np.sqrt(x)) )) * np.exp(-1/(4*x)) / erf(2*np.sqrt(x)*np.pi) - np.exp(-sep/Lp)
        k = opt.leastsq( fitFun, x0=np.exp(-sep/Lp) )
        return k[0][0] * 2*kT*0.00030461742

    # def _update_segment_positions(self, bead_coordinates):
    #     """ Set new function for each segments functions
    #     contour_to_position and contour_to_orientation """
        
    #     dsDnaHelixNeighborDist=50
    #     dsDnaAllNeighborDist=30
    #     ssDnaHelixNeighborDist=25
    #     ssDnaAllNeighborDist=25
        
    #     beads = [b in s.beads for s in self.segments]
    #     positions = np.array([b.position for b in beads])
    #     neighborhood = dict()

    #     ## Assign neighborhood to each bead
    #     for b in beads:
    #         dists = b.position[np.newaxis,:] - positions
    #         dists = np.linalg.norm(dists, axis=-1)
    #         neighborhood[b] = np.where( dists < 50 )

    """ Mapping between different resolution models """
    def _clear_beads(self):
        for s in self.segments:
            s.clear_all()
            s.beads = []
        self.clear_all(keep_children=True)

    def _update_segment_positions(self, bead_coordinates):
        """ Set new function for each segments functions
        contour_to_position and contour_to_orientation """

        for s in self.segments:
            beads = [b for b in s.beads]
            ids = [b.idx for b in beads]
            
            """ Get positions """
            positions = bead_coordinates[ids,:].T
            contours = [b.contour_position for b in beads]
            tck, u = interpolate.splprep( positions, u=contours, s=0, )
            
            s.position_spline_params = tck

            """ Get twist """
            if 'orientation_bead' in beads[0].__dict__:
                tangents = s.contour_to_tangent(contours)
                quats = []
                for b,t in zip(beads,tangents):
                    o = b.orientation_bead
                    angleVec = o.position - b.position
                    angleVec = angleVec - angleVec.dot(t)*t
                    angleVec = angleVec/np.linalg.norm(angleVec)
                    y = np.cross(t,angleVec)
                    quats.append( quaternion_from_matrix( np.array([t,y,angleVec])) )
                quats = np.array(quats)
                tck, u = interpolate.splprep( quats.T, u=contours, s=0, )
                s.quaternion_spline_params = tck


            ## TODO: set twist

    def _generate_bead_model(self,
                             max_basepairs_per_bead = 7,
                             max_nucleotides_per_bead = 4,
                             local_twist=False):

        segments = self.segments

        """ Generate beads """
        if self.DEBUG: print("Generating beads")
        for s in segments:
            if local_twist:
                s.local_twist = True
            s._generate_beads( self, max_basepairs_per_bead, max_nucleotides_per_bead )

        """ Combine beads at junctions as needed """
        for c,A,B in self.get_connections():
            ...

        """ Reassign bead types """
        if self.DEBUG: print("Assigning bead types")
        beadtype_s = dict()
        for segment in segments:
            for b in segment:
                b.num_nts = np.around( b.num_nts, decimals=1 )
                key = (b.type_.name[0].upper(), b.num_nts)
                if key in beadtype_s:
                    b.type_ = beadtype_s[key]
                else:
                    t = deepcopy(b.type_)
                    if key[0] == "D":
                        t.__dict__["nts"] = b.num_nts*2
                    elif key[0] == "S":
                        t.__dict__["nts"] = b.num_nts
                    elif key[0] == "O":
                        t.__dict__["nts"] = b.num_nts
                    else:
                        raise Exception("TODO")
                    # print(t.nts)
                    t.name = t.name + "%03d" % (10*t.nts)
                    beadtype_s[key] = b.type_ = t


        """ Add intrahelical bond potentials """
        if self.DEBUG: print("Adding intrahelical bond potentials")
        dists = dict()          # built for later use
        intra_beads = self._get_intrahelical_beads() 
        if self.DEBUG: print("  Adding %d bonds" % len(intra_beads))
        for b1,b2 in intra_beads:
            parent = self._getParent(b1,b2)
            if b1.parent == b2.parent:
                sep = 0.5*(b1.num_nts+b2.num_nts)
            else:
                sep = 1

            conversion = 0.014393265 # units "pN/AA" kcal_mol/AA^2
            if b1.type_.name[0] == "D" and b2.type_.name[0] == "D":
                elastic_modulus = 1000 # pN http://markolab.bmbcb.northwestern.edu/marko/Cocco.CRP.02.pdf
                d = 3.4*sep
                k = conversion*elastic_modulus/d
            else:
                ## TODO: get better numbers our ssDNA model
                elastic_modulus = 800 # pN http://markolab.bmbcb.northwestern.edu/marko/Cocco.CRP.02.pdf
                d = 5*sep
                k = conversion*elastic_modulus/d
                # print(sep,d,k)
              
            if b1 not in dists:
                dists[b1] = []
            if b2 not in dists:
                dists[b2] = []
            dists[b1].append([b2,sep])
            dists[b2].append([b1,sep])


            # if not (b1.type_.name[0] == "D" and b2.type_.name[0] == "D"):
            #     continue 

            # dists[[b1,b2]] = dists[[b2,b1]] = sep
            bond = self.get_bond_potential(k,d)
            parent.add_bond( b1, b2, bond, exclude=True )

        """ Add intrahelical angle potentials """
        if self.DEBUG: print("Adding intrahelical angle potentials")
        for b1,b2,b3 in self._get_intrahelical_angle_beads():
            
            sep = 0
            if b1.parent == b2.parent:
                sep += 0.5*(b1.num_nts+b2.num_nts)
            else:
                sep += 1
            if b2.parent == b3.parent:
                sep += 0.5*(b2.num_nts+b3.num_nts)
            else:
                sep += 1
                
            parent = self._getParent(b1,b2,b3)

            kT = 0.58622522         # kcal/mol
            if b1.type_.name[0] == "D" and b2.type_.name[0] == "D" and b3.type_.name[0] == "D":
                ## <cos(q)> = exp(-s/Lp) = integrate( x^4 exp(-A x^2) / 2, {x, 0, pi} ) / integrate( x^2 exp(-A x^2), {x, 0, pi} )
                ## <cos(q)> ~ 1 - 3/4A                                                                                            
                ## where A = k_spring / (2 kT)                                                                                    
                k = 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/147))) * 0.00030461742; # kcal_mol/degree^2
                if local_twist:
                    k *= 0.5    # halve because orientation beads have similar springs
            else:
                ## TODO: get correct number from ssDNA model
                k = 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/3))) * 0.00030461742; # kcal_mol/degree^2                
            # k *= 1e-6
            # if (self.num_nts // max_nucleotides_per_bead) + 1 <= 1:
            #     pdb.set_trace()

            angle = self.get_angle_potential(k,180)
            parent.add_angle( b1, b2, b3, angle )

        """ Add intrahelical exclusions """
        if self.DEBUG: print("Adding intrahelical exclusions")
        beads = dists.keys()
        def _recursively_get_beads_within(b1,d,done=[]):
            ret = []
            for b2,sep in dists[b1]:
                if b2 in done: continue
                if sep < d:
                    ret.append( b2 )
                    done.append( b2 )
                    tmp = _recursively_get_beads_within(b2, d-sep, done)
                    if len(tmp) > 0: ret.extend(tmp)
            return ret
            
        exclusions = set()
        for b1 in beads:
            exclusions.update( [(b1,b) for b in _recursively_get_beads_within(b1, 20, done=[b1])] )

        if self.DEBUG: print("Adding %d exclusions" % len(exclusions))
        for b1,b2 in exclusions:
            parent = self._getParent(b1,b2)
            parent.add_exclusion( b1, b2 )

        """ Twist potentials """
        if local_twist:
            if self.DEBUG: print("Adding twist potentials")

            ## TODO: decide whether to add bond here
            # """Add bonds between orientation bead and parent"""
            # for s in self.segments:
            #     for b,o in zip(s.children[::2],s.children[1::2]):
            #         s.add_bond(
                    

            for b1 in beads:
                if "orientation_bead" not in b1.__dict__: continue
                for b2,sep in dists[b1]:
                    if "orientation_bead" not in b2.__dict__: continue

                    p1,p2 = [b.parent for b in (b1,b2)]
                    o1,o2 = [b.orientation_bead for b in (b1,b2)]

                    parent = self._getParent( b1, b2 )
                    

                    """ Add heuristic 90 degree potential to keep orientation bead orthogonal """
                    k = (1.0/2) * 1.5 * kT * (1.0 / (1-np.exp(-float(sep)/147))) * 0.00030461742; # kcal_mol/degree^2
                    pot = self.get_angle_potential(k,90)
                    parent.add_angle(o1,b1,b2, pot)
                    parent.add_angle(b1,b2,o2, pot)
                                        
                    ## TODO: improve this
                    twist_per_nt = 0.5 * (p1.twist_per_nt + p2.twist_per_nt)
                    angle = sep*twist_per_nt
                    if angle > 360 or angle < -360:
                        raise Exception("The twist between beads is too large")
                        
                    k = self._get_twist_spring_constant(sep)
                    pot = self.get_dihedral_potential(k,angle,max_potential=1)
                    parent.add_dihedral(o1,b1,b2,o2, pot)

        """ Add connection potentials """
        for c,A,B in self.get_connections("terminal_crossover"):
            b1,b2 = [loc.particle for loc in (c.A,c.B)]
            pot = self.get_bond_potential(4,18.5)
            self.add_bond(b1,b2, pot)

        self._updateParticleOrder()


    # def get_bead(self, location):
    #     if type(location.container) is not list:
    #         s = self.segments.index(location.container)
    #         s.get_bead(location.address)
    #     else:
    #         r
    #         ...