segmentmodel.py 80.3 KB
Newer Older
cmaffeo2's avatar
cmaffeo2 committed
1
import pdb
2
import numpy as np
3
import random
4
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
5
from coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
6
7
8
9
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme

cmaffeo2's avatar
cmaffeo2 committed
10
11
from scipy.special import erf
import scipy.optimize as opt
12
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
13

cmaffeo2's avatar
cmaffeo2 committed
14
15
16
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC

17
# import pdb
18
"""
cmaffeo2's avatar
cmaffeo2 committed
19
TODO:
cmaffeo2's avatar
cmaffeo2 committed
20
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
21
 + map to atomic representation
22
23
    + add nicks
    - transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
24
25
    - shrink ssDNA
    - shrink dsDNA backbone
26
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
27
    - sequence
28
    - handle circular dna
29
 + ensure crossover bead potentials aren't applied twice 
30
 + remove performance bottlenecks
31
32
 - test for large systems
 - assign sequence
33
 - ENM
34
35
 - rework Location class 
 - remove recursive calls
36
 - document
37
 - add unit test of helices connected to themselves
38
39
"""

cmaffeo2's avatar
cmaffeo2 committed
40
41
42
class ParticleNotConnectedError(Exception):
    pass

43
44
class Location():
    """ Site for connection within an object """
45
    def __init__(self, container, address, type_, on_fwd_strand = True):
46
        ## TODO: remove cyclic references(?)
47
        self.container = container
cmaffeo2's avatar
cmaffeo2 committed
48
        self.address = address  # represents position along contour length in segment
cmaffeo2's avatar
cmaffeo2 committed
49
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
50
        self.on_fwd_strand = on_fwd_strand
51
52
        self.type_ = type_
        self.particle = None
53
        self.connection = None
54
        self.is_3prime_side_of_connection = None
55

56
57
        self.prev_in_strand = None
        self.next_in_strand = None
58
59
        
        self.combine = None     # some locations might be combined in bead model 
60
61
62
63
64
65
66

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

67
    def set_connection(self, connection, is_3prime_side_of_connection):
68
        self.connection = connection # TODO weakref? 
69
        self.is_3prime_side_of_connection = is_3prime_side_of_connection
70

71
72
73
74
75
76
77
78
79
80
81
82
    def get_nt_pos(self):
        try:
            pos = self.container.contour_to_nt_pos(self.address, round_nt=True)
        except:
            if self.address == 0:
                pos = 0
            elif self.address == 1:
                pos = self.container.num_nts-1
            else:
                raise
        return pos

83
84
85
86
87
88
89
    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
90
91
92
93
94
95
96
97
98
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
99
100
101
102
103
104
105
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
cmaffeo2's avatar
cmaffeo2 committed
106
107
108
109
110

    def __repr__(self):
        return "<Connection {}--{}--{}]>".format( self.A, self.type_, self.B )
        

111
        
112
113
114
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
115
116
117
118
    ## TODO: eliminate mutable default arguments
    def __init__(self, connection_locations=[], connections=[]):
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
119
120
        self.connections = connections

121
122
123
124
125
126
127
128
129
    def get_locations(self, type_=None, exclude=[]):
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
130
131
132
133
134
135
136
137
138
139
140
141
        return locs

    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
        loc = None
        if (self.num_nts == 1):
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
cmaffeo2's avatar
cmaffeo2 committed
142
            # assert( loc is not None )
143
144
145
146
147
148
149
150
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
        if loc is None:
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc
151
152

    def get_connections_and_locations(self, connection_type=None, exclude=[]):
153
154
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
155
        type_ = connection_type
156
157
        ret = []
        for c in self.connections:
158
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
159
                if   c.A.container is self:
160
                    ret.append( [c, c.A, c.B] )
161
                elif c.B.container is self:
162
163
                    ret.append( [c, c.B, c.A] )
                else:
164
165
                    import pdb
                    pdb.set_trace()
166
167
168
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

169
    def _connect(self, other, connection, in_3prime_direction=None):
170
171
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
172
173
174
175
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
176
        A.connection = B.connection = connection
177
178
        self.connections.append(connection)
        other.connections.append(connection)
179
180
181
182
183
184
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

185
186
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
187
188
189

class SegmentParticle(PointParticle):
    def __init__(self, type_, position, name="A", segname="A", **kwargs):
190
        self.name = name
191
192
193
194
        self.contour_position = None
        PointParticle.__init__(self, type_, position, name=name, segname=segname, **kwargs)
        self.intrahelical_neighbors = []
        self.other_neighbors = []
cmaffeo2's avatar
cmaffeo2 committed
195
        self.locations = []
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b

cmaffeo2's avatar
cmaffeo2 committed
211
212
213
214
215
    def _neighbor_should_be_added(self,b):
        c1 = self.contour_position
        c2 = b.get_contour_position(self.parent)
        if c2 < c1:
            b0 = self.get_intrahelical_below()
216
        else:
cmaffeo2's avatar
cmaffeo2 committed
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
            b0 = self.get_intrahelical_above()

        if b0 is not None:            
            c0 = b0.get_contour_position(self.parent)
            if np.abs(c2-c1) < np.abs(c0-c1):
                ## remove b0
                self.intrahelical_neighbors.remove(b0)
                b0.intrahelical_neighbors.remove(self)
                return True
            else:
                return False
        return True
        
    def make_intrahelical_neighbor(self,b):
        add1 = self._neighbor_should_be_added(b)
        add2 = b._neighbor_should_be_added(self)
        if add1 and add2:
            assert(len(b.intrahelical_neighbors) <= 1)
            assert(len(self.intrahelical_neighbors) <= 1)
            self.intrahelical_neighbors.append(b)
            b.intrahelical_neighbors.append(self)
238

cmaffeo2's avatar
cmaffeo2 committed
239
240
241
242
243
244
        
    # def get_nt_position(self,seg):
    #     if seg == self.parent:
    #         return seg.contour_to_nt_pos(self.contour_position)
    #     else:
    #         cl = [e for e in self.parent.get_connections_and_locations() if e[2].container is seg]
245

cmaffeo2's avatar
cmaffeo2 committed
246
247
248
249
250
    #         dc = [(self.contour_position - A.address)**2 for c,A,B in cl]

    #         if len(dc) == 0:
    #             import pdb
    #             pdb.set_trace()
251

cmaffeo2's avatar
cmaffeo2 committed
252
253
254
255
256
257
258
259
260
261
262
    #         i = np.argmin(dc)
    #         c,A,B = cl[i]
    #         ## TODO: generalize, removing np.abs and conditional 
    #         delta_nt = np.abs( A.container.contour_to_nt_pos(self.contour_position - A.address) )
    #         B_nt_pos = seg.contour_to_nt_pos(B.address)
    #         if B.address < 0.5:
    #             return B_nt_pos-delta_nt
    #         else:
    #             return B_nt_pos+delta_nt

    def get_nt_position(self,seg):
263
        if seg == self.parent:
cmaffeo2's avatar
cmaffeo2 committed
264
            return seg.contour_to_nt_pos(self.contour_position)
265
266
        else:

cmaffeo2's avatar
cmaffeo2 committed
267
268
269
270
            def get_nt_pos(contour1, seg1, seg2):
                cl = [e for e in seg1.get_connections_and_locations() if e[2].container is seg2]
                dc = [(contour1 - A.address)**2 for c,A,B in cl]
                if len(dc) == 0: return None
271

cmaffeo2's avatar
cmaffeo2 committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
                i = np.argmin(dc)
                c,A,B = cl[i]

                ## TODO: generalize, removing np.abs and conditional 
                delta_nt = np.abs( seg1.contour_to_nt_pos(contour1 - A.address) )
                B_nt_pos = seg2.contour_to_nt_pos(B.address)
                if B.address < 0.5:
                    return B_nt_pos-delta_nt
                else:
                    return B_nt_pos+delta_nt
                
            pos = get_nt_pos(self.contour_position, self.parent, seg)
            if pos is None:
                ## Particle is not directly connected
                visited_segs = set(seg)
                positions = []
                for l in self.locations:
                    if l.container == self.parent: continue
                    pos0 = get_nt_pos(self.contour_position, self.parent, l.container)
                    assert(pos0 is not None)
                    pos0 = l.container.nt_pos_to_contour(pos0)
                    pos = get_nt_pos( pos0, l.container, seg )
                    if pos is not None:
                        positions.append( pos )
                assert( len(positions) > 0 )
                if len(positions) > 1:
                    import pdb
                    pdb.set_trace()
                pos = positions[0]
            return pos
302
303


304
305
306
307
    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
308
309
            nt_pos = self.get_nt_position(seg)
            return seg.nt_pos_to_contour(nt_pos)
310
311

## TODO break this class into smaller, better encapsulated pieces
312
313
314
315
316
317
318
319
320
321
322
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
323
324
325
326
327
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
328

cmaffeo2's avatar
cmaffeo2 committed
329
    # orientation_bond = HarmonicBond(10,2)
330
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
331
332
333
334
335
336
337
338
339
340
341
342
343

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
344
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
345

346
        self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
347
348
349
350
351
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

352
353
354
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

cmaffeo2's avatar
cmaffeo2 committed
355
        self.num_nts = int(num_nts)
356
357
358
359
360
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

361
362
363
364
        ## Set up interpolation for positions
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
365
366
        
        self.sequence = None
367

368
369
370
371
372
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None
373

374
    def contour_to_nt_pos(self, contour_pos, round_nt=False):
cmaffeo2's avatar
cmaffeo2 committed
375
        nt = contour_pos*(self.num_nts) - 0.5
376
        if round_nt:
cmaffeo2's avatar
cmaffeo2 committed
377
            assert( np.isclose(np.around(nt),nt) )
378
379
380
            nt = np.around(nt)
        return nt

381
    def nt_pos_to_contour(self,nt_pos):
cmaffeo2's avatar
cmaffeo2 committed
382
        return (nt_pos+0.5)/(self.num_nts)
383

384
385
386
387
388
389
390
    def contour_to_position(self,s):
        p = interpolate.splev( s, self.position_spline_params )
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
        t = interpolate.splev( s, self.position_spline_params, der=1 )
391
392
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
393
394
395
        

    def contour_to_orientation(self,s):
396
397
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
        orientation = None
398
399
400
401
        if self.start_orientation is not None:
            # axis = self.start_orientation.dot( np.array((0,0,1)) )
            if self.quaternion_spline_params is None:
                axis = self.contour_to_tangent(s)
402
                orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=True )
403
404
405
406
407
            else:
                q = interpolate.splev( s, self.quaternion_spline_params )
                if len(q) > 1: q = np.array(q).T # TODO: is this needed?
                orientation = quaternion_to_matrix(q)
        return orientation
408

cmaffeo2's avatar
cmaffeo2 committed
409
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
410
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
411
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
412
        return sorted(cl, key=sort_fn)
413
414
415
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
416
        # bases = ['T']        ## FOR DEBUG
417
418
419
420
421
422
423
        if self.sequence is None:
            self.sequence = [random.choice(bases) for i in range(self.num_nts)]
        else:
            assert(len(self.sequence) == self.num_nts) # TODO move
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
424

cmaffeo2's avatar
cmaffeo2 committed
425
426
427
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

428
    def _generate_one_bead(self, contour_position, nts):
429
430
        raise NotImplementedError

431
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale):
cmaffeo2's avatar
cmaffeo2 committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
        if self.local_twist:
            orientation = self.contour_to_orientation(contour_position)
            ## TODO: move this code (?)
            if orientation is None:
                axis = self.contour_to_tangent(contour_position)
                angleVec = np.array([1,0,0])
                if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
                angleVec = angleVec - angleVec.dot(axis)*axis
                angleVec = angleVec/np.linalg.norm(angleVec)
                y = np.cross(axis,angleVec)
                orientation = np.array([angleVec,y,axis]).T
448
                ## TODO: improve placement of ssDNA
cmaffeo2's avatar
cmaffeo2 committed
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
                # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nts, normalizeAxis=True )
                # orientation = rot.dot(orientation)
            else:
                orientation = orientation
                            
        else:
            raise NotImplementedError

        # key = self.sequence
        # if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        # if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        # if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"

        key = seq
        if not is_fwd:
            nt_dict = canonicalNtFwd
        else:
            nt_dict = canonicalNtRev
467
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
468
                        
cmaffeo2's avatar
cmaffeo2 committed
469
        atoms.orientation = orientation.dot(atoms.orientation)
470
471
472
473
474
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
                    a.beta = 0
475
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsedPosition()
476
477
478
479
480
481
482
483
484
485
486
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
                        a.beta = 0
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
487
488

        return atoms
489

490
491
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
492
        c = self.nt_pos_to_contour(nt)
493
494
495
496
497
498
499
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
    def add_5prime(self, nt, on_fwd_strand=True):
500
501
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
502
        self.add_location(nt,"5prime",on_fwd_strand)
503
504

    def add_3prime(self, nt, on_fwd_strand=True):
505
506
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
507
        self.add_location(nt,"3prime",on_fwd_strand)
508

509
510
511
    def get_3prime_locations(self):
        return self.get_locations("3prime")
    
cmaffeo2's avatar
cmaffeo2 committed
512
    def get_5prime_locations(self):
513
514
        ## TODO? ensure that data is consistent before _build_model calls
        return self.get_locations("5prime")
cmaffeo2's avatar
cmaffeo2 committed
515

516
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
517
518
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
519
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
520
            cl = cl[::-1]
521
522
523
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
524

525
    ## TODO rename
526
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
527
        """ Walks through locations, checking for crossovers """
528
529
530
531
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0
532
533

        ## Iterate through locations
cmaffeo2's avatar
cmaffeo2 committed
534
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
535
536
537
538
539
540
        def loc_rank(l):
            nt = l.get_nt_pos()
            ## optionally add logic about type of connection
            return (nt, not l.on_fwd_strand)
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        locations = sorted(self.locations, key=loc_rank, reverse=(not is_fwd))
541
542
        # print(locations)

543
        for l in locations:
cmaffeo2's avatar
cmaffeo2 committed
544
545
546
547
548
549
550
            # TODOTODO probably okay
            if l.address == 0:
                pos = 0.0
            elif l.address == 1:
                pos = self.num_nts-1
            else:
                pos = self.contour_to_nt_pos(l.address, round_nt=True)
551
552
553

            ## DEBUG

cmaffeo2's avatar
cmaffeo2 committed
554

555
            ## Skip locations encountered before our strand
556
557
558
559
560
561
562
563
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue
564
565
566

            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
567
568
                print("  found end at",l)
                return pos, None, None, None, None
569
570
571
572
573
574
575
576

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
577
578
                print("  passing through",l)
                print("from {}, connection {} to {}".format(nt_pos,l,B))
579
                Bpos = B.get_nt_pos()
580
                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
581
582
583
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
584
585
586
                if nt_pos == pos: continue
                print("  pausing at",l)
                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0
587

588
589
        import pdb
        pdb.set_trace()
590
591
592
593
594
        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None

595
596
597
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
598
        # TODO: include beads in connections?
599
600
601
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]
602
603
604

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
cmaffeo2's avatar
cmaffeo2 committed
605
        ## Assume that consecutive beads in self.beads are bonded
606
        ret = []
cmaffeo2's avatar
cmaffeo2 committed
607
608
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
609
            ret.append( tmp )
610
        return ret   
611

612
613
614
    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
615
        
616
617
618
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
619
        self.add(b)
620
621
622
623
624
625
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
626
            self.add(o)
627
628
629
630
631
632
633
634
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []
635
636
637

        if True:
            print("WARNING: DEBUG")
638
            ## Remove duplicates, preserving order
639
640
641
642
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
643
644
                else:
                    print("  duplicate particle found!")
645
646
            new_children = tmp

647
648
649
650
651
        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
652
653
654
655
656
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
657
658
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))
659

660

cmaffeo2's avatar
cmaffeo2 committed
661
    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
662

663
        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
cmaffeo2's avatar
cmaffeo2 committed
664
        ## TODO: decide whether to remove bead_model argument
665
        ##       (currently unused)
cmaffeo2's avatar
cmaffeo2 committed
666

667
        ## First find points between-which beads must be generated
668
669
670
671
672
673
674
675
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
        existing_beads = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads), key=lambda b: b.get_contour_position(self) )
        
        if len(existing_beads) != len(set(existing_beads)):
            pdb.set_trace()
676
677
678
        for b in existing_beads:
            assert(b.parent is not None)

cmaffeo2's avatar
cmaffeo2 committed
679
680
681
682
        # if self.name == "1-1":
        #     import pdb
        #     pdb.set_trace()

683
        ## Add ends if they don't exist yet
684
        ## TODOTODO: test 1 nt segments?
cmaffeo2's avatar
cmaffeo2 committed
685
686
687
        if len(existing_beads) == 0 or existing_beads[0].get_nt_position(self) > 0.5:
            # if len(existing_beads) > 0:            
            #     assert(existing_beads[0].get_nt_position(self) >= 0.5)
688
689
            b = self._generate_one_bead(0, 0)
            existing_beads = [b] + existing_beads
cmaffeo2's avatar
cmaffeo2 committed
690
691

        if existing_beads[-1].get_nt_position(self)-(self.num_nts-1) < -0.5:
692
693
694
695
696
697
            b = self._generate_one_bead(1, 0)
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
698
        last = None
699
700
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
701
            assert( eb1 is not eb2 )
702

cmaffeo2's avatar
cmaffeo2 committed
703
704
705
706
707
            # if np.isclose(eb1.position[2], eb2.position[2]):
            #     import pdb
            #     pdb.set_trace()

            print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
708
709
710
711
712
713
714
715
716
717
718
719
720
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
            ds = e_ds / (num_beads+1)
            nts = ds*self.num_nts
            eb1.num_nts += 0.5*nts
            eb2.num_nts += 0.5*nts

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
cmaffeo2's avatar
cmaffeo2 committed
721
                last.make_intrahelical_neighbor(eb1)
722
723
724
725
726
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
                b = self._generate_one_bead(s,nts)

cmaffeo2's avatar
cmaffeo2 committed
727
                last.make_intrahelical_neighbor(b)
728
729
730
                last = b
                tmp_children.append(b)

cmaffeo2's avatar
cmaffeo2 committed
731
        last.make_intrahelical_neighbor(eb2)
732
733
734
735

        if eb2.parent == self:
            tmp_children.append(eb2)
        self._rebuild_children(tmp_children)
736
737
738
739
740
741
742
743
744
745
746
747
748

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
cmaffeo2's avatar
cmaffeo2 committed
749
750
                 local_twist = False,
                 num_turns = None,
cmaffeo2's avatar
cmaffeo2 committed
751
752
                 start_orientation = None,
                 twist_persistence_length = 90 ):
cmaffeo2's avatar
cmaffeo2 committed
753
754
755
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
756
757
758
759
760
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

cmaffeo2's avatar
cmaffeo2 committed
761
762
763
764
765
766
        self.local_twist = local_twist
        if num_turns is None:
            num_turns = float(num_nts) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_nts

        if start_orientation is None:
767
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
cmaffeo2's avatar
cmaffeo2 committed
768
        self.start_orientation = start_orientation
cmaffeo2's avatar
cmaffeo2 committed
769
        self.twist_persistence_length = twist_persistence_length
cmaffeo2's avatar
cmaffeo2 committed
770

771
772
        self.nicks = []

773
        self.start = self.start5 = Location( self, address=0, type_= "end5" )
774
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )
775

776
777
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
cmaffeo2's avatar
cmaffeo2 committed
778
779
        # for l in (self.start5,self.start3,self.end3,self.end5):
        #     self.locations.append(l)
780

781
782
783
784
785
786
787
788
789
        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        ## TODO: initialize sensible spline for orientation
        self.quaternion_spline_params = None


790
    ## Convenience methods
791
    ## TODO: add errors if unrealistic connections are made
792
    ## TODO: make connections automatically between unconnected strands
793
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
794
795
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
796
797
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
798
        if isinstance(end5, SingleStrandedSegment):
799
            end5 = end5.start5
800
801
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
802
        if isinstance(end5, SingleStrandedSegment):
803
            end5 = end5.start5
804
805
        self._connect_ends( self.end3, end5, type_, force_connection = force_connection )
    def connect_end5(self, end3, type_="intrahelical", force_connection=False):
806
807
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
808
        self._connect_ends( self.end5, end3, type_, force_connection = force_connection )
809

cmaffeo2's avatar
cmaffeo2 committed
810
    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
cmaffeo2's avatar
cmaffeo2 committed
811
812
813
814
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO

815
        if isinstance(other,SingleStrandedSegment):
cmaffeo2's avatar
cmaffeo2 committed
816
            other.add_crossover(other_nt, self, nt, strands_fwd[::-1], not nt_on_5prime)
817
        else:
818

819
820
821
            ## Create locations, connections and add to segments
            c = self.nt_pos_to_contour(nt)
            assert(c >= 0 and c <= 1)
822

823
824
825
            loc = self.get_location_at(c, strands_fwd[0])

            c = other.nt_pos_to_contour(other_nt)
cmaffeo2's avatar
cmaffeo2 committed
826
            # TODOTODO: may need to subtract or add a little depending on 3prime/5prime
827
828
829
            assert(c >= 0 and c <= 1)
            other_loc = other.get_location_at(c, strands_fwd[1])
            self._connect(other, Connection( loc, other_loc, type_="crossover" ))
cmaffeo2's avatar
cmaffeo2 committed
830
831
832
833
834
835
            if nt_on_5prime:
                loc.is_3prime_side_of_connection = False
                other_loc.is_3prime_side_of_connection = True
            else:            
                loc.is_3prime_side_of_connection = True
                other_loc.is_3prime_side_of_connection = False
cmaffeo2's avatar
cmaffeo2 committed
836

837
    ## Real work
838
    def _connect_ends(self, end1, end2, type_, force_connection):
839
        ## TODO remove self?
840
841
842
843
844
        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )
845
        ## Create and add connection
cmaffeo2's avatar
cmaffeo2 committed
846
        if end2.type_ == "end5":
847
848
849
            end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ), in_3prime_direction=True )
        else:
            end2.container._connect( end1.container, Connection( end2, end1, type_=type_ ), in_3prime_direction=True )
850
851
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_basepairs_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
852

853
854
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
855
        if self.local_twist:
856
            orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
857
858
859
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
cmaffeo2's avatar
cmaffeo2 committed
860
            opos = pos + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
861
862
            o = SegmentParticle( Segment.orientation_particle, opos, nts,
                                 num_nts=nts, parent=self )
863
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
864
865
866
                                    num_nts=nts, parent=self, 
                                    orientation_bead=o,
                                    contour_position=contour_position )
cmaffeo2's avatar
cmaffeo2 committed
867
868

        else:
869
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
870
871
872
                                    num_nts=nts, parent=self,
                                    contour_position=contour_position )
        self._add_bead(bead)
cmaffeo2's avatar
cmaffeo2 committed
873
        return bead
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

890
        self.start = self.start5 = Location( self, address=0, type_= "end5" ) # TODO change type_?
891
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
cmaffeo2's avatar
cmaffeo2 committed
892
893
        # for l in (self.start5,self.end3):
        #     self.locations.append(l)
894

895
    def connect_end3(self, end5, force_connection=False):
cmaffeo2's avatar
cmaffeo2 committed
896
        self._connect_end( end5,  _5_to_3 = True, force_connection = force_connection )
897

898
    def connect_5end(self, end3, force_connection=False): # TODO: change name or possibly deprecate
cmaffeo2's avatar
cmaffeo2 committed
899
        self._connect_end( end3,  _5_to_3 = False, force_connection = force_connection )
900
901
902
903

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
cmaffeo2's avatar
cmaffeo2 committed
904
905
906
907
            my_end = self.end3
            # assert( other.type_ == "end5" )
            if (other.type_ is not "end5"):
                print("Warning: code does not prevent connecting 3prime to 3prime, etc")
908
909
            conn = Connection( my_end, other, type_="intrahelical" )
            self._connect( other.container, conn, in_3prime_direction=True )
910
        else:
cmaffeo2's avatar
cmaffeo2 committed
911
912
913
914
            my_end = self.end5
            # assert( other.type_ == "end3" )
            if (other.type_ is not "end3"):
                print("Warning: code does not prevent connecting 3prime to 3prime, etc")
915
916
            conn = Connection( other, my_end, type_="intrahelical" )
            other.container._connect( self, conn, in_3prime_direction=True )
917

cmaffeo2's avatar
cmaffeo2 committed
918
    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
919
920
921
922
923
924
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO
       
        ## TODO: fix direction

cmaffeo2's avatar
cmaffeo2 committed
925
926
927
928
929
930
931
932
933
934
935
        # c1 = self.nt_pos_to_contour(nt)
        # # TODOTODO
        # ## Ensure connections occur at ends, otherwise the structure doesn't make sense
        # # assert(np.isclose(c1,0) or np.isclose(c1,1))
        # assert(np.isclose(nt,0) or np.isclose(nt,self.num_nts-1))
        if nt == 0:
            c1 = 0
        elif nt == self.num_nts-1:
            c1 = 1
        else:
            raise Exception("Crossovers can only be at the ends of an ssDNA segment")
936
937
        loc = self.get_location_at(c1, True)

cmaffeo2's avatar
cmaffeo2 committed
938
939
940
941
942
943
944
        if other_nt == 0:
            c2 = 0
        elif other_nt == other.num_nts-1:
            c2 = 1
        else:
            c2 = other.nt_pos_to_contour(other_nt)

945
946
        if isinstance(other,SingleStrandedSegment):
            ## Ensure connections occur at opposing ends
cmaffeo2's avatar
cmaffeo2 committed
947
            assert(np.isclose(other_nt,0) or np.isclose(other_nt,self.num_nts-1))
948
            other_loc = other.get_location_at( c2, True )
949
950
            # if ("22-2" in (self.name, other.name)):
            #     pdb.set_trace()
cmaffeo2's avatar
cmaffeo2 committed
951
            if nt_on_5prime:
952
953
954
955
956
957
958
                self.connect_end3( other_loc )
            else:
                other.connect_end3( self )

        else:
            assert(c2 >= 0 and c2 <= 1)
            other_loc = other.get_location_at( c2, strands_fwd[1] )
cmaffeo2's avatar
cmaffeo2 committed
959
            if nt_on_5prime:
960
961
962
                self._connect(other, Connection( loc, other_loc, type_="sscrossover" ), in_3prime_direction=True )
            else:
                other._connect(self, Connection( other_loc, loc, type_="sscrossover" ), in_3prime_direction=True )
963

964
965
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_nucleotides_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
966

967
968
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
969
970
        b = SegmentParticle( Segment.ssDNA_particle, pos, 
                             name="NAS",
971
972
973
974
                             num_nts=nts, parent=self,
                             contour_position=contour_position )
        self._add_bead(b)
        return b
975
976

    
cmaffeo2's avatar
cmaffeo2 committed
977
978
979
980
class StrandInSegment(Group):
    """ Class that holds atomic model, maps to segment """
    
    def __init__(self, segment, start, end, is_fwd):
981
        """ start/end should be provided expressed in nt coordinates, is_fwd tuples """
cmaffeo2's avatar
cmaffeo2 committed
982
983
        Group.__init__(self)
        self.num_nts = 0
984
        # self.sequence = []
cmaffeo2's avatar
cmaffeo2 committed
985
986
987
988
989
        self.segment = segment
        self.start = start
        self.end = end
        self.is_fwd = is_fwd

990
        nts = np.abs(end-start)+1
991
        self.num_nts = int(round(nts))
cmaffeo2's avatar
cmaffeo2 committed
992
        assert( np.isclose(self.num_nts,nts) )
cmaffeo2's avatar
cmaffeo2 committed
993

994
        # print(" Creating {}-nt StrandInSegment in {} from {} to {} {}".format(self.num_nts, segment.name, start, end, is_fwd))
995
996
997
998
999
1000
1001
1002
1003
    
    def _nucleotide_ids(self):
        nt0 = self.start # seg.contour_to_nt_pos(self.start)
        assert( np.abs(nt0 - round(nt0)) < 1e-5 )
        nt0 = int(round(nt0))
        assert( (self.end-self.start) >= 0 or not self.is_fwd )

        direction = (2*self.is_fwd-1)
        return range(nt0,nt0 + direction*self.num_nts, direction)
1004
1005
1006

    def get_sequence(self):
        """ return 5-to-3 """
1007
        # TODOTODO test
1008
1009
        seg = self.segment
        if self.is_fwd:
1010
            return [seg.sequence[nt] for nt in self._nucleotide_ids()]
1011
        else:
1012
            return [seqComplement[seg.sequence[nt]] for nt in self._nucleotide_ids()]
1013
1014
1015
1016
1017
    
    def get_contour_points(self):
        c0,c1 = [self.segment.nt_pos_to_contour(p) for p in (self.start,self.end)]
        return np.linspace(c0,c1,self.num_nts)
            
cmaffeo2's avatar
cmaffeo2 committed
1018
1019
class Strand(Group):
    """ Class that holds atomic model, maps to segments """
1020
    def __init__(self, segname = None):
cmaffeo2's avatar
cmaffeo2 committed
1021
1022
1023
        Group.__init__(self)
        self.num_nts = 0
        self.children = self.strand_segments = []
1024
1025
        self.segname = segname

1026
    ## TODO disambiguate names of functions
cmaffeo2's avatar
cmaffeo2 committed
1027
    def add_dna(self, segment, start, end, is_fwd):
cmaffeo2's avatar
cmaffeo2 committed
1028
        # TODOTODO use nt pos ? 
cmaffeo2's avatar
cmaffeo2 committed
1029
        """ start/end should be provided expressed as contour_length, is_fwd tuples """
1030
        if not (segment.contour_to_nt_pos(np.abs(start-end)) > 0.9):
1031
1032
1033
            print( "WARNING: segment constructed with a very small number of nts ({})".format(segment.contour_to_nt_pos(np.abs(start-end))) )
            # import pdb
            # pdb.set_trace()
1034
1035
        for s in self.strand_segments:
            if s.segment == segment and s.is_fwd == is_fwd:
1036
1037
1038
                # assert( s.start not in (start,end) )
                # assert( s.end not in (start,end) )
                if s.start in (start,end) or s.end in (start,end):
1039
                    print("  CIRCULAR DNA")
1040
1041
1042
                    import pdb
                    pdb.set_trace()

cmaffeo2's avatar
cmaffeo2 committed
1043
        s = StrandInSegment( segment, start, end, is_fwd )
1044
        self.add( s )
cmaffeo2's avatar
cmaffeo2 committed
1045
1046
        self.num_nts += s.num_nts

New Tbgl User's avatar
New Tbgl User committed
1047
    def set_sequence(self,sequence): # , set_complement=True):
1048
        ## validate input
1049
        assert( len(sequence) >= self.num_nts )
New Tbgl User's avatar
New Tbgl User committed
1050
        assert( np.all( [i in ('A','T','C','G') for i in sequence] ) )