segmentmodel.py 114 KB
Newer Older
cmaffeo2's avatar
cmaffeo2 committed
1
import pdb
2
from pathlib import Path
3
import numpy as np
4
import random
5
6
7
from .model.arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from .coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
from .model.nonbonded import *
8
from copy import copy, deepcopy
9
from .model.nbPot import nbDnaScheme
10

cmaffeo2's avatar
cmaffeo2 committed
11
12
from scipy.special import erf
import scipy.optimize as opt
13
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
14

15
16
from .model.CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from .model.CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
cmaffeo2's avatar
cmaffeo2 committed
17

18
# import pdb
19
"""
cmaffeo2's avatar
cmaffeo2 committed
20
TODO:
cmaffeo2's avatar
cmaffeo2 committed
21
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
22
 + map to atomic representation
23
    + add nicks
cmaffeo2's avatar
cmaffeo2 committed
24
    + transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
25
    - shrink ssDNA
cmaffeo2's avatar
cmaffeo2 committed
26
    + shrink dsDNA backbone
27
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
28
    + sequence
29
    + handle circular dna
30
 + ensure crossover bead potentials aren't applied twice 
31
 + remove performance bottlenecks
32
 - test for large systems
cmaffeo2's avatar
cmaffeo2 committed
33
 + assign sequence
34
 + ENM
35
36
 - rework Location class 
 - remove recursive calls
37
 - document
38
 - develop unit test suite
39
"""
cmaffeo2's avatar
cmaffeo2 committed
40
41
class CircularDnaError(Exception):
    pass
42

cmaffeo2's avatar
cmaffeo2 committed
43
44
45
class ParticleNotConnectedError(Exception):
    pass

46
47
class Location():
    """ Site for connection within an object """
48
    def __init__(self, container, address, type_, on_fwd_strand = True):
49
        ## TODO: remove cyclic references(?)
50
        self.container = container
cmaffeo2's avatar
cmaffeo2 committed
51
        self.address = address  # represents position along contour length in segment
cmaffeo2's avatar
cmaffeo2 committed
52
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
53
        self.on_fwd_strand = on_fwd_strand
54
55
        self.type_ = type_
        self.particle = None
56
        self.connection = None
57
        self.is_3prime_side_of_connection = None
58

59
60
        self.prev_in_strand = None
        self.next_in_strand = None
61
62
        
        self.combine = None     # some locations might be combined in bead model 
63
64
65
66
67
68
69

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

70
    def set_connection(self, connection, is_3prime_side_of_connection):
71
        self.connection = connection # TODO weakref? 
72
        self.is_3prime_side_of_connection = is_3prime_side_of_connection
73

74
75
76
77
78
79
80
    def get_nt_pos(self):
        try:
            pos = self.container.contour_to_nt_pos(self.address, round_nt=True)
        except:
            if self.address == 0:
                pos = 0
            elif self.address == 1:
81
                pos = self.container.num_nt-1
82
83
84
85
            else:
                raise
        return pos

86
87
88
89
90
91
92
    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
93
94
95
96
97
98
99
100
101
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
102
103
104
105
106
107
108
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
cmaffeo2's avatar
cmaffeo2 committed
109

110
111
112
113
114
115
    def delete(self):
        self.A.container.connections.remove(self)
        self.B.container.connections.remove(self)
        self.A.connection = None
        self.B.connection = None

cmaffeo2's avatar
cmaffeo2 committed
116
117
118
119
    def __repr__(self):
        return "<Connection {}--{}--{}]>".format( self.A, self.type_, self.B )
        

120
        
121
122
123
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
cmaffeo2's avatar
cmaffeo2 committed
124
125
126
127
    def __init__(self, connection_locations=None, connections=None):
        if connection_locations is None: connection_locations = []
        if connections is None: connections = []

128
129
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
130
131
        self.connections = connections

cmaffeo2's avatar
cmaffeo2 committed
132
    def get_locations(self, type_=None, exclude=()):
133
134
135
136
137
138
139
140
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
141
142
143
144
        return locs

    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
        loc = None
145
        if (self.num_nt == 1):
146
147
148
149
150
151
152
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
cmaffeo2's avatar
cmaffeo2 committed
153
            # assert( loc is not None )
154
155
156
157
158
159
160
161
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
        if loc is None:
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc
162

cmaffeo2's avatar
cmaffeo2 committed
163
    def get_connections_and_locations(self, connection_type=None, exclude=()):
164
165
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
166
        type_ = connection_type
167
168
        ret = []
        for c in self.connections:
169
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
170
                if   c.A.container is self:
171
                    ret.append( [c, c.A, c.B] )
172
                elif c.B.container is self:
173
174
                    ret.append( [c, c.B, c.A] )
                else:
175
176
                    import pdb
                    pdb.set_trace()
177
178
179
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

180
    def _connect(self, other, connection, in_3prime_direction=None):
181
182
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
183
184
185
186
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
187
        A.connection = B.connection = connection
188
189
        self.connections.append(connection)
        other.connections.append(connection)
190
191
192
193
194
195
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

196
197
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
198
199

class SegmentParticle(PointParticle):
cmaffeo2's avatar
cmaffeo2 committed
200
    def __init__(self, type_, position, name="A", **kwargs):
201
        self.name = name
202
        self.contour_position = None
cmaffeo2's avatar
cmaffeo2 committed
203
        PointParticle.__init__(self, type_, position, name=name, **kwargs)
204
205
        self.intrahelical_neighbors = []
        self.other_neighbors = []
cmaffeo2's avatar
cmaffeo2 committed
206
        self.locations = []
207
208
209

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
210
        # assert( len(self.intrahelical_neighbors) <= 2 )
211
212
213
214
215
216
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
217
        # assert( len(self.intrahelical_neighbors) <= 2 )
218
219
220
221
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b

cmaffeo2's avatar
cmaffeo2 committed
222
    def _neighbor_should_be_added(self,b):
223
224
225
        if type(self.parent) != type(b.parent):
            return True

cmaffeo2's avatar
cmaffeo2 committed
226
227
228
229
        c1 = self.contour_position
        c2 = b.get_contour_position(self.parent)
        if c2 < c1:
            b0 = self.get_intrahelical_below()
230
        else:
cmaffeo2's avatar
cmaffeo2 committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
            b0 = self.get_intrahelical_above()

        if b0 is not None:            
            c0 = b0.get_contour_position(self.parent)
            if np.abs(c2-c1) < np.abs(c0-c1):
                ## remove b0
                self.intrahelical_neighbors.remove(b0)
                b0.intrahelical_neighbors.remove(self)
                return True
            else:
                return False
        return True
        
    def make_intrahelical_neighbor(self,b):
        add1 = self._neighbor_should_be_added(b)
        add2 = b._neighbor_should_be_added(self)
        if add1 and add2:
248
249
            # assert(len(b.intrahelical_neighbors) <= 1)
            # assert(len(self.intrahelical_neighbors) <= 1)
cmaffeo2's avatar
cmaffeo2 committed
250
251
            self.intrahelical_neighbors.append(b)
            b.intrahelical_neighbors.append(self)
252

cmaffeo2's avatar
cmaffeo2 committed
253
254
255
256
257
258
259
260
261
262
    def get_nt_position(self, seg):
        """ Returns the "address" of the nucleotide relative to seg in
        nucleotides, taking the shortest (intrahelical) contour length route to seg
        """
        if seg == self.parent:
            return seg.contour_to_nt_pos(self.contour_position)
        else:
            pos = self.get_contour_position(seg)
            return seg.contour_to_nt_pos(pos)

263
264
265
266
    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
cmaffeo2's avatar
cmaffeo2 committed
267
268
269
270
271
            cutoff = 30*3
            target_seg = seg

            ## depth-first search
            ## TODO cache distances to nearby locations?
cmaffeo2's avatar
cmaffeo2 committed
272
            def descend_search_tree(seg, contour_in_seg, distance=0, visited_segs=None):
cmaffeo2's avatar
cmaffeo2 committed
273
                nonlocal cutoff
cmaffeo2's avatar
cmaffeo2 committed
274
                if visited_segs is None: visited_segs = []
cmaffeo2's avatar
cmaffeo2 committed
275
276
277
278
279
280
281
282

                if seg == target_seg:
                    # pdb.set_trace()
                    ## Found a segment in our target
                    sign = (contour_in_seg == 1) - (contour_in_seg == 0)
                    assert( sign in (1,-1) )
                    if distance < cutoff: # TODO: check if this does anything
                        cutoff = distance
283
                    return [[distance, contour_in_seg+sign*seg.nt_pos_to_contour(distance)]], [(seg, contour_in_seg, distance)]
cmaffeo2's avatar
cmaffeo2 committed
284
                if distance > cutoff:
285
                    return None,None
cmaffeo2's avatar
cmaffeo2 committed
286
287
                    
                ret_list = []
288
                hist_list = []
cmaffeo2's avatar
cmaffeo2 committed
289
                ## Find intrahelical locations in seg that we might pass through
290
291
292
293
294
                conn_locs = seg.get_connections_and_locations("intrahelical")
                if isinstance(target_seg, SingleStrandedSegment):
                    tmp = seg.get_connections_and_locations("sscrossover")
                    conn_locs = conn_locs + list(filter(lambda x: x[2].container == target_seg, tmp))
                for c,A,B in conn_locs:
cmaffeo2's avatar
cmaffeo2 committed
295
                    if B.container in visited_segs: continue
296
297
298
                    dx = seg.contour_to_nt_pos( A.address, round_nt=False ) - seg.contour_to_nt_pos( contour_in_seg, round_nt=False)
                    dx = np.abs(dx)
                    results,history = descend_search_tree( B.container, B.address,
cmaffeo2's avatar
cmaffeo2 committed
299
300
301
                                                   distance+dx, visited_segs + [seg] )
                    if results is not None:
                        ret_list.extend( results )
302
303
                        hist_list.extend( history )
                return ret_list,hist_list
cmaffeo2's avatar
cmaffeo2 committed
304

305
            results,history = descend_search_tree(self.parent, self.contour_position)
cmaffeo2's avatar
cmaffeo2 committed
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            if results is None or len(results) == 0:
                raise Exception("Could not find location in segment") # TODO better error
            return sorted(results,key=lambda x:x[0])[0][1]

            # nt_pos = self.get_nt_position(seg)
            # return seg.nt_pos_to_contour(nt_pos)

    def update_position(self, contour_position):
        self.contour_position = contour_position
        self.position = self.parent.contour_to_position(contour_position)
        if 'orientation_bead' in self.__dict__:
            o = self.orientation_bead
            o.contour_position = contour_position
            orientation = self.parent.contour_to_orientation(contour_position)
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
            o.position = self.position + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
            
cmaffeo2's avatar
cmaffeo2 committed
325
326
327
    def __repr__(self):
        return "<SegmentParticle {} on {}[{:.2f}]>".format( self.name, self.parent, self.contour_position)

328
329

## TODO break this class into smaller, better encapsulated pieces
330
331
332
333
334
335
336
337
338
339
340
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
341
342
343
344
345
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
346

cmaffeo2's avatar
cmaffeo2 committed
347
    # orientation_bond = HarmonicBond(10,2)
348
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
349
350
351
352
353
354
355

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

356
    def __init__(self, name, num_nt, 
cmaffeo2's avatar
cmaffeo2 committed
357
                 start_position = None,
358
                 end_position = None, 
cmaffeo2's avatar
cmaffeo2 committed
359
360
                 segment_model = None,
                 **kwargs):
361

cmaffeo2's avatar
cmaffeo2 committed
362
363
        if start_position is None: start_position = np.array((0,0,0))

cmaffeo2's avatar
cmaffeo2 committed
364
        Group.__init__(self, name, children=[], **kwargs)
365
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
366

cmaffeo2's avatar
cmaffeo2 committed
367
368
369
        if 'segname' not in kwargs:
            self.segname = name
        # self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
370
371
372
373
374
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

375
376
377
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

378
379
380
381
        self.strand_pieces = dict()
        for d in ('fwd','rev'):
            self.strand_pieces[d] = []

382
        self.num_nt = int(num_nt)
383
        if end_position is None:
384
            end_position = np.array((0,0,self.distance_per_nt*num_nt)) + start_position
385
386
387
        self.start_position = start_position
        self.end_position = end_position

cmaffeo2's avatar
cmaffeo2 committed
388
389
390
391
        ## Used to assign cadnano names to beads
        self._generate_bead_callbacks = []
        self._generate_nucleotide_callbacks = []

392
        ## Set up interpolation for positions
393
394
395
396
        self._set_splines_from_ends()

        self.sequence = None

397
398
399
    def __repr__(self):
        return "<{} {}[{:d}]>".format( type(self), self.name, self.num_nt )

cmaffeo2's avatar
cmaffeo2 committed
400
    def set_splines(self, contours, coords):
401
        tck, u = interpolate.splprep( coords.T, u=contours, s=0, k=1)
402
        self.position_spline_params = (tck,u)
403

404
405
    def set_orientation_splines(self, contours, quaternions):
        tck, u = interpolate.splprep( quaternions.T, u=contours, s=0, k=1)
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        self.quaternion_spline_params = (tck,u)

    def get_center(self):
        tck, u = self.position_spline_params
        return np.mean(self.contour_to_position(u), axis=0)

    def __filter_contours(contours, positions, position_filter, contour_filter):
        u = contours
        r = positions

        ## Filter
        ids = list(range(len(u)))
        if contour_filter is not None:
            ids = list(filter(lambda i: contour_filter(u[i]), ids))
        if position_filter is not None:
            ids = list(filter(lambda i: position_filter(r[i,:]), ids))
        return ids

424
    def translate(self, translation_vector, position_filter=None, contour_filter=None):
425
426
427
428
429
430
431
432
        dr = np.array(translation_vector)
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)

        ## Translate
        r[ids,:] = r[ids,:] + dr[np.newaxis,:]
433
        self.set_splines(u,r)
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448

    def rotate(self, rotation_matrix, about=None, position_filter=None, contour_filter=None):
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)

        if about is None:
            ## TODO: do this more efficiently
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]) for i in range(r.shape[0])])
        else:
            dr = np.array(about)
            ## TODO: do this more efficiently
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]-dr) + dr for  i in range(r.shape[0])])

449
        self.set_splines(u,r)
450
451
452
453

        if self.quaternion_spline_params is not None:
            ## TODO: performance: don't shift between quaternion and matrix representations so much
            tck, u = self.quaternion_spline_params
454
            orientations = [self.contour_to_orientation(v) for v in u]
455
456
457
            for i in ids:
                orientations[i,:] = rotation_matrix.dot(orientations[i])
            quats = [quaternion_from_matrix(o) for o in orientations]
458
            self.set_orientation_splines(u, quats)
459

460
    def _set_splines_from_ends(self):
461
        self.quaternion_spline_params = None
cmaffeo2's avatar
cmaffeo2 committed
462
463
        coords = np.array([self.start_position, self.end_position])
        self.set_splines([0,1], coords)
464

465
466
467
468
469
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None
470

471
    def contour_to_nt_pos(self, contour_pos, round_nt=False):
472
        nt = contour_pos*(self.num_nt) - 0.5
473
        if round_nt:
cmaffeo2's avatar
cmaffeo2 committed
474
            assert( np.isclose(np.around(nt),nt) )
475
476
477
            nt = np.around(nt)
        return nt

478
    def nt_pos_to_contour(self,nt_pos):
479
        return (nt_pos+0.5)/(self.num_nt)
480

481
    def contour_to_position(self,s):
482
        p = interpolate.splev( s, self.position_spline_params[0] )
483
484
485
486
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
487
        t = interpolate.splev( s, self.position_spline_params[0], der=1 )
488
489
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
490
491
492
        

    def contour_to_orientation(self,s):
493
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
494

495
496
497
498
499
        if self.quaternion_spline_params is None:
            axis = self.contour_to_tangent(s)
            axis = axis / np.linalg.norm(axis)
            rotAxis = np.cross(axis,np.array((0,0,1)))
            rotAxisL = np.linalg.norm(rotAxis)
500
            zAxis = np.array((0,0,1))
501
502

            if rotAxisL > 0.001:
503
504
505
                theta = np.arcsin(rotAxisL) * 180/np.pi
                if axis.dot(zAxis) < 0: theta = 180-theta
                orientation0 = rotationAboutAxis( rotAxis/rotAxisL, theta, normalizeAxis=False ).T
506
            else:
507
508
                orientation0 = np.eye(3) if axis.dot(zAxis) > 0 else \
                               rotationAboutAxis( np.array((1,0,0)), 180, normalizeAxis=False )
509
510
511
            orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=False )
            orientation = orientation.dot(orientation0)
        else:
512
            q = interpolate.splev( s, self.quaternion_spline_params[0] )
513
514
            if len(q) > 1: q = np.array(q).T # TODO: is this needed?
            orientation = quaternion_to_matrix(q)
515

516
        return orientation
517

cmaffeo2's avatar
cmaffeo2 committed
518
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
519
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
520
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
521
        return sorted(cl, key=sort_fn)
522
523
524
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
525
        # bases = ['T']        ## FOR DEBUG
526
        if self.sequence is None:
527
            self.sequence = [random.choice(bases) for i in range(self.num_nt)]
528
        else:
529
            assert(len(self.sequence) == self.num_nt) # TODO move
530
531
532
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
533

cmaffeo2's avatar
cmaffeo2 committed
534
535
536
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

537
    def _generate_one_bead(self, contour_position, nts):
538
539
        raise NotImplementedError

cmaffeo2's avatar
cmaffeo2 committed
540
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale, strand_segment):
cmaffeo2's avatar
cmaffeo2 committed
541
542
543
544
545
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
546
        orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
547

548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
        """ deleteme
        ## TODO: move this code (?)
        if orientation is None:
            import pdb
            pdb.set_trace()
            axis = self.contour_to_tangent(contour_position)
            angleVec = np.array([1,0,0])
            if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
            angleVec = angleVec - angleVec.dot(axis)*axis
            angleVec = angleVec/np.linalg.norm(angleVec)
            y = np.cross(axis,angleVec)
            orientation = np.array([angleVec,y,axis]).T
            ## TODO: improve placement of ssDNA
            # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nt, normalizeAxis=True )
            # orientation = rot.dot(orientation)
        else:
            orientation = orientation                            
        """
cmaffeo2's avatar
cmaffeo2 committed
566
        key = seq
567
568
        nt_dict = canonicalNtFwd if is_fwd else canonicalNtRev

569
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
570
        atoms.orientation = orientation.dot(atoms.orientation)
571
572
573
574
575
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
                    a.beta = 0
576
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsedPosition()
577
578
579
580
581
582
583
584
585
586
587
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
                        a.beta = 0
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
588
589
590
591
592
593
        
        atoms.contour_position = contour_position
        strand_segment.add(atoms)

        for callback in self._generate_nucleotide_callbacks:
            callback(atoms)
cmaffeo2's avatar
cmaffeo2 committed
594
595

        return atoms
596

597
598
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
599
        c = self.nt_pos_to_contour(nt)
600
601
602
603
604
605
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
606
607
608
609
610

    def add_nick(self, nt, on_fwd_strand=True):
        self.add_3prime(nt,on_fwd_strand)
        self.add_5prime(nt+1,on_fwd_strand)

611
    def add_5prime(self, nt, on_fwd_strand=True):
612
613
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
614
        self.add_location(nt,"5prime",on_fwd_strand)
615
616

    def add_3prime(self, nt, on_fwd_strand=True):
617
618
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
619
        self.add_location(nt,"3prime",on_fwd_strand)
620

621
    def get_3prime_locations(self):
cmaffeo2's avatar
cmaffeo2 committed
622
        return sorted(self.get_locations("3prime"),key=lambda x: x.address)
623
    
cmaffeo2's avatar
cmaffeo2 committed
624
    def get_5prime_locations(self):
625
        ## TODO? ensure that data is consistent before _build_model calls
cmaffeo2's avatar
cmaffeo2 committed
626
        return sorted(self.get_locations("5prime"),key=lambda x: x.address)
cmaffeo2's avatar
cmaffeo2 committed
627

628
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
629
630
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
631
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
632
            cl = cl[::-1]
633
634
635
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
636

637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
    ## TODO rename
    def _add_strand_piece(self, strand_piece):
        """ Registers a strand segment within this object """

        ## TODO use weakref
        d = 'fwd' if strand_piece.is_fwd else 'rev'

        ## Validate strand_piece (ensure no clashes)
        for s in self.strand_pieces[d]:
            l,h = sorted((s.start,s.end))
            for value in (strand_piece.start,strand_piece.end):
                assert( value < l or value > h )

        ## Add strand_piece in correct order
        self.strand_pieces[d].append(strand_piece)
        self.strand_pieces[d] = sorted(self.strand_pieces[d],
                                       key = lambda x: x.start)

655
    ## TODO rename
656
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
657
        """ Walks through locations, checking for crossovers """
658
659
660
661
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0
662
663

        ## Iterate through locations
cmaffeo2's avatar
cmaffeo2 committed
664
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
665
666
667
668
669
670
        def loc_rank(l):
            nt = l.get_nt_pos()
            ## optionally add logic about type of connection
            return (nt, not l.on_fwd_strand)
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        locations = sorted(self.locations, key=loc_rank, reverse=(not is_fwd))
671
672
        # print(locations)

673
        for l in locations:
cmaffeo2's avatar
cmaffeo2 committed
674
675
676
677
            # TODOTODO probably okay
            if l.address == 0:
                pos = 0.0
            elif l.address == 1:
678
                pos = self.num_nt-1
cmaffeo2's avatar
cmaffeo2 committed
679
680
            else:
                pos = self.contour_to_nt_pos(l.address, round_nt=True)
681
682
683

            ## DEBUG

cmaffeo2's avatar
cmaffeo2 committed
684

685
            ## Skip locations encountered before our strand
686
687
688
689
690
691
692
693
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue
694
695
696

            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
697
                # print("  found end at",l)
698
                return pos, None, None, None, None
699
700
701
702
703
704
705
706

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
707
708
                # print("  passing through",l)
                # print("from {}, connection {} to {}".format(nt_pos,l,B))
709
                Bpos = B.get_nt_pos()
710
                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
711
712
713
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
714
                if nt_pos == pos: continue
715
                # print("  pausing at",l)
716
                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0
717
718
719
720
721
722

        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None

723
724
725
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
726
        # TODO: include beads in connections?
727
728
729
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]
730

731
732
733
734
735
736
737
738
739
    def _get_atomic_nucleotide(self, nucleotide_idx, is_fwd=True):
        d = 'fwd' if is_fwd else 'rev'
        for s in self.strand_pieces[d]:
            try:
                return s.get_nucleotide(nucleotide_idx)
            except:
                pass
        raise Exception("Could not find nucleotide in {} at {}.{}".format( self, nucleotide_idx, d ))

740
741
    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
cmaffeo2's avatar
cmaffeo2 committed
742
        ## Assume that consecutive beads in self.beads are bonded
743
        ret = []
cmaffeo2's avatar
cmaffeo2 committed
744
745
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
746
            ret.append( tmp )
747
        return ret   
748

749
750
751
    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
752
        
753
754
755
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
756
        self.add(b)
757
758
759
760
761
762
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
763
            self.add(o)
764
765
766
767
768
769
770
771
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []
772
773

        if True:
774
775
            ## TODO: remove this if duplicates are never found 
            # print("Searching for duplicate particles...")
776
            ## Remove duplicates, preserving order
777
778
779
780
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
781
                else:
782
                    print("  DUPLICATE PARTICLE FOUND!")
783
784
            new_children = tmp

785
786
787
788
789
        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
790
791
792
793
794
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
795
796
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))
797

798

cmaffeo2's avatar
cmaffeo2 committed
799
    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
800

801
        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
cmaffeo2's avatar
cmaffeo2 committed
802
        ## TODO: decide whether to remove bead_model argument
803
        ##       (currently unused)
cmaffeo2's avatar
cmaffeo2 committed
804

805
        ## First find points between-which beads must be generated
806
807
808
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
cmaffeo2's avatar
cmaffeo2 committed
809
810
        # if self.name == "S001":
        #     pdb.set_trace()
811

cmaffeo2's avatar
cmaffeo2 committed
812
813
        existing_beads0 = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads0), key=lambda b: b.get_contour_position(self) )
814

815
816
817
        # if self.num_nt == 1 and all([l.particle is not None for l in self.locations]):
        #     pdb.set_trace()
        #     return
818

819
820
821
822
        for b in existing_beads:
            assert(b.parent is not None)

        ## Add ends if they don't exist yet
823
        ## TODOTODO: test 1 nt segments?
824
        if len(existing_beads) == 0 or existing_beads[0].get_nt_position(self) >= 0.5:
cmaffeo2's avatar
cmaffeo2 committed
825
826
            # if len(existing_beads) > 0:            
            #     assert(existing_beads[0].get_nt_position(self) >= 0.5)
827
            b = self._generate_one_bead( self.nt_pos_to_contour(0), 0)
828
            existing_beads = [b] + existing_beads
cmaffeo2's avatar
cmaffeo2 committed
829

cmaffeo2's avatar
cmaffeo2 committed
830
        if existing_beads[-1].get_nt_position(self)-(self.num_nt-1) < -0.5 or len(existing_beads)==1:
831
            b = self._generate_one_bead( self.nt_pos_to_contour(self.num_nt-1), 0)
832
833
834
835
836
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
837
        last = None
cmaffeo2's avatar
cmaffeo2 committed
838

839
840
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
841
            assert( eb1 is not eb2 )
842

cmaffeo2's avatar
cmaffeo2 committed
843
844
845
846
            # if np.isclose(eb1.position[2], eb2.position[2]):
            #     import pdb
            #     pdb.set_trace()

847
            # print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
848
849
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
cmaffeo2's avatar
cmaffeo2 committed
850
851
852
853
854
855

            ## Ensure there is a ssDNA bead between dsDNA beads
            if num_beads == 0 and isinstance(self,SingleStrandedSegment) and isinstance(eb1.parent,DoubleStrandedSegment) and isinstance(eb2.parent,DoubleStrandedSegment):
                num_beads = 1
            ## TODO similarly ensure there is a dsDNA bead between ssDNA beads

856
            ds = e_ds / (num_beads+1)
857
858
859
            nts = ds*self.num_nt
            eb1.num_nt += 0.5*nts
            eb2.num_nt += 0.5*nts
860
861
862
863
864
865
866

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
cmaffeo2's avatar
cmaffeo2 committed
867
                last.make_intrahelical_neighbor(eb1)
868
869
870
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
871
                # if self.name in ("51-2","51-3"):
cmaffeo2's avatar
cmaffeo2 committed
872
                # if self.name in ("31-2",):
873
                #     print(" adding bead at {}".format(s))
874
875
                b = self._generate_one_bead(s,nts)

cmaffeo2's avatar
cmaffeo2 committed
876
                last.make_intrahelical_neighbor(b)
877
878
879
                last = b
                tmp_children.append(b)

cmaffeo2's avatar
cmaffeo2 committed
880
        last.make_intrahelical_neighbor(eb2)
881
882
883

        if eb2.parent == self:
            tmp_children.append(eb2)
cmaffeo2's avatar
cmaffeo2 committed
884
        # if self.name in ("31-2",):
885
        #     pdb.set_trace()
886
        self._rebuild_children(tmp_children)
887

cmaffeo2's avatar
cmaffeo2 committed
888
889
890
        for callback in self._generate_bead_callbacks:
            callback(self)

891
892
893
894
895
896
897
898
899
    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

900
    def __init__(self, name, num_bp, start_position = np.array((0,0,0)),
901
902
                 end_position = None, 
                 segment_model = None,
cmaffeo2's avatar
cmaffeo2 committed
903
904
                 local_twist = False,
                 num_turns = None,
cmaffeo2's avatar
cmaffeo2 committed
905
                 start_orientation = None,
cmaffeo2's avatar
cmaffeo2 committed
906
907
                 twist_persistence_length = 90,
                 **kwargs):
cmaffeo2's avatar
cmaffeo2 committed
908
909
910
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
911
        Segment.__init__(self, name, num_bp,
912
913
                         start_position,
                         end_position, 
cmaffeo2's avatar
cmaffeo2 committed
914
915
                         segment_model,
                         **kwargs)
916
        self.num_bp = self.num_nt
917

cmaffeo2's avatar
cmaffeo2 committed
918
919
        self.local_twist = local_twist
        if num_turns is None:
920
921
            num_turns = float(num_bp) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_bp
cmaffeo2's avatar
cmaffeo2 committed
922
923

        if start_orientation is None:
924
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
cmaffeo2's avatar
cmaffeo2 committed
925
        self.start_orientation = start_orientation
cmaffeo2's avatar
cmaffeo2 committed
926
        self.twist_persistence_length = twist_persistence_length
cmaffeo2's avatar
cmaffeo2 committed
927

928
929
        self.nicks = []

930
        self.start = self.start5 = Location( self, address=0, type_= "end5" )
931
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )
932

933
934
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
cmaffeo2's avatar
cmaffeo2 committed
935
936
        # for l in (self.start5,self.start3,self.end3,self.end5):
        #     self.locations.append(l)
937

938
939
940
        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
941
        self.position_spline_params = (tck,u)
942
943
944
        
        ## TODO: initialize sensible spline for orientation

945
    ## Convenience methods
946
    ## TODO: add errors if unrealistic connections are made
947
    ## TODO: make connections automatically between unconnected strands
948
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
949
950
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
951
952
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
953
        if isinstance(end5, SingleStrandedSegment):
954
            end5 = end5.start5
955
956
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
957
        if isinstance(end5, SingleStrandedSegment):
958
            end5 = end5.start5
959
960
        self._connect_ends( self.end3, end5, type_,