segmentmodel.py 78.9 KB
Newer Older
cmaffeo2's avatar
cmaffeo2 committed
1
import pdb
2
import numpy as np
3
import random
4
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
5
from coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
6
7
8
9
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme

cmaffeo2's avatar
cmaffeo2 committed
10
11
from scipy.special import erf
import scipy.optimize as opt
12
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
13

cmaffeo2's avatar
cmaffeo2 committed
14
15
16
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC

17
# import pdb
18
"""
cmaffeo2's avatar
cmaffeo2 committed
19
TODO:
cmaffeo2's avatar
cmaffeo2 committed
20
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
21
 + map to atomic representation
22
23
    + add nicks
    - transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
24
25
    - shrink ssDNA
    - shrink dsDNA backbone
26
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
27
    - sequence
28
    - handle circular dna
29
 + ensure crossover bead potentials aren't applied twice 
30
 + remove performance bottlenecks
31
32
 - test for large systems
 - assign sequence
33
 - ENM
34
35
 - rework Location class 
 - remove recursive calls
36
 - document
37
 - add unit test of helices connected to themselves
38
39
"""

cmaffeo2's avatar
cmaffeo2 committed
40
41
42
class ParticleNotConnectedError(Exception):
    pass

43
44
class Location():
    """ Site for connection within an object """
45
    def __init__(self, container, address, type_, on_fwd_strand = True):
46
        ## TODO: remove cyclic references(?)
47
        self.container = container
cmaffeo2's avatar
cmaffeo2 committed
48
        self.address = address  # represents position along contour length in segment
cmaffeo2's avatar
cmaffeo2 committed
49
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
50
        self.on_fwd_strand = on_fwd_strand
51
52
        self.type_ = type_
        self.particle = None
53
        self.connection = None
54
        self.is_3prime_side_of_connection = None
55

56
57
        self.prev_in_strand = None
        self.next_in_strand = None
58
59
        
        self.combine = None     # some locations might be combined in bead model 
60
61
62
63
64
65
66

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

67
    def set_connection(self, connection, is_3prime_side_of_connection):
68
        self.connection = connection # TODO weakref? 
69
        self.is_3prime_side_of_connection = is_3prime_side_of_connection
70
71
72
73
74
75
76
77

    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
78
79
80
81
82
83
84
85
86
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
87
88
89
90
91
92
93
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
cmaffeo2's avatar
cmaffeo2 committed
94
95
96
97
98

    def __repr__(self):
        return "<Connection {}--{}--{}]>".format( self.A, self.type_, self.B )
        

99
        
100
101
102
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
103
104
105
106
    ## TODO: eliminate mutable default arguments
    def __init__(self, connection_locations=[], connections=[]):
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
107
108
        self.connections = connections

109
110
111
112
113
114
115
116
117
    def get_locations(self, type_=None, exclude=[]):
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
118
119
120
121
122
123
124
125
126
127
128
129
        return locs

    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
        loc = None
        if (self.num_nts == 1):
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
cmaffeo2's avatar
cmaffeo2 committed
130
            # assert( loc is not None )
131
132
133
134
135
136
137
138
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
        if loc is None:
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc
139
140

    def get_connections_and_locations(self, connection_type=None, exclude=[]):
141
142
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
143
        type_ = connection_type
144
145
        ret = []
        for c in self.connections:
146
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
147
                if   c.A.container is self:
148
                    ret.append( [c, c.A, c.B] )
149
                elif c.B.container is self:
150
151
                    ret.append( [c, c.B, c.A] )
                else:
152
153
                    import pdb
                    pdb.set_trace()
154
155
156
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

157
    def _connect(self, other, connection, in_3prime_direction=None):
158
159
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
160
161
162
163
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
164
        A.connection = B.connection = connection
165
166
        self.connections.append(connection)
        other.connections.append(connection)
167
168
169
170
171
172
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

173
174
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
175
176
177

class SegmentParticle(PointParticle):
    def __init__(self, type_, position, name="A", segname="A", **kwargs):
178
        self.name = name
179
180
181
182
        self.contour_position = None
        PointParticle.__init__(self, type_, position, name=name, segname=segname, **kwargs)
        self.intrahelical_neighbors = []
        self.other_neighbors = []
cmaffeo2's avatar
cmaffeo2 committed
183
        self.locations = []
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b

cmaffeo2's avatar
cmaffeo2 committed
199
200
201
202
203
    def _neighbor_should_be_added(self,b):
        c1 = self.contour_position
        c2 = b.get_contour_position(self.parent)
        if c2 < c1:
            b0 = self.get_intrahelical_below()
204
        else:
cmaffeo2's avatar
cmaffeo2 committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
            b0 = self.get_intrahelical_above()

        if b0 is not None:            
            c0 = b0.get_contour_position(self.parent)
            if np.abs(c2-c1) < np.abs(c0-c1):
                ## remove b0
                self.intrahelical_neighbors.remove(b0)
                b0.intrahelical_neighbors.remove(self)
                return True
            else:
                return False
        return True
        
    def make_intrahelical_neighbor(self,b):
        add1 = self._neighbor_should_be_added(b)
        add2 = b._neighbor_should_be_added(self)
        if add1 and add2:
            assert(len(b.intrahelical_neighbors) <= 1)
            assert(len(self.intrahelical_neighbors) <= 1)
            self.intrahelical_neighbors.append(b)
            b.intrahelical_neighbors.append(self)
226

cmaffeo2's avatar
cmaffeo2 committed
227
228
229
230
231
232
        
    # def get_nt_position(self,seg):
    #     if seg == self.parent:
    #         return seg.contour_to_nt_pos(self.contour_position)
    #     else:
    #         cl = [e for e in self.parent.get_connections_and_locations() if e[2].container is seg]
233

cmaffeo2's avatar
cmaffeo2 committed
234
235
236
237
238
    #         dc = [(self.contour_position - A.address)**2 for c,A,B in cl]

    #         if len(dc) == 0:
    #             import pdb
    #             pdb.set_trace()
239

cmaffeo2's avatar
cmaffeo2 committed
240
241
242
243
244
245
246
247
248
249
250
    #         i = np.argmin(dc)
    #         c,A,B = cl[i]
    #         ## TODO: generalize, removing np.abs and conditional 
    #         delta_nt = np.abs( A.container.contour_to_nt_pos(self.contour_position - A.address) )
    #         B_nt_pos = seg.contour_to_nt_pos(B.address)
    #         if B.address < 0.5:
    #             return B_nt_pos-delta_nt
    #         else:
    #             return B_nt_pos+delta_nt

    def get_nt_position(self,seg):
251
        if seg == self.parent:
cmaffeo2's avatar
cmaffeo2 committed
252
            return seg.contour_to_nt_pos(self.contour_position)
253
254
        else:

cmaffeo2's avatar
cmaffeo2 committed
255
256
257
258
            def get_nt_pos(contour1, seg1, seg2):
                cl = [e for e in seg1.get_connections_and_locations() if e[2].container is seg2]
                dc = [(contour1 - A.address)**2 for c,A,B in cl]
                if len(dc) == 0: return None
259

cmaffeo2's avatar
cmaffeo2 committed
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                i = np.argmin(dc)
                c,A,B = cl[i]

                ## TODO: generalize, removing np.abs and conditional 
                delta_nt = np.abs( seg1.contour_to_nt_pos(contour1 - A.address) )
                B_nt_pos = seg2.contour_to_nt_pos(B.address)
                if B.address < 0.5:
                    return B_nt_pos-delta_nt
                else:
                    return B_nt_pos+delta_nt
                
            pos = get_nt_pos(self.contour_position, self.parent, seg)
            if pos is None:
                ## Particle is not directly connected
                visited_segs = set(seg)
                positions = []
                for l in self.locations:
                    if l.container == self.parent: continue
                    pos0 = get_nt_pos(self.contour_position, self.parent, l.container)
                    assert(pos0 is not None)
                    pos0 = l.container.nt_pos_to_contour(pos0)
                    pos = get_nt_pos( pos0, l.container, seg )
                    if pos is not None:
                        positions.append( pos )
                assert( len(positions) > 0 )
                if len(positions) > 1:
                    import pdb
                    pdb.set_trace()
                pos = positions[0]
            return pos
290
291


292
293
294
295
    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
296
297
            nt_pos = self.get_nt_position(seg)
            return seg.nt_pos_to_contour(nt_pos)
298
299

## TODO break this class into smaller, better encapsulated pieces
300
301
302
303
304
305
306
307
308
309
310
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
311
312
313
314
315
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
316

cmaffeo2's avatar
cmaffeo2 committed
317
    # orientation_bond = HarmonicBond(10,2)
318
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
319
320
321
322
323
324
325
326
327
328
329
330
331

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
332
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
333

334
        self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
335
336
337
338
339
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

340
341
342
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

cmaffeo2's avatar
cmaffeo2 committed
343
        self.num_nts = int(num_nts)
344
345
346
347
348
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

349
350
351
352
        ## Set up interpolation for positions
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
353
354
        
        self.sequence = None
355

356
357
358
359
360
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None
361

362
    def contour_to_nt_pos(self, contour_pos, round_nt=False):
cmaffeo2's avatar
cmaffeo2 committed
363
        nt = contour_pos*(self.num_nts) - 0.5
364
        if round_nt:
cmaffeo2's avatar
cmaffeo2 committed
365
            assert( np.isclose(np.around(nt),nt) )
366
367
368
            nt = np.around(nt)
        return nt

369
    def nt_pos_to_contour(self,nt_pos):
cmaffeo2's avatar
cmaffeo2 committed
370
        return (nt_pos+0.5)/(self.num_nts)
371

372
373
374
375
376
377
378
    def contour_to_position(self,s):
        p = interpolate.splev( s, self.position_spline_params )
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
        t = interpolate.splev( s, self.position_spline_params, der=1 )
379
380
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
381
382
383
        

    def contour_to_orientation(self,s):
384
385
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
        orientation = None
386
387
388
389
        if self.start_orientation is not None:
            # axis = self.start_orientation.dot( np.array((0,0,1)) )
            if self.quaternion_spline_params is None:
                axis = self.contour_to_tangent(s)
390
                orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=True )
391
392
393
394
395
            else:
                q = interpolate.splev( s, self.quaternion_spline_params )
                if len(q) > 1: q = np.array(q).T # TODO: is this needed?
                orientation = quaternion_to_matrix(q)
        return orientation
396

cmaffeo2's avatar
cmaffeo2 committed
397
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
398
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
399
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
400
        return sorted(cl, key=sort_fn)
401
402
403
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
404
        # bases = ['T']        ## FOR DEBUG
405
406
407
408
409
410
411
        if self.sequence is None:
            self.sequence = [random.choice(bases) for i in range(self.num_nts)]
        else:
            assert(len(self.sequence) == self.num_nts) # TODO move
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
412

cmaffeo2's avatar
cmaffeo2 committed
413
414
415
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

416
    def _generate_one_bead(self, contour_position, nts):
417
418
        raise NotImplementedError

419
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale):
cmaffeo2's avatar
cmaffeo2 committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
        if self.local_twist:
            orientation = self.contour_to_orientation(contour_position)
            ## TODO: move this code (?)
            if orientation is None:
                axis = self.contour_to_tangent(contour_position)
                angleVec = np.array([1,0,0])
                if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
                angleVec = angleVec - angleVec.dot(axis)*axis
                angleVec = angleVec/np.linalg.norm(angleVec)
                y = np.cross(axis,angleVec)
                orientation = np.array([angleVec,y,axis]).T
436
                ## TODO: improve placement of ssDNA
cmaffeo2's avatar
cmaffeo2 committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nts, normalizeAxis=True )
                # orientation = rot.dot(orientation)
            else:
                orientation = orientation
                            
        else:
            raise NotImplementedError

        # key = self.sequence
        # if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        # if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        # if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"

        key = seq
        if not is_fwd:
            nt_dict = canonicalNtFwd
        else:
            nt_dict = canonicalNtRev
455
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
456
                        
cmaffeo2's avatar
cmaffeo2 committed
457
        atoms.orientation = orientation.dot(atoms.orientation)
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
                    a.beta = 0
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsed_position()
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
                        a.beta = 0
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
475
476

        return atoms
477

478
479
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
480
        c = self.nt_pos_to_contour(nt)
481
482
483
484
485
486
487
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
    def add_5prime(self, nt, on_fwd_strand=True):
488
        self.add_location(nt,"5prime",on_fwd_strand)
489
490

    def add_3prime(self, nt, on_fwd_strand=True):
491
        self.add_location(nt,"3prime",on_fwd_strand)
492

493
494
495
    def get_3prime_locations(self):
        return self.get_locations("3prime")
    
cmaffeo2's avatar
cmaffeo2 committed
496
    def get_5prime_locations(self):
497
498
        ## TODO? ensure that data is consistent before _build_model calls
        return self.get_locations("5prime")
cmaffeo2's avatar
cmaffeo2 committed
499

500
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
501
502
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
503
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
504
            cl = cl[::-1]
505
506
507
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
508

509
    ## TODO rename
510
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
511
        """ Walks through locations, checking for crossovers """
512
513
514
515
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0
516
517

        ## Iterate through locations
cmaffeo2's avatar
cmaffeo2 committed
518
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
519
        locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
520
521
        # print(locations)

522
        for l in locations:
cmaffeo2's avatar
cmaffeo2 committed
523
524
525
526
527
528
529
            # TODOTODO probably okay
            if l.address == 0:
                pos = 0.0
            elif l.address == 1:
                pos = self.num_nts-1
            else:
                pos = self.contour_to_nt_pos(l.address, round_nt=True)
530
531
532

            ## DEBUG

cmaffeo2's avatar
cmaffeo2 committed
533
534
535
            # import pdb
            # pdb.set_trace()

536
            ## Skip locations encountered before our strand
537
538
539
540
541
542
543
544
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue
545
546
547

            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
548
549
                print("  found end at",l)
                return pos, None, None, None, None
550
551
552
553
554
555
556
557

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
558
559
                print("  passing through",l)
                print("from {}, connection {} to {}".format(nt_pos,l,B))
cmaffeo2's avatar
cmaffeo2 committed
560
561
562
563
564
565
566
567
568
                try:
                    Bpos = B.container.contour_to_nt_pos(B.address, round_nt=True)
                except:
                    if B.address == 0:
                        Bpos = 0
                    elif B.address == 1:
                        Bpos = B.container.num_nts-1
                    else:
                        raise
569
                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
570
571
572
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
573
574
575
                if nt_pos == pos: continue
                print("  pausing at",l)
                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0
576

577
578
        import pdb
        pdb.set_trace()
579
580
581
582
583
        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None

584
585
586
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
587
        # TODO: include beads in connections?
588
589
590
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]
591
592
593

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
cmaffeo2's avatar
cmaffeo2 committed
594
        ## Assume that consecutive beads in self.beads are bonded
595
        ret = []
cmaffeo2's avatar
cmaffeo2 committed
596
597
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
598
            ret.append( tmp )
599
        return ret   
600

601
602
603
    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
604
        
605
606
607
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
608
        self.add(b)
609
610
611
612
613
614
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
615
            self.add(o)
616
617
618
619
620
621
622
623
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []
624
625
626

        if True:
            print("WARNING: DEBUG")
627
            ## Remove duplicates, preserving order
628
629
630
631
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
632
633
                else:
                    print("  duplicate particle found!")
634
635
            new_children = tmp

636
637
638
639
640
        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
641
642
643
644
645
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
646
647
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))
648

649

cmaffeo2's avatar
cmaffeo2 committed
650
    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
651

652
        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
cmaffeo2's avatar
cmaffeo2 committed
653
        ## TODO: decide whether to remove bead_model argument
654
        ##       (currently unused)
cmaffeo2's avatar
cmaffeo2 committed
655

656
        ## First find points between-which beads must be generated
657
658
659
660
661
662
663
664
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
        existing_beads = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads), key=lambda b: b.get_contour_position(self) )
        
        if len(existing_beads) != len(set(existing_beads)):
            pdb.set_trace()
665
666
667
        for b in existing_beads:
            assert(b.parent is not None)

cmaffeo2's avatar
cmaffeo2 committed
668
669
670
671
        # if self.name == "1-1":
        #     import pdb
        #     pdb.set_trace()

672
        ## Add ends if they don't exist yet
673
        ## TODOTODO: test 1 nt segments?
cmaffeo2's avatar
cmaffeo2 committed
674
675
676
        if len(existing_beads) == 0 or existing_beads[0].get_nt_position(self) > 0.5:
            # if len(existing_beads) > 0:            
            #     assert(existing_beads[0].get_nt_position(self) >= 0.5)
677
678
            b = self._generate_one_bead(0, 0)
            existing_beads = [b] + existing_beads
cmaffeo2's avatar
cmaffeo2 committed
679
680

        if existing_beads[-1].get_nt_position(self)-(self.num_nts-1) < -0.5:
681
682
683
684
685
686
            b = self._generate_one_bead(1, 0)
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
687
        last = None
688
689
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
690
            assert( eb1 is not eb2 )
691

cmaffeo2's avatar
cmaffeo2 committed
692
693
694
695
696
            # if np.isclose(eb1.position[2], eb2.position[2]):
            #     import pdb
            #     pdb.set_trace()

            print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
697
698
699
700
701
702
703
704
705
706
707
708
709
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
            ds = e_ds / (num_beads+1)
            nts = ds*self.num_nts
            eb1.num_nts += 0.5*nts
            eb2.num_nts += 0.5*nts

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
cmaffeo2's avatar
cmaffeo2 committed
710
                last.make_intrahelical_neighbor(eb1)
711
712
713
714
715
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
                b = self._generate_one_bead(s,nts)

cmaffeo2's avatar
cmaffeo2 committed
716
                last.make_intrahelical_neighbor(b)
717
718
719
                last = b
                tmp_children.append(b)

cmaffeo2's avatar
cmaffeo2 committed
720
        last.make_intrahelical_neighbor(eb2)
721
722
723
724

        if eb2.parent == self:
            tmp_children.append(eb2)
        self._rebuild_children(tmp_children)
725
726
727
728
729
730
731
732
733
734
735
736
737

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
cmaffeo2's avatar
cmaffeo2 committed
738
739
                 local_twist = False,
                 num_turns = None,
cmaffeo2's avatar
cmaffeo2 committed
740
741
                 start_orientation = None,
                 twist_persistence_length = 90 ):
cmaffeo2's avatar
cmaffeo2 committed
742
743
744
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
745
746
747
748
749
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

cmaffeo2's avatar
cmaffeo2 committed
750
751
752
753
754
755
        self.local_twist = local_twist
        if num_turns is None:
            num_turns = float(num_nts) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_nts

        if start_orientation is None:
756
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
cmaffeo2's avatar
cmaffeo2 committed
757
        self.start_orientation = start_orientation
cmaffeo2's avatar
cmaffeo2 committed
758
        self.twist_persistence_length = twist_persistence_length
cmaffeo2's avatar
cmaffeo2 committed
759

760
761
        self.nicks = []

762
        self.start = self.start5 = Location( self, address=0, type_= "end5" )
763
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )
764

765
766
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
cmaffeo2's avatar
cmaffeo2 committed
767
768
        # for l in (self.start5,self.start3,self.end3,self.end5):
        #     self.locations.append(l)
769

770
771
772
773
774
775
776
777
778
        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        ## TODO: initialize sensible spline for orientation
        self.quaternion_spline_params = None


779
    ## Convenience methods
780
    ## TODO: add errors if unrealistic connections are made
781
    ## TODO: make connections automatically between unconnected strands
782
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
783
784
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
785
786
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
787
        if isinstance(end5, SingleStrandedSegment):
788
            end5 = end5.start5
789
790
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
791
        if isinstance(end5, SingleStrandedSegment):
792
            end5 = end5.start5
793
794
        self._connect_ends( self.end3, end5, type_, force_connection = force_connection )
    def connect_end5(self, end3, type_="intrahelical", force_connection=False):
795
796
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
797
        self._connect_ends( self.end5, end3, type_, force_connection = force_connection )
798

cmaffeo2's avatar
cmaffeo2 committed
799
    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
cmaffeo2's avatar
cmaffeo2 committed
800
801
802
803
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO

804
        if isinstance(other,SingleStrandedSegment):
cmaffeo2's avatar
cmaffeo2 committed
805
            other.add_crossover(other_nt, self, nt, strands_fwd[::-1], not nt_on_5prime)
806
        else:
807

808
809
810
            ## Create locations, connections and add to segments
            c = self.nt_pos_to_contour(nt)
            assert(c >= 0 and c <= 1)
811

812
813
814
            loc = self.get_location_at(c, strands_fwd[0])

            c = other.nt_pos_to_contour(other_nt)
cmaffeo2's avatar
cmaffeo2 committed
815
            # TODOTODO: may need to subtract or add a little depending on 3prime/5prime
816
817
818
            assert(c >= 0 and c <= 1)
            other_loc = other.get_location_at(c, strands_fwd[1])
            self._connect(other, Connection( loc, other_loc, type_="crossover" ))
cmaffeo2's avatar
cmaffeo2 committed
819
820
821
822
823
824
            if nt_on_5prime:
                loc.is_3prime_side_of_connection = False
                other_loc.is_3prime_side_of_connection = True
            else:            
                loc.is_3prime_side_of_connection = True
                other_loc.is_3prime_side_of_connection = False
cmaffeo2's avatar
cmaffeo2 committed
825

826
    ## Real work
827
    def _connect_ends(self, end1, end2, type_, force_connection):
828
        ## TODO remove self?
829
830
831
832
833
        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )
834
        ## Create and add connection
cmaffeo2's avatar
cmaffeo2 committed
835
        if end2.type_ == "end5":
836
837
838
            end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ), in_3prime_direction=True )
        else:
            end2.container._connect( end1.container, Connection( end2, end1, type_=type_ ), in_3prime_direction=True )
839
840
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_basepairs_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
841

842
843
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
844
        if self.local_twist:
845
            orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
846
847
848
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
cmaffeo2's avatar
cmaffeo2 committed
849
            opos = pos + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
850
851
            o = SegmentParticle( Segment.orientation_particle, opos, nts,
                                 num_nts=nts, parent=self )
852
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
853
854
855
                                    num_nts=nts, parent=self, 
                                    orientation_bead=o,
                                    contour_position=contour_position )
cmaffeo2's avatar
cmaffeo2 committed
856
857

        else:
858
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
859
860
861
                                    num_nts=nts, parent=self,
                                    contour_position=contour_position )
        self._add_bead(bead)
cmaffeo2's avatar
cmaffeo2 committed
862
        return bead
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

879
        self.start = self.start5 = Location( self, address=0, type_= "end5" ) # TODO change type_?
880
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
cmaffeo2's avatar
cmaffeo2 committed
881
882
        # for l in (self.start5,self.end3):
        #     self.locations.append(l)
883

884
    def connect_end3(self, end5, force_connection=False):
cmaffeo2's avatar
cmaffeo2 committed
885
        self._connect_end( end5,  _5_to_3 = True, force_connection = force_connection )
886

887
    def connect_5end(self, end3, force_connection=False): # TODO: change name or possibly deprecate
cmaffeo2's avatar
cmaffeo2 committed
888
        self._connect_end( end3,  _5_to_3 = False, force_connection = force_connection )
889
890
891
892

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
cmaffeo2's avatar
cmaffeo2 committed
893
894
895
896
            my_end = self.end3
            # assert( other.type_ == "end5" )
            if (other.type_ is not "end5"):
                print("Warning: code does not prevent connecting 3prime to 3prime, etc")
897
898
            conn = Connection( my_end, other, type_="intrahelical" )
            self._connect( other.container, conn, in_3prime_direction=True )
899
        else:
cmaffeo2's avatar
cmaffeo2 committed
900
901
902
903
            my_end = self.end5
            # assert( other.type_ == "end3" )
            if (other.type_ is not "end3"):
                print("Warning: code does not prevent connecting 3prime to 3prime, etc")
904
905
            conn = Connection( other, my_end, type_="intrahelical" )
            other.container._connect( self, conn, in_3prime_direction=True )
906

cmaffeo2's avatar
cmaffeo2 committed
907
    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
908
909
910
911
912
913
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO
       
        ## TODO: fix direction

cmaffeo2's avatar
cmaffeo2 committed
914
915
916
917
918
919
920
921
922
923
924
        # c1 = self.nt_pos_to_contour(nt)
        # # TODOTODO
        # ## Ensure connections occur at ends, otherwise the structure doesn't make sense
        # # assert(np.isclose(c1,0) or np.isclose(c1,1))
        # assert(np.isclose(nt,0) or np.isclose(nt,self.num_nts-1))
        if nt == 0:
            c1 = 0
        elif nt == self.num_nts-1:
            c1 = 1
        else:
            raise Exception("Crossovers can only be at the ends of an ssDNA segment")
925
926
        loc = self.get_location_at(c1, True)

cmaffeo2's avatar
cmaffeo2 committed
927
928
929
930
931
932
933
        if other_nt == 0:
            c2 = 0
        elif other_nt == other.num_nts-1:
            c2 = 1
        else:
            c2 = other.nt_pos_to_contour(other_nt)

934
935
        if isinstance(other,SingleStrandedSegment):
            ## Ensure connections occur at opposing ends
cmaffeo2's avatar
cmaffeo2 committed
936
            assert(np.isclose(other_nt,0) or np.isclose(other_nt,self.num_nts-1))
937
            other_loc = other.get_location_at( c2, True )
cmaffeo2's avatar
cmaffeo2 committed
938
939
940
            if ("22-2" in (self.name, other.name)):
                pdb.set_trace()
            if nt_on_5prime:
941
942
943
944
945
946
947
                self.connect_end3( other_loc )
            else:
                other.connect_end3( self )

        else:
            assert(c2 >= 0 and c2 <= 1)
            other_loc = other.get_location_at( c2, strands_fwd[1] )
cmaffeo2's avatar
cmaffeo2 committed
948
            if nt_on_5prime:
949
950
951
                self._connect(other, Connection( loc, other_loc, type_="sscrossover" ), in_3prime_direction=True )
            else:
                other._connect(self, Connection( other_loc, loc, type_="sscrossover" ), in_3prime_direction=True )
952

953
954
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_nucleotides_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
955

956
957
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
958
959
        b = SegmentParticle( Segment.ssDNA_particle, pos, 
                             name="NAS",
960
961
962
963
                             num_nts=nts, parent=self,
                             contour_position=contour_position )
        self._add_bead(b)
        return b
964
965

    
cmaffeo2's avatar
cmaffeo2 committed
966
967
968
969
class StrandInSegment(Group):
    """ Class that holds atomic model, maps to segment """
    
    def __init__(self, segment, start, end, is_fwd):
970
        """ start/end should be provided expressed in nt coordinates, is_fwd tuples """
cmaffeo2's avatar
cmaffeo2 committed
971
972
        Group.__init__(self)
        self.num_nts = 0
973
        # self.sequence = []
cmaffeo2's avatar
cmaffeo2 committed
974
975
976
977
978
        self.segment = segment
        self.start = start
        self.end = end
        self.is_fwd = is_fwd

979
        nts = np.abs(end-start)+1
980
        self.num_nts = int(round(nts))
cmaffeo2's avatar
cmaffeo2 committed
981
        assert( np.isclose(self.num_nts,nts) )
cmaffeo2's avatar
cmaffeo2 committed
982

983
        # print(" Creating {}-nt StrandInSegment in {} from {} to {} {}".format(self.num_nts, segment.name, start, end, is_fwd))
984
985
986
987
988
989
990
991
992
    
    def _nucleotide_ids(self):
        nt0 = self.start # seg.contour_to_nt_pos(self.start)
        assert( np.abs(nt0 - round(nt0)) < 1e-5 )
        nt0 = int(round(nt0))
        assert( (self.end-self.start) >= 0 or not self.is_fwd )

        direction = (2*self.is_fwd-1)
        return range(nt0,nt0 + direction*self.num_nts, direction)
993
994
995

    def get_sequence(self):
        """ return 5-to-3 """
996
        # TODOTODO test
997
998
        seg = self.segment
        if self.is_fwd:
999
            return [seg.sequence[nt] for nt in self._nucleotide_ids()]
1000
        else:
For faster browsing, not all history is shown. View entire blame