segmentmodel.py 73.2 KB
Newer Older
1
import numpy as np
2
import random
3
from arbdmodel import PointParticle, ParticleType, Group, ArbdModel
4
from coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
5
6
7
8
from nonbonded import *
from copy import copy, deepcopy
from nbPot import nbDnaScheme

cmaffeo2's avatar
cmaffeo2 committed
9
10
from scipy.special import erf
import scipy.optimize as opt
11
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
12

cmaffeo2's avatar
cmaffeo2 committed
13
14
15
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC

16
# import pdb
17
"""
cmaffeo2's avatar
cmaffeo2 committed
18
TODO:
cmaffeo2's avatar
cmaffeo2 committed
19
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
20
 + map to atomic representation
21
22
    + add nicks
    - transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
23
24
    - shrink ssDNA
    - shrink dsDNA backbone
25
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
26
    - sequence
27
    - handle circular dna
28
 + ensure crossover bead potentials aren't applied twice 
29
 + remove performance bottlenecks
30
31
 - test for large systems
 - assign sequence
32
 - ENM
33
34
 - rework Location class 
 - remove recursive calls
35
 - document
36
 - add unit test of helices connected to themselves
37
38
39
40
"""

class Location():
    """ Site for connection within an object """
41
    def __init__(self, container, address, type_, on_fwd_strand = True):
42
        ## TODO: remove cyclic references(?)
43
        self.container = container
44
        self.address = address  # represents position along contour length in segments
cmaffeo2's avatar
cmaffeo2 committed
45
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
46
        self.on_fwd_strand = on_fwd_strand
47
48
        self.type_ = type_
        self.particle = None
49
        self.connection = None
50

51
52
        self.prev_in_strand = None
        self.next_in_strand = None
53
54
        
        self.combine = None     # some locations might be combined in bead model 
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

    def set_connection(self,connection):
        self.connection = connection # TODO weakref? 

    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
        # return "<Location in {} at contour {} {} with connection {}>".format( self.container.name, self.address, self.on_fwd_strand, self.connection )
        # return "<Location {} in {} at contour {} {} with connection {}>".format( self.type_, self.container.name, self.address, on_fwd, self.connection )
        return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
        
74
75
76
77
78
79
80
81
82
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
83
84
85
86
87
88
89
90
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
        
91
92
93
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
94
95
96
97
    ## TODO: eliminate mutable default arguments
    def __init__(self, connection_locations=[], connections=[]):
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
98
99
        self.connections = connections

100
101
102
103
104
105
106
107
108
109
110
111
    def get_locations(self, type_=None, exclude=[]):
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
        return locs                

    def get_connections_and_locations(self, connection_type=None, exclude=[]):
112
113
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
114
        type_ = connection_type
115
116
        ret = []
        for c in self.connections:
117
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
118
                if   c.A.container is self:
119
                    ret.append( [c, c.A, c.B] )
120
                elif c.B.container is self:
121
122
                    ret.append( [c, c.B, c.A] )
                else:
123
124
                    import pdb
                    pdb.set_trace()
125
126
127
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

128
    def _connect(self, other, connection):
129
130
131
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
        A.connection = B.connection = connection
132
133
        self.connections.append(connection)
        other.connections.append(connection)
134
135
136
137
138
139
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

140
141
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
142
143
144

class SegmentParticle(PointParticle):
    def __init__(self, type_, position, name="A", segname="A", **kwargs):
145
        self.name = name
146
147
148
149
150
        self.contour_position = None
        PointParticle.__init__(self, type_, position, name=name, segname=segname, **kwargs)
        self.intrahelical_neighbors = []
        self.other_neighbors = []

151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    # def get_contour_position(self,seg):
    #     assert( isinstance(seg,Segment) )
    #     if seg == self.parent:
    #         return self.contour_position
    #     else:
    #         ## TODO replace with something more elegant
    #         for c,A,B in self.parent.get_connections_and_locations():
    #             if A.particle is self and B.container is seg:
    #                 nt = np.abs( (self.contour_position - A.address)*(A.container.num_nts-1) )
    #                 if B.address < 0.5:
    #                     return B.address-nt/(seg.num_nts-1)
    #                 else:
    #                     return B.address+nt/(seg.num_nts-1)
    #         ## ERROR
    #         print("")
    #         for c,A,B in self.parent.get_connections_and_locations():
    #             print("  ",c.type_)
    #             print(A,B)
    #             print(A.particle,self)
    #             print(B.container,seg)
    #         print("")
    #         import pdb
    #         pdb.set_trace()
    #         raise Exception("Did not find location for particle {} in Segment {}".format(self,seg))

    def get_intrahelical_above(self):
        """ Returns bead directly above self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) > self.contour_position:
                return b

    def get_intrahelical_below(self):
        """ Returns bead directly below self """
        assert( len(self.intrahelical_neighbors) <= 2 )
        for b in self.intrahelical_neighbors:
            if b.get_contour_position(self.parent) < self.contour_position:
                return b
        

    def get_nt_position(self,seg):
        if seg == self.parent:
            return seg.contour_to_nt_pos(self.contour_position)
        else:
cmaffeo2's avatar
cmaffeo2 committed
195
196
            cl = [e for e in self.parent.get_connections_and_locations() if e[2].container is seg]
            dc = [(self.contour_position - A.address)**2 for c,A,B in cl]
197
198
199
200
201

            if len(dc) == 0:
                pdb.set_trace()

            i = np.argmin(dc)
cmaffeo2's avatar
cmaffeo2 committed
202
            c,A,B = cl[i]
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
            ## TODO: generalize, removing np.abs and conditional 
            delta_nt = np.abs( A.container.contour_to_nt_pos(self.contour_position - A.address) )
            B_nt_pos = seg.contour_to_nt_pos(B.address)
            if B.address < 0.5:
                return B_nt_pos-delta_nt
            else:
                return B_nt_pos+delta_nt

    def get_contour_position_old(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
            cl = [e for e in self.parent.get_connections_and_locations() in B.container is seg]
            dc = [(self.contour_position - A.address)**2 for c,A,B in e]

            if len(dc) == 0:
                pdb.set_trace()

            i = np.argmin(dc)

            nt = np.abs( (self.contour_position - A.address)*(A.container.num_nts-1) )
            if B.address < 0.5:
                return seg.nt_pos_to_contour(B.address-nt)
            else:
                return seg.nt_pos_to_contour(B.address+nt)

229
230
231
232
    def get_contour_position(self,seg):
        if seg == self.parent:
            return self.contour_position
        else:
233
234
            nt_pos = self.get_nt_position(seg)
            return seg.nt_pos_to_contour(nt_pos)
235
236

## TODO break this class into smaller, better encapsulated pieces
237
238
239
240
241
242
243
244
245
246
247
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
248
249
250
251
252
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
253

cmaffeo2's avatar
cmaffeo2 committed
254
    # orientation_bond = HarmonicBond(10,2)
255
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
256
257
258
259
260
261
262
263
264
265
266
267
268

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

    def __init__(self, name, num_nts, 
                 start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        Group.__init__(self, name, children=[])
269
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
270

271
        self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
272
273
274
275
276
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

277
278
279
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

cmaffeo2's avatar
cmaffeo2 committed
280
        self.num_nts = int(num_nts)
281
282
283
284
285
        if end_position is None:
            end_position = np.array((0,0,self.distance_per_nt*num_nts)) + start_position
        self.start_position = start_position
        self.end_position = end_position

286
287
288
289
        ## Set up interpolation for positions
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
290
291
        
        self.sequence = None
292

293
294
295
296
297
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
        for c,loc,other in self.get_connections_and_locations():
            loc.particle = None
298

299
300
301
302
303
    def contour_to_nt_pos(self,contour_pos):
        return contour_pos*(self.num_nts-1)
    def nt_pos_to_contour(self,nt_pos):
        return nt_pos/(self.num_nts-1)

304
305
306
307
308
309
310
    def contour_to_position(self,s):
        p = interpolate.splev( s, self.position_spline_params )
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
        t = interpolate.splev( s, self.position_spline_params, der=1 )
311
312
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
313
314
315
        

    def contour_to_orientation(self,s):
316
317
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
        orientation = None
318
319
320
321
        if self.start_orientation is not None:
            # axis = self.start_orientation.dot( np.array((0,0,1)) )
            if self.quaternion_spline_params is None:
                axis = self.contour_to_tangent(s)
322
                orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=True )
323
324
325
326
327
            else:
                q = interpolate.splev( s, self.quaternion_spline_params )
                if len(q) > 1: q = np.array(q).T # TODO: is this needed?
                orientation = quaternion_to_matrix(q)
        return orientation
328

cmaffeo2's avatar
cmaffeo2 committed
329
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
330
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
331
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
332
        return sorted(cl, key=sort_fn)
333
334
335
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
336
        bases = ['T']        ## FOR DEBUG
337
338
339
340
341
342
343
        if self.sequence is None:
            self.sequence = [random.choice(bases) for i in range(self.num_nts)]
        else:
            assert(len(self.sequence) == self.num_nts) # TODO move
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
344

cmaffeo2's avatar
cmaffeo2 committed
345
346
347
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

348
    def _generate_one_bead(self, contour_position, nts):
349
350
        raise NotImplementedError

351
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale):
cmaffeo2's avatar
cmaffeo2 committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
        if self.local_twist:
            orientation = self.contour_to_orientation(contour_position)
            ## TODO: move this code (?)
            if orientation is None:
                axis = self.contour_to_tangent(contour_position)
                angleVec = np.array([1,0,0])
                if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
                angleVec = angleVec - angleVec.dot(axis)*axis
                angleVec = angleVec/np.linalg.norm(angleVec)
                y = np.cross(axis,angleVec)
                orientation = np.array([angleVec,y,axis]).T
368
                ## TODO: improve placement of ssDNA
cmaffeo2's avatar
cmaffeo2 committed
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
                # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nts, normalizeAxis=True )
                # orientation = rot.dot(orientation)
            else:
                orientation = orientation
                            
        else:
            raise NotImplementedError

        # key = self.sequence
        # if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        # if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        # if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"

        key = seq
        if not is_fwd:
            nt_dict = canonicalNtFwd
        else:
            nt_dict = canonicalNtRev
387
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
388
                        
cmaffeo2's avatar
cmaffeo2 committed
389
        atoms.orientation = orientation.dot(atoms.orientation)
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
                    a.beta = 0
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsed_position()
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
                        a.beta = 0
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
407
408

        return atoms
409

410
411
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
412
        c = self.nt_pos_to_contour(nt)
413
414
415
416
417
418
419
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
    def add_5prime(self, nt, on_fwd_strand=True):
420
        self.add_location(nt,"5prime",on_fwd_strand)
421
422

    def add_3prime(self, nt, on_fwd_strand=True):
423
        self.add_location(nt,"3prime",on_fwd_strand)
424

425
426
427
    def get_3prime_locations(self):
        return self.get_locations("3prime")
    
cmaffeo2's avatar
cmaffeo2 committed
428
    def get_5prime_locations(self):
429
430
        ## TODO? ensure that data is consistent before _build_model calls
        return self.get_locations("5prime")
cmaffeo2's avatar
cmaffeo2 committed
431

432
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
433
434
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
435
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
436
            cl = cl[::-1]
437
438
439
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
440

441
    ## TODO rename
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    def get_strand_segment(self, nt_pos, is_fwd):
        """ Walks through locations, checking for crossovers """

        ## Iterate through locations
        locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        for l in locations:
            pos = self.contour_to_nt_pos(l.address)

            ## DEBUG

            ## Skip locations encountered before our strand
            tol = 0.1
            if is_fwd:
                if pos-nt_pos <= tol: continue 
            elif pos-nt_pos >= -tol: continue

            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
                return pos, None, None, None

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
                # print("  passing through",l)
                # print("from {}, connection {} to {}".format(contour_pos,l,B))
                Bpos = B.container.contour_to_nt_pos(B.address)
                return pos, B.container, Bpos, B.on_fwd_strand
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
                # print("  pausing at",l)
                return pos, l.container, pos+(2*is_fwd-1), is_fwd

        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None


    def get_end_of_strand_old(self, contour_pos, is_fwd):
486
487
488
489
490
        """ Walks through locations, checking for crossovers """

        ## Iterate through locations
        # for l in self.locations:
        def loc_iter():
491
492
493
494
495
496
497
            locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
            # if is_fwd:
            for l in locations:
                yield l
            # else:
            #     for l in locations[::-1]:
            #         yield l
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            
        for l in loc_iter():
            # if l.particle is None:
            #     pos = l.address
            # else:
            #     pos = l.particle.get_contour_position()          
            pos = l.address

            ## DEBUG
            # if self.name == "1-0" and is_fwd == False:
            #     import pdb
            #     pdb.set_trace()

            ## Skip locations encountered before our strand
cmaffeo2's avatar
cmaffeo2 committed
512
513
514
515
            if is_fwd:
                if pos <= contour_pos: continue
            elif pos >= contour_pos: continue

516
517
            # print("  ?",l)
            
518
519
            ## Stop if we found the 3prime end
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
520
                return pos, None, None, None
521
522
523
524
525
526
527
528

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
529
                # print("  passing through",l)
530
531
532
533
534
                # print("from {}, connection {} to {}".format(contour_pos,l,B))
                return pos, B.container, B.address, B.on_fwd_strand
                
            ## Stop at other strand crossovers so basepairs line up
            elif c.type_ == "crossover":
535
                # print("  pausing at",l)
536
537
538
                # print("pausing at {}".format(l))
                return pos, l.container, pos, is_fwd

539
540
        raise Exception("Shouldn't be here")
        # print("Shouldn't be here")
541
542
        ## Made it to the end of the segment without finding a connection
        return 1*is_fwd, None, None, None
cmaffeo2's avatar
cmaffeo2 committed
543
        
544
545
546
    def get_nearest_bead(self, contour_position):
        if len(self.beads) < 1: return None
        cs = np.array([b.contour_position for b in self.beads]) # TODO: cache
547
        # TODO: include beads in connections?
548
549
550
        i = np.argmin((cs - contour_position)**2)

        return self.beads[i]
551
552
553

    def get_all_consecutive_beads(self, number):
        assert(number >= 1)
cmaffeo2's avatar
cmaffeo2 committed
554
        ## Assume that consecutive beads in self.beads are bonded
555
        ret = []
cmaffeo2's avatar
cmaffeo2 committed
556
557
        for i in range(len(self.beads)-number+1):
            tmp = [self.beads[i+j] for j in range(0,number)]
558
            ret.append( tmp )
559
        return ret   
560

561
562
563
    def _add_bead(self,b,set_contour=False):
        if set_contour:
            b.contour_position = b.get_contour_position(self)
564
        
565
566
567
        # assert(b.parent is None)
        if b.parent is not None:
            b.parent.children.remove(b)
568
        self.add(b)
569
570
571
572
573
574
        self.beads.append(b) # don't add orientation bead
        if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
            o = b.orientation_bead
            o.contour_position = b.contour_position
            if o.parent is not None:
                o.parent.children.remove(o)
575
            self.add(o)
576
577
578
579
580
581
582
583
            self.add_bond(b,o, Segment.orientation_bond, exclude=True)

    def _rebuild_children(self, new_children):
        # print("_rebuild_children on %s" % self.name)
        old_children = self.children
        old_beads = self.beads
        self.children = []
        self.beads = []
584
585
586
587
588
589
590
591
592

        if True:
            print("WARNING: DEBUG")
            tmp = []
            for c in new_children:
                if c not in tmp:
                    tmp.append(c)
            new_children = tmp

593
594
595
596
597
        for b in new_children:
            self.beads.append(b)
            self.children.append(b)
            if "orientation_bead" in b.__dict__: # TODO: think of a cleaner approach
                self.children.append(b.orientation_bead)
598
599
600
601
602
            
        # tmp = [c for c in self.children if c not in old_children]
        # assert(len(tmp) == 0)
        # tmp = [c for c in old_children if c not in self.children]
        # assert(len(tmp) == 0)
603
604
        assert(len(old_children) == len(self.children))
        assert(len(old_beads) == len(self.beads))
605

606

cmaffeo2's avatar
cmaffeo2 committed
607
    def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
608

609
        """ Generate beads (positions, types, etcl) and bonds, angles, dihedrals, exclusions """
cmaffeo2's avatar
cmaffeo2 committed
610
        ## TODO: decide whether to remove bead_model argument
611
        ##       (currently unused)
cmaffeo2's avatar
cmaffeo2 committed
612

613
        ## First find points between-which beads must be generated
614
615
616
617
618
619
620
621
622
        # conn_locs = self.get_contour_sorted_connections_and_locations()
        # locs = [A for c,A,B in conn_locs]
        # existing_beads = [l.particle for l in locs if l.particle is not None]
        existing_beads = {l.particle for l in self.locations if l.particle is not None}
        existing_beads = sorted( list(existing_beads), key=lambda b: b.get_contour_position(self) )

        
        if len(existing_beads) != len(set(existing_beads)):
            pdb.set_trace()
623
624
625
626
        for b in existing_beads:
            assert(b.parent is not None)

        ## Add ends if they don't exist yet
627
        ## TODOTODO: test 1 nt segments?
628
        if len(existing_beads) == 0 or existing_beads[0].get_contour_position(self) > 0:
629
630
631
            if len(existing_beads) > 0:            
                assert(existing_beads[0].get_nt_position(self) > 1.5)

632
633
634
            b = self._generate_one_bead(0, 0)
            existing_beads = [b] + existing_beads
        if existing_beads[-1].get_contour_position(self) < 1:
635
636
            # assert((1-existing_beads[0].get_contour_position(self))*(self.num_nts-1) > 1.5)
            assert(self.num_nts-1-existing_beads[0].get_nt_position(self) > 1.5)
637
638
639
640
641
642
            b = self._generate_one_bead(1, 0)
            existing_beads.append(b)
        assert(len(existing_beads) > 1)

        ## Walk through existing_beads, add beads between
        tmp_children = []       # build list of children in nice order
643
        last = None
644
645
        for I in range(len(existing_beads)-1):
            eb1,eb2 = [existing_beads[i] for i in (I,I+1)]
646
647
648
            if eb1 is eb2:
                pdb.set_trace()
            assert( eb1 is not eb2 )
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665

            # print(" %s working on %d to %d" % (self.name, eb1.position[2], eb2.position[2]))
            e_ds = eb2.get_contour_position(self) - eb1.get_contour_position(self)
            num_beads = self._get_num_beads( e_ds, max_basepairs_per_bead, max_nucleotides_per_bead )
            ds = e_ds / (num_beads+1)
            nts = ds*self.num_nts
            eb1.num_nts += 0.5*nts
            eb2.num_nts += 0.5*nts

            ## Add beads
            if eb1.parent == self:
                tmp_children.append(eb1)

            s0 = eb1.get_contour_position(self)
            if last is not None:
                last.intrahelical_neighbors.append(eb1)
                eb1.intrahelical_neighbors.append(last)
666
667
                assert(len(last.intrahelical_neighbors) <= 2)
                assert(len(eb1.intrahelical_neighbors) <= 2)
668
669
670
671
672
673
674
            last = eb1
            for j in range(num_beads):
                s = ds*(j+1) + s0
                b = self._generate_one_bead(s,nts)

                last.intrahelical_neighbors.append(b)
                b.intrahelical_neighbors.append(last)
675
676
                assert(len(last.intrahelical_neighbors) <= 2)
                assert(len(b.intrahelical_neighbors) <= 2)
677
678
679
680
681
                last = b
                tmp_children.append(b)

        last.intrahelical_neighbors.append(eb2)
        eb2.intrahelical_neighbors.append(last)
682
683
        assert(len(last.intrahelical_neighbors) <= 2)
        assert(len(eb2.intrahelical_neighbors) <= 2)
684
685
686
687

        if eb2.parent == self:
            tmp_children.append(eb2)
        self._rebuild_children(tmp_children)
688
689
690
691
692
693
694
695
696
697
698
699
700

    def _regenerate_beads(self, max_nts_per_bead=4, ):
        ...
    

class DoubleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None,
cmaffeo2's avatar
cmaffeo2 committed
701
702
                 local_twist = False,
                 num_turns = None,
cmaffeo2's avatar
cmaffeo2 committed
703
704
                 start_orientation = None,
                 twist_persistence_length = 90 ):
cmaffeo2's avatar
cmaffeo2 committed
705
706
707
        
        self.helical_rise = 10.44
        self.distance_per_nt = 3.4
708
709
710
711
712
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

cmaffeo2's avatar
cmaffeo2 committed
713
714
715
716
717
718
        self.local_twist = local_twist
        if num_turns is None:
            num_turns = float(num_nts) / self.helical_rise
        self.twist_per_nt = float(360 * num_turns) / num_nts

        if start_orientation is None:
719
            start_orientation = np.eye(3) # np.array(((1,0,0),(0,1,0),(0,0,1)))
cmaffeo2's avatar
cmaffeo2 committed
720
        self.start_orientation = start_orientation
cmaffeo2's avatar
cmaffeo2 committed
721
        self.twist_persistence_length = twist_persistence_length
cmaffeo2's avatar
cmaffeo2 committed
722

723
724
        self.nicks = []

725
        self.start = self.start5 = Location( self, address=0, type_= "end5" )
726
        self.start3 = Location( self, address=0, type_ = "end3", on_fwd_strand=False )
727

728
729
730
731
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
        self.end5 = Location( self, address=1, type_= "end5", on_fwd_strand=False )
        for l in (self.start5,self.start3,self.end3,self.end5):
            self.locations.append(l)
732

733
734
735
736
737
738
739
740
741
        ## Set up interpolation for azimuthal angles 
        a = np.array([self.start_position,self.end_position]).T
        tck, u = interpolate.splprep( a, u=[0,1], s=0, k=1)
        self.position_spline_params = tck
        
        ## TODO: initialize sensible spline for orientation
        self.quaternion_spline_params = None


742
    ## Convenience methods
743
    ## TODO: add errors if unrealistic connections are made
744
    ## TODO: make connections automatically between unconnected strands
745
    def connect_start5(self, end3, type_="intrahelical", force_connection=False):
746
747
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
748
749
        self._connect_ends( self.start5, end3, type_, force_connection = force_connection )
    def connect_start3(self, end5, type_="intrahelical", force_connection=False):
750
        if isinstance(end5, SingleStrandedSegment):
751
            end5 = end5.start5
752
753
        self._connect_ends( self.start3, end5, type_, force_connection = force_connection )
    def connect_end3(self, end5, type_="intrahelical", force_connection=False):
754
        if isinstance(end5, SingleStrandedSegment):
755
            end5 = end5.start5
756
757
        self._connect_ends( self.end3, end5, type_, force_connection = force_connection )
    def connect_end5(self, end3, type_="intrahelical", force_connection=False):
758
759
        if isinstance(end3, SingleStrandedSegment):
            end3 = end3.end3
760
        self._connect_ends( self.end5, end3, type_, force_connection = force_connection )
761

762
    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False]):
cmaffeo2's avatar
cmaffeo2 committed
763
764
765
766
767
        """ Add a crossover between two helices """
        ## Validate other, nt, other_nt
        ##   TODO

        ## Create locations, connections and add to segments
768
        c = self.nt_pos_to_contour(nt)
cmaffeo2's avatar
cmaffeo2 committed
769
        assert(c >= 0 and c <= 1)
770
771
772
773
774
775
776
777
778
779
780
781
782
        
        def get_loc(seg, address, on_fwd_strand):
            loc = None
            for l in seg.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
            if loc is None:
                loc = Location( seg, address=address, type_="crossover", on_fwd_strand=on_fwd_strand )
            return loc

        loc = get_loc(self, c, strands_fwd[0])

783
        c = other.nt_pos_to_contour(other_nt)
cmaffeo2's avatar
cmaffeo2 committed
784
        assert(c >= 0 and c <= 1)
785
        other_loc = get_loc( other, c, strands_fwd[1] )
cmaffeo2's avatar
cmaffeo2 committed
786
        self._connect(other, Connection( loc, other_loc, type_="crossover" ))
cmaffeo2's avatar
cmaffeo2 committed
787

788
    ## Real work
789
    def _connect_ends(self, end1, end2, type_, force_connection):
790
        ## TODO remove self?
791
792
793
794
795
        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )
796
        ## Create and add connection
797
        end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ) )
798

799
800
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_basepairs_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
801

802
803
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
804
        if self.local_twist:
805
            orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
806
807
808
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
cmaffeo2's avatar
cmaffeo2 committed
809
            opos = pos + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
810
811
            o = SegmentParticle( Segment.orientation_particle, opos, nts,
                                 num_nts=nts, parent=self )
812
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
813
814
815
                                    num_nts=nts, parent=self, 
                                    orientation_bead=o,
                                    contour_position=contour_position )
cmaffeo2's avatar
cmaffeo2 committed
816
817

        else:
818
            bead = SegmentParticle( Segment.dsDNA_particle, pos, name="DNA",
819
820
821
                                    num_nts=nts, parent=self,
                                    contour_position=contour_position )
        self._add_bead(bead)
cmaffeo2's avatar
cmaffeo2 committed
822
        return bead
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        

class SingleStrandedSegment(Segment):

    """ Class that describes a segment of ssDNA. When built from
    cadnano models, should not span helices """

    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
                 end_position = None, 
                 segment_model = None):

        self.distance_per_nt = 5
        Segment.__init__(self, name, num_nts, 
                         start_position,
                         end_position, 
                         segment_model)

840
        self.start = self.start5 = Location( self, address=0, type_= "end5" ) # TODO change type_?
841
        self.end = self.end3 = Location( self, address=1, type_ = "end3" )
842
843
        for l in (self.start5,self.end3):
            self.locations.append(l)
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861

    def connect_3end(self, end5, force_connection=False):
        self._connect_end( end5,  _5_to_3 = False, force_connection = force_connection )

    def connect_5end(self, end3, force_connection=False):
        self._connect_end( end3,  _5_to_3 = True, force_connection = force_connection )

    def _connect_end(self, other, _5_to_3, force_connection):
        assert( isinstance(other, Location) )
        if _5_to_3 == True:
            my_end = self.end5
            assert( other.type_ == "end3" )
        else:
            my_end = self.end3
            assert( other.type_ == "end5" )

        self._connect( other.container, Connection( my_end, other, type_="intrahelical" ) )

862
863
    def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
        return int(contour*self.num_nts // max_nucleotides_per_bead)
cmaffeo2's avatar
cmaffeo2 committed
864

865
866
    def _generate_one_bead(self, contour_position, nts):
        pos = self.contour_to_position(contour_position)
867
868
        b = SegmentParticle( Segment.ssDNA_particle, pos, 
                             name="NAS",
869
870
871
872
                             num_nts=nts, parent=self,
                             contour_position=contour_position )
        self._add_bead(b)
        return b
873
874

    
cmaffeo2's avatar
cmaffeo2 committed
875
876
877
878
class StrandInSegment(Group):
    """ Class that holds atomic model, maps to segment """
    
    def __init__(self, segment, start, end, is_fwd):
879
        """ start/end should be provided expressed in nt coordinates, is_fwd tuples """
cmaffeo2's avatar
cmaffeo2 committed
880
881
        Group.__init__(self)
        self.num_nts = 0
882
        # self.sequence = []
cmaffeo2's avatar
cmaffeo2 committed
883
884
885
886
887
        self.segment = segment
        self.start = start
        self.end = end
        self.is_fwd = is_fwd

888
        nts = np.abs(end-start)+1
889
        self.num_nts = int(round(nts))
890
        assert( np.abs(self.num_nts-nts) < 1e-5 )
cmaffeo2's avatar
cmaffeo2 committed
891

892
893
894
895
        # print(" Creating {}-nt StrandInSegment in {} from {} to {} {}".format(self.num_nts, segment.name, start, end, is_fwd))

    def get_sequence(self):
        """ return 5-to-3 """
896
        # TODOTODO test
897
        seg = self.segment
898
899
900
        nt0 = self.start # seg.contour_to_nt_pos(self.start)
        assert( np.abs(nt0 - round(nt0)) < 1e-5 )
        nt0 = int(round(nt0))
901
902
        assert( (self.end-self.start) >= 0 or not self.is_fwd )
        if self.is_fwd:
903
904
            # return [seqComplement[seg.sequence[nt]] for nt in range(nt0,nt0+self.num_nts)]
            return [seg.sequence[nt] for nt in range(nt0,nt0+self.num_nts)]
905
        else:
906
907
908
909
910
911
            return [seqComplement[seg.sequence[nt]] for nt in range(nt0,nt0-self.num_nts,-1)]
    
    def get_contour_points(self):
        c0,c1 = [self.segment.nt_pos_to_contour(p) for p in (self.start,self.end)]
        return np.linspace(c0,c1,self.num_nts)
            
cmaffeo2's avatar
cmaffeo2 committed
912
913
class Strand(Group):
    """ Class that holds atomic model, maps to segments """
914
    def __init__(self, segname = None):
cmaffeo2's avatar
cmaffeo2 committed
915
916
917
        Group.__init__(self)
        self.num_nts = 0
        self.children = self.strand_segments = []
918
919
        self.segname = segname

920
    ## TODO disambiguate names of functions
cmaffeo2's avatar
cmaffeo2 committed
921
922
    def add_dna(self, segment, start, end, is_fwd):
        """ start/end should be provided expressed as contour_length, is_fwd tuples """
923
924
925
926
927
928
        if not (segment.contour_to_nt_pos(np.abs(start-end)) > 0.9):
            pdb.set_trace()
        for s in self.strand_segments:
            if s.segment == segment and s.is_fwd == is_fwd:
                assert( s.start not in (start,end) )
                assert( s.end not in (start,end) )
cmaffeo2's avatar
cmaffeo2 committed
929
        s = StrandInSegment( segment, start, end, is_fwd )
930
        self.add( s )
cmaffeo2's avatar
cmaffeo2 committed
931
932
        self.num_nts += s.num_nts

933
934
935
936
937
938
939
    def set_sequence(self,sequence):
        ## validate input
        assert( np.all( [i in ('A','T','C','G') for i in sequence] ) )

        ## set sequence on each segment
        for s in self.children:
            seg = s.segment
cmaffeo2's avatar
cmaffeo2 committed
940
            # TODOTODO
941
942
            ...

943
944
945
946
947
948
949
950
951
952
953
954
955
956
        ...

    # def get_sequence(self):
    #     sequence = []
    #     for ss in self.strand_segments:
    #         sequence.extend( ss.get_sequence() )

    #     assert( len(sequence) >= self.num_nts )
    #     ret = ["5"+sequence[0]] +\
    #           sequence[1:-1] +\
    #           [sequence[-1]+"3"]
    #     assert( len(ret) == self.num_nts )
    #     return ret

957
    def generate_atomic_model(self,scale):
cmaffeo2's avatar
cmaffeo2 committed
958
959
        last = None
        resid = 1
cmaffeo2's avatar
cmaffeo2 committed
960
        strand_segment_count = 0
cmaffeo2's avatar
cmaffeo2 committed
961
        for s in self.strand_segments:
cmaffeo2's avatar
cmaffeo2 committed
962
            strand_segment_count += 1
cmaffeo2's avatar
cmaffeo2 committed
963
            seg = s.segment
964
965
966
            contour = s.get_contour_points()
            assert(s.end != s.start)
            assert(np.linalg.norm( seg.contour_to_position(contour[-1]) - seg.contour_to_position(contour[0]) ) > 0.1)
967
            for c,seq in zip(contour,s.get_sequence()):
cmaffeo2's avatar
cmaffeo2 committed
968
969
970
971
972
                if last is None:
                    seq = "5"+seq
                if strand_segment_count == len(s.strand_segments) and c == 1:
                    seq = seq+"3"

973
                nt = seg._generate_atomic_nucleotide( c, s.is_fwd, seq, scale )
974
975
976
                # if s.is_fwd:                    
                # else:
                #     nt = seg._generate_atomic_nucleotide( c, s.is_fwd, "A" )
cmaffeo2's avatar
cmaffeo2 committed
977

978
                s.add(nt)
cmaffeo2's avatar
cmaffeo2 committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
                ## Join last basepairs
                if last is not None:
                    o3,c3,c4,c2,h3 = [last.atoms_by_name[n] 
                                      for n in ("O3'","C3'","C4'","C2'","H3'")]
                    p,o5,o1,o2,c5 = [nt.atoms_by_name[n] 
                                     for n in ("P","O5'","O1P","O2P","C5'")]
                    self.add_bond( o3, p, None )
                    self.add_angle( c3, o3, p, None )
                    for x in (o5,o1,o2):
                        self.add_angle( o3, p, x, None )
                        self.add_dihedral(c3, o3, p, x, None )
                    for x in (c4,c2,h3):
                        self.add_dihedral(x, c3, o3, p, None )
                    self.add_dihedral(o3, p, o5, c5, None)
                nt.__dict__['resid'] = resid
                resid += 1
                last = nt
996

997
998
999
1000
    def update_atomic_orientations(self,default_orientation):
        last = None
        resid = 1
        for s in self.strand_segments:
For faster browsing, not all history is shown. View entire blame