segmentmodel.py 116 KB
Newer Older
cmaffeo2's avatar
cmaffeo2 committed
1
import pdb
2
from pathlib import Path
3
import numpy as np
4
import random
5
6
7
from .model.arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from .coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
from .model.nonbonded import *
8
from copy import copy, deepcopy
9
from .model.nbPot import nbDnaScheme
10

cmaffeo2's avatar
cmaffeo2 committed
11
12
from scipy.special import erf
import scipy.optimize as opt
13
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
14

15
16
from .model.CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from .model.CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
cmaffeo2's avatar
cmaffeo2 committed
17

18
# import pdb
19
"""
cmaffeo2's avatar
cmaffeo2 committed
20
TODO:
cmaffeo2's avatar
cmaffeo2 committed
21
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
22
 + map to atomic representation
23
    + add nicks
cmaffeo2's avatar
cmaffeo2 committed
24
    + transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
25
    - shrink ssDNA
cmaffeo2's avatar
cmaffeo2 committed
26
    + shrink dsDNA backbone
27
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
28
    + sequence
29
    + handle circular dna
30
 + ensure crossover bead potentials aren't applied twice 
31
 + remove performance bottlenecks
32
 - test for large systems
cmaffeo2's avatar
cmaffeo2 committed
33
 + assign sequence
34
 + ENM
35
36
 - rework Location class 
 - remove recursive calls
37
 - document
38
 - develop unit test suite
39
"""
cmaffeo2's avatar
cmaffeo2 committed
40
41
class CircularDnaError(Exception):
    pass
42

cmaffeo2's avatar
cmaffeo2 committed
43
44
45
class ParticleNotConnectedError(Exception):
    pass

46
47
class Location():
    """ Site for connection within an object """
48
    def __init__(self, container, address, type_, on_fwd_strand = True):
49
        ## TODO: remove cyclic references(?)
50
        self.container = container
cmaffeo2's avatar
cmaffeo2 committed
51
        self.address = address  # represents position along contour length in segment
cmaffeo2's avatar
cmaffeo2 committed
52
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
53
        self.on_fwd_strand = on_fwd_strand
54
55
        self.type_ = type_
        self.particle = None
56
        self.connection = None
57
        self.is_3prime_side_of_connection = None
58

59
60
        self.prev_in_strand = None
        self.next_in_strand = None
61
62
        
        self.combine = None     # some locations might be combined in bead model 
63
64
65
66
67
68
69

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

70
    def set_connection(self, connection, is_3prime_side_of_connection):
71
        self.connection = connection # TODO weakref? 
72
        self.is_3prime_side_of_connection = is_3prime_side_of_connection
73

74
75
76
77
78
79
80
    def get_nt_pos(self):
        try:
            pos = self.container.contour_to_nt_pos(self.address, round_nt=True)
        except:
            if self.address == 0:
                pos = 0
            elif self.address == 1:
81
                pos = self.container.num_nt-1
82
83
84
85
            else:
                raise
        return pos

86
87
88
89
90
91
92
    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
93
94
95
96
97
98
99
100
101
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
102
103
104
105
106
107
108
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
cmaffeo2's avatar
cmaffeo2 committed
109

110
111
112
113
114
115
    def delete(self):
        self.A.container.connections.remove(self)
        self.B.container.connections.remove(self)
        self.A.connection = None
        self.B.connection = None

cmaffeo2's avatar
cmaffeo2 committed
116
117
118
119
    def __repr__(self):
        return "<Connection {}--{}--{}]>".format( self.A, self.type_, self.B )
        

120
        
121
122
123
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
cmaffeo2's avatar
cmaffeo2 committed
124
125
126
127
    def __init__(self, connection_locations=None, connections=None):
        if connection_locations is None: connection_locations = []
        if connections is None: connections = []

128
129
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
130
131
        self.connections = connections

cmaffeo2's avatar
cmaffeo2 committed
132
    def get_locations(self, type_=None, exclude=()):
133
134
135
136
137
138
139
140
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
141
142
143
144
        return locs

    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
        loc = None
145
        if (self.num_nt == 1):
146
147
148
149
150
151
152
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
cmaffeo2's avatar
cmaffeo2 committed
153
            # assert( loc is not None )
154
155
156
157
158
159
160
161
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
        if loc is None:
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc
162

cmaffeo2's avatar
cmaffeo2 committed
163
    def get_connections_and_locations(self, connection_type=None, exclude=()):
164
165
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
166
        type_ = connection_type
167
168
        ret = []
        for c in self.connections:
169
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
170
                if   c.A.container is self:
171
                    ret.append( [c, c.A, c.B] )
172
                elif c.B.container is self:
173
174
                    ret.append( [c, c.B, c.A] )
                else:
175
176
                    import pdb
                    pdb.set_trace()
177
178
179
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

180
    def _connect(self, other, connection, in_3prime_direction=None):
181
182
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
183
184
185
186
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
187
        A.connection = B.connection = connection
188
189
        self.connections.append(connection)
        other.connections.append(connection)
190
191
192
193
194
195
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

196
197
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
198
199

class SegmentParticle(PointParticle):
cmaffeo2's avatar
cmaffeo2 committed
200
    def __init__(self, type_, position, name="A", **kwargs):
201
        self.name = name
202
        self.contour_position = None
cmaffeo2's avatar
cmaffeo2 committed
203
        PointParticle.__init__(self, type_, position, name=name, **kwargs)
204
205
        self.intrahelical_neighbors = []
        self.other_neighbors = []
cmaffeo2's avatar
cmaffeo2 committed
206
        self.locations = []
207
208
209

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
210
        # assert( len(self.intrahelical_neighbors) <= 2 )
211
212
213
214
215
216
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
217
        # assert( len(self.intrahelical_neighbors) <= 2 )
218
219
220
221
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b

cmaffeo2's avatar
cmaffeo2 committed
222
    def _neighbor_should_be_added(self,b):
223
224
225
        if type(self.parent) != type(b.parent):
            return True

cmaffeo2's avatar
cmaffeo2 committed
226
227
228
229
        c1 = self.contour_position
        c2 = b.get_contour_position(self.parent)
        if c2 < c1:
            b0 = self.get_intrahelical_below()
230
        else:
cmaffeo2's avatar
cmaffeo2 committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            b0 = self.get_intrahelical_above()

        if b0 is not None:            
            c0 = b0.get_contour_position(self.parent)
            if np.abs(c2-c1) < np.abs(c0-c1):
                ## remove b0
                self.intrahelical_neighbors.remove(b0)
                b0.intrahelical_neighbors.remove(self)
                return True
            else:
                return False
        return True
        
    def make_intrahelical_neighbor(self,b):
        add1 = self._neighbor_should_be_added(b)
        add2 = b._neighbor_should_be_added(self)
        if add1 and add2:
248
249
            # assert(len(b.intrahelical_neighbors) <= 1)
            # assert(len(self.intrahelical_neighbors) <= 1)
cmaffeo2's avatar
cmaffeo2 committed
250
251
            self.intrahelical_neighbors.append(b)
            b.intrahelical_neighbors.append(self)
252

cmaffeo2's avatar
cmaffeo2 committed
253
254
255
256
257
258
259
260
261
262
    def get_nt_position(self, seg):
        """ Returns the "address" of the nucleotide relative to seg in
        nucleotides, taking the shortest (intrahelical) contour length route to seg
        """
        if seg == self.parent:
            return seg.contour_to_nt_pos(self.contour_position)
        else:
            pos = self.get_contour_position(seg)
            return seg.contour_to_nt_pos(pos)

263
264
265
266
    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
cmaffeo2's avatar
cmaffeo2 committed
267
268
269
270
271
            cutoff = 30*3
            target_seg = seg

            ## depth-first search
            ## TODO cache distances to nearby locations?
cmaffeo2's avatar
cmaffeo2 committed
272
            def descend_search_tree(seg, contour_in_seg, distance=0, visited_segs=None):
cmaffeo2's avatar
cmaffeo2 committed
273
                nonlocal cutoff
cmaffeo2's avatar
cmaffeo2 committed
274
                if visited_segs is None: visited_segs = []
cmaffeo2's avatar
cmaffeo2 committed
275
276
277
278
279
280
281
282

                if seg == target_seg:
                    # pdb.set_trace()
                    ## Found a segment in our target
                    sign = (contour_in_seg == 1) - (contour_in_seg == 0)
                    assert( sign in (1,-1) )
                    if distance < cutoff: # TODO: check if this does anything
                        cutoff = distance
283
                    return [[distance, contour_in_seg+sign*seg.nt_pos_to_contour(distance)]], [(seg, contour_in_seg, distance)]
cmaffeo2's avatar
cmaffeo2 committed
284
                if distance > cutoff:
285
                    return None,None
cmaffeo2's avatar
cmaffeo2 committed
286
287
                    
                ret_list = []
288
                hist_list = []
cmaffeo2's avatar
cmaffeo2 committed
289
                ## Find intrahelical locations in seg that we might pass through
290
291
292
293
294
                conn_locs = seg.get_connections_and_locations("intrahelical")
                if isinstance(target_seg, SingleStrandedSegment):
                    tmp = seg.get_connections_and_locations("sscrossover")
                    conn_locs = conn_locs + list(filter(lambda x: x[2].container == target_seg, tmp))
                for c,A,B in conn_locs:
cmaffeo2's avatar
cmaffeo2 committed
295
                    if B.container in visited_segs: continue
296
297
298
                    dx = seg.contour_to_nt_pos( A.address, round_nt=False ) - seg.contour_to_nt_pos( contour_in_seg, round_nt=False)
                    dx = np.abs(dx)
                    results,history = descend_search_tree( B.container, B.address,
cmaffeo2's avatar
cmaffeo2 committed
299
300
301
                                                   distance+dx, visited_segs + [seg] )
                    if results is not None:
                        ret_list.extend( results )
302
303
                        hist_list.extend( history )
                return ret_list,hist_list
cmaffeo2's avatar
cmaffeo2 committed
304

305
            results,history = descend_search_tree(self.parent, self.contour_position)
cmaffeo2's avatar
cmaffeo2 committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            if results is None or len(results) == 0:
                raise Exception("Could not find location in segment") # TODO better error
            return sorted(results,key=lambda x:x[0])[0][1]

            # nt_pos = self.get_nt_position(seg)
            # return seg.nt_pos_to_contour(nt_pos)

    def update_position(self, contour_position):
        self.contour_position = contour_position
        self.position = self.parent.contour_to_position(contour_position)
        if 'orientation_bead' in self.__dict__:
            o = self.orientation_bead
            o.contour_position = contour_position
            orientation = self.parent.contour_to_orientation(contour_position)
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
            o.position = self.position + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
            
cmaffeo2's avatar
cmaffeo2 committed
325
326
327
    def __repr__(self):
        return "<SegmentParticle {} on {}[{:.2f}]>".format( self.name, self.parent, self.contour_position)

328
329

## TODO break this class into smaller, better encapsulated pieces
330
331
332
333
334
335
336
337
338
339
340
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
341
342
343
344
345
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
346

cmaffeo2's avatar
cmaffeo2 committed
347
    # orientation_bond = HarmonicBond(10,2)
348
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
349
350
351
352
353
354
355

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

356
    def __init__(self, name, num_nt, 
cmaffeo2's avatar
cmaffeo2 committed
357
                 start_position = None,
358
                 end_position = None, 
cmaffeo2's avatar
cmaffeo2 committed
359
360
                 segment_model = None,
                 **kwargs):
361

cmaffeo2's avatar
cmaffeo2 committed
362
363
        if start_position is None: start_position = np.array((0,0,0))

cmaffeo2's avatar
cmaffeo2 committed
364
        Group.__init__(self, name, children=[], **kwargs)
365
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
366

cmaffeo2's avatar
cmaffeo2 committed
367
368
369
        if 'segname' not in kwargs:
            self.segname = name
        # self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
370
371
372
373
374
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

375
376
377
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

378
379
380
381
        self.strand_pieces = dict()
        for d in ('fwd','rev'):
            self.strand_pieces[d] = []

382
        self.num_nt = int(num_nt)
383
        if end_position is None:
384
            end_position = np.array((0,0,self.distance_per_nt*num_nt)) + start_position
385
386
387
        self.start_position = start_position
        self.end_position = end_position

cmaffeo2's avatar
cmaffeo2 committed
388
389
390
391
        ## Used to assign cadnano names to beads
        self._generate_bead_callbacks = []
        self._generate_nucleotide_callbacks = []

392
        ## Set up interpolation for positions
393
394
395
396
        self._set_splines_from_ends()

        self.sequence = None

397
398
399
    def __repr__(self):
        return "<{} {}[{:d}]>".format( type(self), self.name, self.num_nt )

cmaffeo2's avatar
cmaffeo2 committed
400
    def set_splines(self, contours, coords):
401
        tck, u = interpolate.splprep( coords.T, u=contours, s=0, k=1)
402
        self.position_spline_params = (tck,u)
403

404
405
    def set_orientation_splines(self, contours, quaternions):
        tck, u = interpolate.splprep( quaternions.T, u=contours, s=0, k=1)
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.quaternion_spline_params = (tck,u)

    def get_center(self):
        tck, u = self.position_spline_params
        return np.mean(self.contour_to_position(u), axis=0)

    def __filter_contours(contours, positions, position_filter, contour_filter):
        u = contours
        r = positions

        ## Filter
        ids = list(range(len(u)))
        if contour_filter is not None:
            ids = list(filter(lambda i: contour_filter(u[i]), ids))
        if position_filter is not None:
            ids = list(filter(lambda i: position_filter(r[i,:]), ids))
        return ids

424
    def translate(self, translation_vector, position_filter=None, contour_filter=None):
425
426
427
428
429
430
431
432
        dr = np.array(translation_vector)
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)

        ## Translate
        r[ids,:] = r[ids,:] + dr[np.newaxis,:]
433
        self.set_splines(u,r)
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    def rotate(self, rotation_matrix, about=None, position_filter=None, contour_filter=None):
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)

        if about is None:
            ## TODO: do this more efficiently
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]) for i in range(r.shape[0])])
        else:
            dr = np.array(about)
            ## TODO: do this more efficiently
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]-dr) + dr for  i in range(r.shape[0])])

449
        self.set_splines(u,r)
450
451
452
453

        if self.quaternion_spline_params is not None:
            ## TODO: performance: don't shift between quaternion and matrix representations so much
            tck, u = self.quaternion_spline_params
454
            orientations = [self.contour_to_orientation(v) for v in u]
455
456
457
            for i in ids:
                orientations[i,:] = rotation_matrix.dot(orientations[i])
            quats = [quaternion_from_matrix(o) for o in orientations]
458
            self.set_orientation_splines(u, quats)
459

460
    def _set_splines_from_ends(self):
461
        self.quaternion_spline_params = None
cmaffeo2's avatar
cmaffeo2 committed
462
463
        coords = np.array([self.start_position, self.end_position])
        self.set_splines([0,1], coords)
464

465
466
467
468
469
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None
470

471
    def contour_to_nt_pos(self, contour_pos, round_nt=False):
472
        nt = contour_pos*(self.num_nt) - 0.5
473
        if round_nt:
cmaffeo2's avatar
cmaffeo2 committed
474
            assert( np.isclose(np.around(nt),nt) )
475
476
477
            nt = np.around(nt)
        return nt

478
    def nt_pos_to_contour(self,nt_pos):
479
        return (nt_pos+0.5)/(self.num_nt)
480

481
    def contour_to_position(self,s):
482
        p = interpolate.splev( s, self.position_spline_params[0] )
483
484
485
486
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
487
        t = interpolate.splev( s, self.position_spline_params[0], der=1 )
488
489
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
490
491
492
        

    def contour_to_orientation(self,s):
493
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
494

495
496
497
498
499
        if self.quaternion_spline_params is None:
            axis = self.contour_to_tangent(s)
            axis = axis / np.linalg.norm(axis)
            rotAxis = np.cross(axis,np.array((0,0,1)))
            rotAxisL = np.linalg.norm(rotAxis)
500
            zAxis = np.array((0,0,1))
501
502

            if rotAxisL > 0.001:
503
504
505
                theta = np.arcsin(rotAxisL) * 180/np.pi
                if axis.dot(zAxis) < 0: theta = 180-theta
                orientation0 = rotationAboutAxis( rotAxis/rotAxisL, theta, normalizeAxis=False ).T
506
            else:
507
508
                orientation0 = np.eye(3) if axis.dot(zAxis) > 0 else \
                               rotationAboutAxis( np.array((1,0,0)), 180, normalizeAxis=False )
509
510
511
            if self.start_orientation is not None:
                orientation0 = orientation0.dot(self.start_orientation)

512
513
514
            orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=False )
            orientation = orientation.dot(orientation0)
        else:
515
            q = interpolate.splev( s, self.quaternion_spline_params[0] )
516
517
            if len(q) > 1: q = np.array(q).T # TODO: is this needed?
            orientation = quaternion_to_matrix(q)
518

519
        return orientation
520

cmaffeo2's avatar
cmaffeo2 committed
521
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
522
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
523
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
524
        return sorted(cl, key=sort_fn)
525
526
527
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
528
        # bases = ['T']        ## FOR DEBUG
529
        if self.sequence is None:
530
            self.sequence = [random.choice(bases) for i in range(self.num_nt)]
531
        else:
532
            assert(len(self.sequence) == self.num_nt) # TODO move
533
534
535
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
536

cmaffeo2's avatar
cmaffeo2 committed
537
538
539
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

540
    def _generate_one_bead(self, contour_position, nts):
541
542
        raise NotImplementedError

cmaffeo2's avatar
cmaffeo2 committed
543
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale, strand_segment):
cmaffeo2's avatar
cmaffeo2 committed
544
545
546
547
548
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
549
        orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
550

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        """ deleteme
        ## TODO: move this code (?)
        if orientation is None:
            import pdb
            pdb.set_trace()
            axis = self.contour_to_tangent(contour_position)
            angleVec = np.array([1,0,0])
            if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
            angleVec = angleVec - angleVec.dot(axis)*axis
            angleVec = angleVec/np.linalg.norm(angleVec)
            y = np.cross(axis,angleVec)
            orientation = np.array([angleVec,y,axis]).T
            ## TODO: improve placement of ssDNA
            # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nt, normalizeAxis=True )
            # orientation = rot.dot(orientation)
        else:
            orientation = orientation                            
        """
cmaffeo2's avatar
cmaffeo2 committed
569
        key = seq
570
571
        nt_dict = canonicalNtFwd if is_fwd else canonicalNtRev

572
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
573
        atoms.orientation = orientation.dot(atoms.orientation)
574
575
576
577
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
578
                    a.fixed = 1
579
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsedPosition()
580
581
582
583
584
585
586
587
588
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
589
                        a.fixed = 1
590
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
591
592
593
594
595
596
        
        atoms.contour_position = contour_position
        strand_segment.add(atoms)

        for callback in self._generate_nucleotide_callbacks:
            callback(atoms)
cmaffeo2's avatar
cmaffeo2 committed
597
598

        return atoms
599

600
601
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
602
        c = self.nt_pos_to_contour(nt)
603
604
605
606
607
608
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
609
610
611
612
613

    def add_nick(self, nt, on_fwd_strand=True):
        self.add_3prime(nt,on_fwd_strand)
        self.add_5prime(nt+1,on_fwd_strand)

614
    def add_5prime(self, nt, on_fwd_strand=True):
615
616
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
617
        self.add_location(nt,"5prime",on_fwd_strand)
618
619

    def add_3prime(self, nt, on_fwd_strand=True):
620
621
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
622
        self.add_location(nt,"3prime",on_fwd_strand)
623

624
    def get_3prime_locations(self):
cmaffeo2's avatar
cmaffeo2 committed
625
        return sorted(self.get_locations("3prime"),key=lambda x: x.address)
626
    
cmaffeo2's avatar
cmaffeo2 committed
627
    def get_5prime_locations(self):
628
        ## TODO? ensure that data is consistent before _build_model calls
cmaffeo2's avatar
cmaffeo2 committed
629
        return sorted(self.get_locations("5prime"),key=lambda x: x.address)
cmaffeo2's avatar
cmaffeo2 committed
630

631
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
632
633
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
634
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
635
            cl = cl[::-1]
636
637
638
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
    ## TODO rename
    def _add_strand_piece(self, strand_piece):
        """ Registers a strand segment within this object """

        ## TODO use weakref
        d = 'fwd' if strand_piece.is_fwd else 'rev'

        ## Validate strand_piece (ensure no clashes)
        for s in self.strand_pieces[d]:
            l,h = sorted((s.start,s.end))
            for value in (strand_piece.start,strand_piece.end):
                assert( value < l or value > h )

        ## Add strand_piece in correct order
        self.strand_pieces[d].append(strand_piece)
        self.strand_pieces[d] = sorted(self.strand_pieces[d],
                                       key = lambda x: x.start)

658
    ## TODO rename
659
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
660
        """ Walks through locations, checking for crossovers """
661
662
663
664
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0
665
666

        ## Iterate through locations
cmaffeo2's avatar
cmaffeo2 committed
667
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
668
669
670
671
672
673
        def loc_rank(l):
            nt = l.get_nt_pos()
            ## optionally add logic about type of connection
            return (nt, not l.on_fwd_strand)
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        locations = sorted(self.locations, key=loc_rank, reverse=(not is_fwd))
674
675
        # print(locations)

676
        for l in locations:
cmaffeo2's avatar
cmaffeo2 committed
677
678
679
680
            # TODOTODO probably okay
            if l.address == 0:
                pos = 0.0
            elif l.address == 1:
681
                pos = self.num_nt-1
cmaffeo2's avatar
cmaffeo2 committed
682
683
            else:
                pos = self.contour_to_nt_pos(l.address, round_nt=True)
684
685
686

            ## DEBUG

cmaffeo2's avatar
cmaffeo2 committed
687

688
            ## Skip locations encountered before our strand
689
690
691
692
693
694
695
696
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue
697
698

            ## Stop if we found the 3prime end
699
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime" and l.connection is None:
700
                # print("  found end at",l)
701
                return pos, None, None, None, None
702
703
704
705
706
707
708
709

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
710
711
                # print("  passing through",l)
                # print("from {}, connection {} to {}".format(nt_pos,l,B))
712
                Bpos = B.get_nt_pos()
713
                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
714
715
716
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
717
                if nt_pos == pos: continue
718
                # print("  pausing at",l)
719
                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0
720
721
722
723
724
725

        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None

726
727
728
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
729
        # TODO: include beads in connections?
730
731
732
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]
733

734
735
736
737
738
739
740
741
742
    def _get_atomic_nucleotide(self, nucleotide_idx, is_fwd=True):
        d = 'fwd' if is_fwd else 'rev'
        for s in self.strand_pieces[d]:
            try:
                return s.get_nucleotide(nucleotide_idx)
            except:
                pass
        raise Exception("Could not find nucleotide in {} at {}.{}".format( self, nucleotide_idx, d ))

743
744
    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
cmaffeo2's avatar
cmaffeo2 committed
745
        ## Assume that consecutive beads in self.beads are bonded
746
        ret = []
cmaffeo2's avatar
cmaffeo2 committed
747
748
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
749
            ret.append( tmp )
750
        return ret   
751

752
753
754
    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
755
        
756
757
758
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
759
        self.add(b)
760
761
762
763
764
765
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
766
            self.add(o)
767
768
769
770
771
772
773
774
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []
775
776

        if True:
777
778
            ## TODO: remove this if duplicates are never found 
            # print("Searching for duplicate particles...")
779
            ## Remove duplicates, preserving order
780
781
782
783
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
784
                else:
785
                    print("  DUPLICATE PARTICLE FOUND!")
786
787
            new_children = tmp

788
789
790
791
792
        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
793
794
795
796
797
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
798
799
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))
800

801

cmaffeo2's avatar
cmaffeo2 committed
802
    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
803

804
        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
cmaffeo2's avatar
cmaffeo2 committed
805
        ## TODO: decide whether to remove bead_model argument
806
        ##       (currently unused)
cmaffeo2's avatar
cmaffeo2 committed
807

808
        ## First find points between-which beads must be generated
809
810
811
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
cmaffeo2's avatar
cmaffeo2 committed
812
813
        # if self.name == "S001":
        #     pdb.set_trace()
814

cmaffeo2's avatar
cmaffeo2 committed
815
816
        existing_beads0 = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads0), key=lambda b: b.get_contour_position(self) )
817

818
819
820
        # if self.num_nt == 1 and all([l.particle is not None for l in self.locations]):
        #     pdb.set_trace()
        #     return
821

822
823
824
825
        for b in existing_beads:
            assert(b.parent is not None)

        ## Add ends if they don't exist yet
826
        ## TODOTODO: test 1 nt segments?
827
        if len(existing_beads) == 0 or existing_beads[0].get_nt_position(self) >= 0.5:
cmaffeo2's avatar
cmaffeo2 committed
828
829
            # if len(existing_beads) > 0:            
            #     assert(existing_beads[0].get_nt_position(self) >= 0.5)
830
            b = self._generate_one_bead( self.nt_pos_to_contour(0), 0)
831
            existing_beads = [b] + existing_beads
cmaffeo2's avatar
cmaffeo2 committed
832

cmaffeo2's avatar
cmaffeo2 committed
833
        if existing_beads[-1].get_nt_position(self)-(self.num_nt-1) < -0.5 or len(existing_beads)==1:
834
            b = self._generate_one_bead( self.nt_pos_to_contour(self.num_nt-1), 0)
835
836
837
838
839
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
840
        last = None
cmaffeo2's avatar
cmaffeo2 committed
841

842
843
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
844
            assert( eb1 is not eb2 )
845

cmaffeo2's avatar
cmaffeo2 committed
846
847
848
849
            # if np.isclose(eb1.position[2], eb2.position[2]):
            #     import pdb
            #     pdb.set_trace()

850
            # print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
851
852
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
cmaffeo2's avatar
cmaffeo2 committed
853
854
855
856
857
858

            ## Ensure there is a ssDNA bead between dsDNA beads
            if num_beads == 0 and isinstance(self,SingleStrandedSegment) and isinstance(eb1.parent,DoubleStrandedSegment) and isinstance(eb2.parent,DoubleStrandedSegment):
                num_beads = 1
            ## TODO similarly ensure there is a dsDNA bead between ssDNA beads

859
            ds = e_ds / (num_beads+1)
860
861
862
            nts = ds*self.num_nt
            eb1.num_nt += 0.5*nts
            eb2.num_nt += 0.5*nts
863
864
865
866
867
868
869

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
cmaffeo2's avatar
cmaffeo2 committed
870
                last.make_intrahelical_neighbor(eb1)
871
872
873
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
874
                # if self.name in ("51-2","51-3"):
cmaffeo2's avatar
cmaffeo2 committed
875
                # if self.name in ("31-2",):
876
                #     print(" adding bead at {}".format(s))
877
878
                b = self._generate_one_bead(s,nts)

cmaffeo2's avatar
cmaffeo2 committed
879
                last.make_intrahelical_neighbor(b)
880
881
882
                last = b
                tmp_children.append(b)

cmaffeo2's avatar
cmaffeo2 committed
883
        last.make_intrahelical_neighbor(eb2)
884
885
886

        if eb2.parent == self:
            tmp_children.append(eb2)
cmaffeo2's avatar
cmaffeo2 committed
887
        # if self.name in ("31-2",):
888
        #     pdb.set_trace()
889
        self._rebuild_children(tmp_children)
890

cmaffeo2's avatar
cmaffeo2 committed
891
892
893
        for callback in self._generate_bead_callbacks:
            callback(self)

894
895
896
897
898
899
900
901
902
    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

903
    def __init__(self, name, num_bp, start_position = np.array((0,0,0)),
904
905
                 end_position = None, 
                 segment_model = None,
cmaffeo2's avatar
cmaffeo2 committed
906
907
                 local_twist = False,
                 num_turns = None,
cmaffeo2's avatar
cmaffeo2 committed
908
                 start_orientation = None,
cmaffeo2's avatar
cmaffeo2 committed
909
910
                 twist_persistence_length = 90,
                 **kwargs):
cmaffeo2's avatar
cmaffeo2 committed
911
912
913
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
914
        Segment.__init__(self, name, num_bp,
915
916
                         start_position,
                         end_position, 
cmaffeo2's avatar
cmaffeo2 committed
917
918
                         segment_model,
                         **kwargs)
919
        self.num_bp = self.num_nt
920

cmaffeo2's avatar
cmaffeo2 committed
921
922
        self.local_twist = local_twist
        if num_turns is None:
923
924
            num_turns = float(num_bp) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_bp
cmaffeo2's avatar
cmaffeo2 committed
925
926

        if start_orientation is None:
927
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
cmaffeo2's avatar
cmaffeo2 committed
928
        self.start_orientation = start_orientation
cmaffeo2's avatar
cmaffeo2 committed
929
        self.twist_persistence_length = twist_persistence_length
cmaffeo2's avatar
cmaffeo2 committed
930

931
932
        self.nicks = []

933
        self.start = self.start5 = Location( self, address=0, type_= "end5" )
934
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )
935

936
937
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
cmaffeo2's avatar
cmaffeo2 committed
938
939
        # for l in (self.start5,self.start3,self.end3,self.end5):
        #     self.locations.append(l)
940

941
942
943
        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
944
        self.position_spline_params = (tck,u)
945
946
947
        
        ## TODO: initialize sensible spline for orientation

948
    ## Convenience methods
949
    ## TODO: add errors if unrealistic connections are made
950
    ## TODO: make connections automatically between unconnected strands
951
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
952
953
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
954
955
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
956
        if isinstance(end5, SingleStrandedSegment):
957
            end5 = end5.start5