segmentmodel.py 165 KB
Newer Older
1
import logging
cmaffeo2's avatar
cmaffeo2 committed
2
import pdb
3
import subprocess
4
from pathlib import Path
5
import numpy as np
6
import random
7
8
9
from .arbdmodel import PointParticle, ParticleType, Group, ArbdModel
from .arbdmodel.coords import rotationAboutAxis, quaternion_from_matrix, quaternion_to_matrix
from .arbdmodel.nonbonded import *
10
from copy import copy, deepcopy
11
from .model.nbPot import nbDnaScheme
12

13
from . import get_resource_path
14
from . import logger, devlogger
15

16
17
import re

cmaffeo2's avatar
cmaffeo2 committed
18
19
from scipy.special import erf
import scipy.optimize as opt
20
from scipy import interpolate
cmaffeo2's avatar
cmaffeo2 committed
21

22
23
from .model.CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from .model.CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
cmaffeo2's avatar
cmaffeo2 committed
24

cmaffeo2's avatar
cmaffeo2 committed
25
from .model.spring_from_lp import k_angle as angle_spring_from_lp
26
from .model.spring_from_lp import k_twist as twist_spring_from_lp
cmaffeo2's avatar
cmaffeo2 committed
27

28
29
import csv

30
# import pdb
31
"""
cmaffeo2's avatar
cmaffeo2 committed
32
TODO:
cmaffeo2's avatar
cmaffeo2 committed
33
 + fix handling of crossovers for atomic representation
cmaffeo2's avatar
cmaffeo2 committed
34
 + map to atomic representation
35
    + add nicks
cmaffeo2's avatar
cmaffeo2 committed
36
    + transform ssDNA nucleotides 
cmaffeo2's avatar
cmaffeo2 committed
37
    - shrink ssDNA
cmaffeo2's avatar
cmaffeo2 committed
38
    + shrink dsDNA backbone
39
    + make orientation continuous
cmaffeo2's avatar
cmaffeo2 committed
40
    + sequence
41
    + handle circular dna
42
 + ensure crossover bead potentials aren't applied twice 
43
 + remove performance bottlenecks
44
 - test for large systems
cmaffeo2's avatar
cmaffeo2 committed
45
 + assign sequence
46
 + ENM
47
48
 - rework Location class 
 - remove recursive calls
49
 - document
50
 - develop unit test suite
51
52
 - refactor parts of Segment into an abstract_polymer class
 - make each call generate_bead_model, generate_atomic_model, generate_oxdna_model return an object with only have a reference to original object
53
"""
54
55
56

_DEBUG_TRACE = False

cmaffeo2's avatar
cmaffeo2 committed
57
58
class CircularDnaError(Exception):
    pass
59

cmaffeo2's avatar
cmaffeo2 committed
60
61
62
class ParticleNotConnectedError(Exception):
    pass

63
64
class Location():
    """ Site for connection within an object """
65
    def __init__(self, container, address, type_, on_fwd_strand = True):
66
        ## TODO: remove cyclic references(?)
67
        self.container = container
cmaffeo2's avatar
cmaffeo2 committed
68
        self.address = address  # represents position along contour length in segment
cmaffeo2's avatar
cmaffeo2 committed
69
        # assert( type_ in ("end3","end5") ) # TODO remove or make conditional
70
        self.on_fwd_strand = on_fwd_strand
71
72
        self.type_ = type_
        self.particle = None
73
        self.connection = None
74
        self.is_3prime_side_of_connection = None
75

76
        self.combine = None     # some locations might be combined in bead model 
77
78
79
80
81
82
83

    def get_connected_location(self):
        if self.connection is None:
            return None
        else:
            return self.connection.other(self)

84
    def set_connection(self, connection, is_3prime_side_of_connection):
85
        self.connection = connection # TODO weakref? 
86
        self.is_3prime_side_of_connection = is_3prime_side_of_connection
87

88
89
90
91
92
93
94
    def get_nt_pos(self):
        try:
            pos = self.container.contour_to_nt_pos(self.address, round_nt=True)
        except:
            if self.address == 0:
                pos = 0
            elif self.address == 1:
95
                pos = self.container.num_nt-1
96
97
98
99
            else:
                raise
        return pos

100
101
102
103
104
105
    def delete(self):
        if self.connection is not None:
            self.connection.delete()
        if self.container is not None:
            self.container.locations.remove(self)

106
107
108
109
110
    def __repr__(self):
        if self.on_fwd_strand:
            on_fwd = "on_fwd_strand"
        else:
            on_fwd = "on_rev_strand"
111
112
113
114
        try:
            return "<Location {}.{}[{:d},{}]>".format( self.container.name, self.type_, self.get_nt_pos(), on_fwd )
        except:
            return "<Location {}.{}[{:.2f},{:d}]>".format( self.container.name, self.type_, self.address, self.on_fwd_strand)
115
        
116
117
118
119
120
121
122
123
124
class Connection():
    """ Abstract base class for connection between two elements """
    def __init__(self, A, B, type_ = None):
        assert( isinstance(A,Location) )
        assert( isinstance(B,Location) )
        self.A = A
        self.B = B
        self.type_ = type_
        
125
126
127
128
129
130
131
    def other(self, location):
        if location is self.A:
            return self.B
        elif location is self.B:
            return self.A
        else:
            raise Exception("OutOfBoundsError")
cmaffeo2's avatar
cmaffeo2 committed
132

133
134
    def delete(self):
        self.A.container.connections.remove(self)
135
136
        if self.B.container is not self.A.container:
            self.B.container.connections.remove(self)
137
138
139
        self.A.connection = None
        self.B.connection = None

cmaffeo2's avatar
cmaffeo2 committed
140
141
142
143
    def __repr__(self):
        return "<Connection {}--{}--{}]>".format( self.A, self.type_, self.B )
        

144
        
145
146
147
# class ConnectableElement(Transformable):
class ConnectableElement():
    """ Abstract base class """
cmaffeo2's avatar
cmaffeo2 committed
148
149
150
151
    def __init__(self, connection_locations=None, connections=None):
        if connection_locations is None: connection_locations = []
        if connections is None: connections = []

152
153
        ## TODO decide on names
        self.locations = self.connection_locations = connection_locations
154
155
        self.connections = connections

cmaffeo2's avatar
cmaffeo2 committed
156
    def get_locations(self, type_=None, exclude=()):
157
158
159
160
161
162
163
164
        locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
        counter = dict()
        for l in locs:
            if l in counter:
                counter[l] += 1
            else:
                counter[l] = 1
        assert( np.all( [counter[l] == 1 for l in locs] ) )
165
166
        return locs

167
    def get_location_at(self, address, on_fwd_strand=True, new_type=None):
168
        loc = None
169
        if (self.num_nt == 1):
170
171
172
173
174
175
176
            # import pdb
            # pdb.set_trace()
            ## Assumes that intrahelical connections have been made before crossovers
            for l in self.locations:
                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
                    assert(loc is None)
                    loc = l
cmaffeo2's avatar
cmaffeo2 committed
177
            # assert( loc is not None )
178
179
180
181
182
        else:
            for l in self.locations:
                if l.address == address and l.on_fwd_strand == on_fwd_strand:
                    assert(loc is None)
                    loc = l
183
        if loc is None and new_type is not None:
184
185
            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
        return loc
186

cmaffeo2's avatar
cmaffeo2 committed
187
    def get_connections_and_locations(self, connection_type=None, exclude=()):
188
189
        """ Returns a list with each entry of the form:
            connection, location_in_self, location_in_other """
190
        type_ = connection_type
191
192
        ret = []
        for c in self.connections:
193
            if (type_ is None or c.type_ == type_) and c.type_ not in exclude:
194
                if   c.A.container is self:
195
                    ret.append( [c, c.A, c.B] )
196
                elif c.B.container is self:
197
198
                    ret.append( [c, c.B, c.A] )
                else:
199
200
                    import pdb
                    pdb.set_trace()
201
202
203
                    raise Exception("Object contains connection that fails to refer to object")
        return ret

204
    def _connect(self, other, connection, in_3prime_direction=None):
205
206
        ## TODO fix circular references        
        A,B = [connection.A, connection.B]
207
208
209
210
        if in_3prime_direction is not None:
            A.is_3prime_side_of_connection = not in_3prime_direction
            B.is_3prime_side_of_connection = in_3prime_direction
            
211
        A.connection = B.connection = connection
212
        self.connections.append(connection)
213
214
        if other is not self:
            other.connections.append(connection)
215

216
217
218
219
220
221
        l = A.container.locations
        if A not in l: l.append(A)
        l = B.container.locations
        if B not in l: l.append(B)
        

222
223
    # def _find_connections(self, loc):
    #     return [c for c in self.connections if c.A == loc or c.B == loc]
224
225

class SegmentParticle(PointParticle):
cmaffeo2's avatar
cmaffeo2 committed
226
    def __init__(self, type_, position, name="A", **kwargs):
227
        self.name = name
228
        self.contour_position = None
cmaffeo2's avatar
cmaffeo2 committed
229
        PointParticle.__init__(self, type_, position, name=name, **kwargs)
230
231
        self.intrahelical_neighbors = []
        self.other_neighbors = []
cmaffeo2's avatar
cmaffeo2 committed
232
        self.locations = []
233

234
    def get_intrahelical_above(self, all_types=True):
235
        """ Returns bead directly above self """
236
        # assert( len(self.intrahelical_neighbors) <= 2 )
237
        for b in self.intrahelical_neighbors:
238
            if b.get_contour_position(self.parent, self.contour_position) > self.contour_position:
239
240
                if all_types or isinstance(b,type(self)):
                    return b
241

242
    def get_intrahelical_below(self, all_types=True):
243
        """ Returns bead directly below self """
244
        # assert( len(self.intrahelical_neighbors) <= 2 )
245
        for b in self.intrahelical_neighbors:
246
            if b.get_contour_position(self.parent, self.contour_position) < self.contour_position:
247
248
                if all_types or isinstance(b,type(self)):
                    return b
249

cmaffeo2's avatar
cmaffeo2 committed
250
    def _neighbor_should_be_added(self,b):
251
252
253
        if type(self.parent) != type(b.parent):
            return True

cmaffeo2's avatar
cmaffeo2 committed
254
        c1 = self.contour_position
255
        c2 = b.get_contour_position(self.parent,c1)
cmaffeo2's avatar
cmaffeo2 committed
256
257
        if c2 < c1:
            b0 = self.get_intrahelical_below()
258
        else:
cmaffeo2's avatar
cmaffeo2 committed
259
260
261
            b0 = self.get_intrahelical_above()

        if b0 is not None:            
262
            c0 = b0.get_contour_position(self.parent,c1)
cmaffeo2's avatar
cmaffeo2 committed
263
264
265
266
267
268
269
270
271
272
273
274
275
            if np.abs(c2-c1) < np.abs(c0-c1):
                ## remove b0
                self.intrahelical_neighbors.remove(b0)
                b0.intrahelical_neighbors.remove(self)
                return True
            else:
                return False
        return True
        
    def make_intrahelical_neighbor(self,b):
        add1 = self._neighbor_should_be_added(b)
        add2 = b._neighbor_should_be_added(self)
        if add1 and add2:
276
277
            # assert(len(b.intrahelical_neighbors) <= 1)
            # assert(len(self.intrahelical_neighbors) <= 1)
cmaffeo2's avatar
cmaffeo2 committed
278
279
            self.intrahelical_neighbors.append(b)
            b.intrahelical_neighbors.append(self)
280

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    def conceptual_get_position(self, context):
        """ 
        context: object

Q: does this function do too much?
Danger of mixing return values

Q: does context describe the system or present an argument?
        """

        ## Validate Inputs
        ...

        ## Algorithm
        """
context specifies:
  - kind of output: real space, nt within segment, fraction of segment
  - absolute or relative
  - constraints: e.g. if passing through
        """
        """
given context, provide the position
        input
"""

    def get_nt_position(self, seg, near_address=None):
cmaffeo2's avatar
cmaffeo2 committed
307
308
309
310
        """ Returns the "address" of the nucleotide relative to seg in
        nucleotides, taking the shortest (intrahelical) contour length route to seg
        """
        if seg == self.parent:
311
            pos = self.contour_position
cmaffeo2's avatar
cmaffeo2 committed
312
        else:
313
314
            pos = self.get_contour_position(seg,near_address)
        return seg.contour_to_nt_pos(pos)
cmaffeo2's avatar
cmaffeo2 committed
315

316
    def get_contour_position(self,seg, address = None):
cmaffeo2's avatar
cmaffeo2 committed
317
318
319
320
        """ TODO: fix paradigm where a bead maps to exactly one location in a polymer
        - One way: modify get_contour_position to take an optional argument that indicates where in the polymer you are looking from
        """

321
322
323
        if seg == self.parent:
            return self.contour_position
        else:
cmaffeo2's avatar
cmaffeo2 committed
324
325
326
327
328
            cutoff = 30*3
            target_seg = seg

            ## depth-first search
            ## TODO cache distances to nearby locations?
cmaffeo2's avatar
cmaffeo2 committed
329
            def descend_search_tree(seg, contour_in_seg, distance=0, visited_segs=None):
cmaffeo2's avatar
cmaffeo2 committed
330
                nonlocal cutoff
cmaffeo2's avatar
cmaffeo2 committed
331
                if visited_segs is None: visited_segs = []
cmaffeo2's avatar
cmaffeo2 committed
332
333
334
335

                if seg == target_seg:
                    # pdb.set_trace()
                    ## Found a segment in our target
336
337
                    sign = 1 if contour_in_seg == 1 else -1
                    if sign == -1: assert( contour_in_seg == 0 )
cmaffeo2's avatar
cmaffeo2 committed
338
339
                    if distance < cutoff: # TODO: check if this does anything
                        cutoff = distance
340
                    return [[distance, contour_in_seg+sign*seg.nt_pos_to_contour(distance)]], [(seg, contour_in_seg, distance)]
cmaffeo2's avatar
cmaffeo2 committed
341
                if distance > cutoff:
342
                    return None,None
cmaffeo2's avatar
cmaffeo2 committed
343
344
                    
                ret_list = []
345
                hist_list = []
cmaffeo2's avatar
cmaffeo2 committed
346
                ## Find intrahelical locations in seg that we might pass through
347
348
349
350
351
                conn_locs = seg.get_connections_and_locations("intrahelical")
                if isinstance(target_seg, SingleStrandedSegment):
                    tmp = seg.get_connections_and_locations("sscrossover")
                    conn_locs = conn_locs + list(filter(lambda x: x[2].container == target_seg, tmp))
                for c,A,B in conn_locs:
cmaffeo2's avatar
cmaffeo2 committed
352
                    if B.container in visited_segs: continue
353
354
355
                    dx = seg.contour_to_nt_pos( A.address, round_nt=False ) - seg.contour_to_nt_pos( contour_in_seg, round_nt=False)
                    dx = np.abs(dx)
                    results,history = descend_search_tree( B.container, B.address,
cmaffeo2's avatar
cmaffeo2 committed
356
357
358
                                                   distance+dx, visited_segs + [seg] )
                    if results is not None:
                        ret_list.extend( results )
359
360
                        hist_list.extend( history )
                return ret_list,hist_list
cmaffeo2's avatar
cmaffeo2 committed
361

362
            results,history = descend_search_tree(self.parent, self.contour_position)
cmaffeo2's avatar
cmaffeo2 committed
363
364
            if results is None or len(results) == 0:
                raise Exception("Could not find location in segment") # TODO better error
365
366
367
368
            if address is not None:
                return sorted(results,key=lambda x:(x[0],(x[1]-address)**2))[0][1]
            else:
                return sorted(results,key=lambda x:x[0])[0][1]
cmaffeo2's avatar
cmaffeo2 committed
369
370
371
            # nt_pos = self.get_nt_position(seg)
            # return seg.nt_pos_to_contour(nt_pos)

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
    def combine(self, other):
        assert(other in self.intrahelical_neighbors)
        assert(self in other.intrahelical_neighbors)

        self.intrahelical_neighbors.remove(other)
        other.intrahelical_neighbors.remove(self)
        for b in other.intrahelical_neighbors:
            b.intrahelical_neighbors.remove(other)
            b.intrahelical_neighbors.append(self)
            self.intrahelical_neighbors.append(b)
        for l in other.locations:
            self.locations.append(l)
            l.particle = self
        
        ## Remove bead
        other.parent.children.remove(other)
        if other in other.parent.beads:
            other.parent.beads.remove(other)
        if 'orientation_bead' in other.__dict__:
            other.parent.children.remove(other.orientation_bead)

        for b in list(other.parent.bonds):
            if other in b[:2]: other.parent.bonds.remove(b)

cmaffeo2's avatar
cmaffeo2 committed
396
397
398
399
400
401
402
403
404
405
406
407
    def update_position(self, contour_position):
        self.contour_position = contour_position
        self.position = self.parent.contour_to_position(contour_position)
        if 'orientation_bead' in self.__dict__:
            o = self.orientation_bead
            o.contour_position = contour_position
            orientation = self.parent.contour_to_orientation(contour_position)
            if orientation is None:
                print("WARNING: local_twist is True, but orientation is None; using identity")
                orientation = np.eye(3)
            o.position = self.position + orientation.dot( np.array((Segment.orientation_bond.r0,0,0)) )
            
cmaffeo2's avatar
cmaffeo2 committed
408
409
410
    def __repr__(self):
        return "<SegmentParticle {} on {}[{:.2f}]>".format( self.name, self.parent, self.contour_position)

411
412

## TODO break this class into smaller, better encapsulated pieces
413
414
415
416
417
418
419
420
421
422
423
class Segment(ConnectableElement, Group):

    """ Base class that describes a segment of DNA. When built from
    cadnano models, should not span helices """

    """Define basic particle types"""
    dsDNA_particle = ParticleType("D",
                                  diffusivity = 43.5,
                                  mass = 300,
                                  radius = 3,                 
                              )
cmaffeo2's avatar
cmaffeo2 committed
424
425
426
427
428
    orientation_particle = ParticleType("O",
                                        diffusivity = 100,
                                        mass = 300,
                                        radius = 1,
                                    )
429

cmaffeo2's avatar
cmaffeo2 committed
430
    # orientation_bond = HarmonicBond(10,2)
431
    orientation_bond = HarmonicBond(30,1.5, rRange = (0,500) )
432
433
434
435
436
437
438

    ssDNA_particle = ParticleType("S",
                                  diffusivity = 43.5,
                                  mass = 150,
                                  radius = 3,                 
                              )

439
    def __init__(self, name, num_nt, 
cmaffeo2's avatar
cmaffeo2 committed
440
                 start_position = None,
441
                 end_position = None, 
cmaffeo2's avatar
cmaffeo2 committed
442
443
                 segment_model = None,
                 **kwargs):
444

cmaffeo2's avatar
cmaffeo2 committed
445
446
        if start_position is None: start_position = np.array((0,0,0))

cmaffeo2's avatar
cmaffeo2 committed
447
        Group.__init__(self, name, children=[], **kwargs)
448
        ConnectableElement.__init__(self, connection_locations=[], connections=[])
449

cmaffeo2's avatar
cmaffeo2 committed
450
451
452
        if 'segname' not in kwargs:
            self.segname = name
        # self.resname = name
cmaffeo2's avatar
cmaffeo2 committed
453
454
455
456
457
        self.start_orientation = None
        self.twist_per_nt = 0

        self.beads = [c for c in self.children] # self.beads will not contain orientation beads

458
459
460
        self._bead_model_generation = 0    # TODO: remove?
        self.segment_model = segment_model # TODO: remove?

461
462
463
464
        self.strand_pieces = dict()
        for d in ('fwd','rev'):
            self.strand_pieces[d] = []

465
        self.num_nt = int(num_nt)
466
        if end_position is None:
467
            end_position = np.array((0,0,self.distance_per_nt*num_nt)) + start_position
468
469
470
        self.start_position = start_position
        self.end_position = end_position

cmaffeo2's avatar
cmaffeo2 committed
471
472
473
474
        ## Used to assign cadnano names to beads
        self._generate_bead_callbacks = []
        self._generate_nucleotide_callbacks = []

475
        ## Set up interpolation for positions
476
477
478
479
        self._set_splines_from_ends()

        self.sequence = None

480
    def __repr__(self):
481
        return "<{} {}[{:d}]>".format( str(type(self)).split('.')[-1], self.name, self.num_nt )
482

cmaffeo2's avatar
cmaffeo2 committed
483
    def set_splines(self, contours, coords):
484
        tck, u = interpolate.splprep( coords.T, u=contours, s=0, k=1)
485
        self.position_spline_params = (tck,u)
486

487
488
    def set_orientation_splines(self, contours, quaternions):
        tck, u = interpolate.splprep( quaternions.T, u=contours, s=0, k=1)
489
490
491
492
493
494
        self.quaternion_spline_params = (tck,u)

    def get_center(self):
        tck, u = self.position_spline_params
        return np.mean(self.contour_to_position(u), axis=0)

495
496
497
498
    def get_bounding_box( self, num_points=3 ):
        positions = np.zeros( (num_points, 3) )
        i = 0
        for c in np.linspace(0,1,num_points):
499
            positions[i] = self.contour_to_position(c)
500
501
502
503
504
            i += 1
        min_ = np.array([np.min(positions[:,i]) for i in range(3)])
        max_ = np.array([np.max(positions[:,i]) for i in range(3)])
        return min_,max_

505
506
507
    def _get_location_positions(self):
        return [self.contour_to_nt_pos(l.address) for l in self.locations]

cmaffeo2's avatar
cmaffeo2 committed
508
    def insert_dna(self, at_nt: int, num_nt: int, seq=tuple()):
509
510
511
512
513
514
515
516
517
518
519
520
        assert(np.isclose(np.around(num_nt),num_nt))
        if at_nt < 0:
            raise ValueError("Attempted to insert DNA into {} at a negative location".format(self))
        if at_nt > self.num_nt-1:
            raise ValueError("Attempted to insert DNA into {} at beyond the end of the Segment".format(self))
        if num_nt < 0:
            raise ValueError("Attempted to insert DNA a negative amount of DNA into {}".format(self))

        num_nt = np.around(num_nt)
        nt_positions = self._get_location_positions()
        new_nt_positions = [p if p <= at_nt else p+num_nt for p in nt_positions]

cmaffeo2's avatar
cmaffeo2 committed
521
522
        ## TODO: handle sequence

523
524
525
526
527
        self.num_nt = self.num_nt+num_nt

        for l,p in zip(self.locations, new_nt_positions):
            l.address = self.nt_pos_to_contour(p)

528
    def remove_dna(self, first_nt:int, last_nt:int, remove_locations:bool = False):
cmaffeo2's avatar
cmaffeo2 committed
529
        """ Removes nucleotides between first_nt and last_nt, inclusive """
530
531
532
533
534
535
536
537
538
539
540
541
542
        assert(np.isclose(np.around(first_nt),first_nt))
        assert(np.isclose(np.around(last_nt),last_nt))
        tmp = min((first_nt,last_nt))
        last_nt = max((first_nt,last_nt))
        fist_nt = tmp

        if first_nt < 0 or first_nt > self.num_nt-2:
            raise ValueError("Attempted to remove DNA from {} starting at an invalid location {}".format(self, first_nt))
        if last_nt < 1 or last_nt > self.num_nt-1:
            raise ValueError("Attempted to remove DNA from {} ending at an invalid location {}".format(self, last_nt))
        if first_nt == last_nt:
            return

543
544
545
546
547
548
549
550
551
        def _transform_contour_positions(contours):
            nt = self.contour_to_nt_pos(contours)
            first = first_nt if first_nt > 0 else -np.inf
            last = last_nt if last_nt < self.num_nt-1 else np.inf
            bad_ids = (nt >= first) * (nt <= last)
            nt[nt>last] = nt[nt>last]-removed_nt
            nt[bad_ids] = np.nan
            return self.nt_pos_to_contour(nt)* self.num_nt/num_nt

552
553
554
        first_nt = np.around(first_nt)
        last_nt = np.around(last_nt)

cmaffeo2's avatar
cmaffeo2 committed
555
556
557
        removed_nt = last_nt-first_nt+1
        num_nt = self.num_nt-removed_nt

558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        c = np.array([l.address for l in self.locations])
        new_c = _transform_contour_positions(c)

        if np.any(np.isnan(new_c)) and not remove_locations:
            raise Exception("Attempted to remove DNA containing locations from {} between {} and {}".format(self,first_nt,last_nt))

        for l,v in zip(list(self.locations),new_c):
            if np.isnan(v):
                l.delete()
            else:
                l.address = v

        assert( len(self.locations) == np.logical_not(np.isnan(new_c)).sum() )
        tmp_nt = np.array([l.address*num_nt+0.5 for l in self.locations])
        assert(np.all( (tmp_nt < 1) + (tmp_nt > num_nt-1) + np.isclose((np.round(tmp_nt) - tmp_nt), 0) ))

cmaffeo2's avatar
cmaffeo2 committed
574
575
576
577
        if self.sequence is not None and len(self.sequence) == self.num_nt:
            self.sequence = [s for s,i in zip(self.sequence,range(self.num_nt)) 
                                if i < first_nt or i > last_nt]
            assert( len(self.sequence) == num_nt )
578

579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
        ## Rescale splines
        def _rescale_splines(u, get_coord_function, set_spline_function):
            new_u = _transform_contour_positions(u)
            ids = np.logical_not(np.isnan(new_u))
            new_u = new_u[ids]
            new_p = get_coord_function(u[ids])
            set_spline_function( new_u, new_p )

        _rescale_splines(self.position_spline_params[1],
                         self.contour_to_position,
                         self.set_splines)

        if self.quaternion_spline_params is not None:
            def _contour_to_quaternion(contours):
                os = [self.contour_to_orientation(c) for c in contours]
                qs = [quaternion_from_matrix( o ) for o in os]
                return np.array(qs)
            _rescale_splines(self.quaternion_spline_params[1],
597
598
                             _contour_to_quaternion,
                             self.set_orientation_splines)
599

cmaffeo2's avatar
cmaffeo2 committed
600
        self.num_nt = num_nt
601

602
603
604
605
    def delete(self):
        for c,loc,other in self.get_connections_and_locations():
            c.delete()
        self.parent.segments.remove(self)
606

607
608
609
610
611
612
613
614
615
616
617
618
    def __filter_contours(contours, positions, position_filter, contour_filter):
        u = contours
        r = positions

        ## Filter
        ids = list(range(len(u)))
        if contour_filter is not None:
            ids = list(filter(lambda i: contour_filter(u[i]), ids))
        if position_filter is not None:
            ids = list(filter(lambda i: position_filter(r[i,:]), ids))
        return ids

619
    def translate(self, translation_vector, position_filter=None, contour_filter=None):
620
621
622
623
624
        dr = np.array(translation_vector)
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)
cmaffeo2's avatar
cmaffeo2 committed
625
        if len(ids) == 0: return
626
627
628

        ## Translate
        r[ids,:] = r[ids,:] + dr[np.newaxis,:]
629
        self.set_splines(u,r)
630
631
632
633
634
635

    def rotate(self, rotation_matrix, about=None, position_filter=None, contour_filter=None):
        tck, u = self.position_spline_params
        r = self.contour_to_position(u)

        ids = Segment.__filter_contours(u, r, position_filter, contour_filter)
cmaffeo2's avatar
cmaffeo2 committed
636
        if len(ids) == 0: return
637
638
639

        if about is None:
            ## TODO: do this more efficiently
640
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]) for i in ids])
641
642
643
        else:
            dr = np.array(about)
            ## TODO: do this more efficiently
644
            r[ids,:] = np.array([rotation_matrix.dot(r[i,:]-dr) + dr for i in ids])
645

646
        self.set_splines(u,r)
647
648
649
650

        if self.quaternion_spline_params is not None:
            ## TODO: performance: don't shift between quaternion and matrix representations so much
            tck, u = self.quaternion_spline_params
651
            orientations = np.array([self.contour_to_orientation(v) for v in u])
652
653
            for i in ids:
                orientations[i,:] = rotation_matrix.dot(orientations[i])
654
            quats = np.array([quaternion_from_matrix(o) for o in orientations])
655
            self.set_orientation_splines(u, quats)
656

657
    def _set_splines_from_ends(self, resolution=4):
658
        self.quaternion_spline_params = None
659
660
661
662
663
664
        r0 = np.array(self.start_position)[np.newaxis,:]
        r1 = np.array(self.end_position)[np.newaxis,:]
        u = np.linspace(0,1, max(3,self.num_nt//int(resolution)))
        s = u[:,np.newaxis]
        coords = (1-s)*r0 + s*r1
        self.set_splines(u, coords)
665

666
667
668
    def clear_all(self):
        Group.clear_all(self)  # TODO: use super?
        self.beads = []
669
670
671
672
673
        # for c,loc,other in self.get_connections_and_locations():
        #     loc.particle = None
        #     other.particle = None
        for l in self.locations:
            l.particle = None
674

675
    def contour_to_nt_pos(self, contour_pos, round_nt=False):
676
        nt = contour_pos*(self.num_nt) - 0.5
677
        if round_nt:
cmaffeo2's avatar
cmaffeo2 committed
678
            assert( np.isclose(np.around(nt),nt) )
679
680
681
            nt = np.around(nt)
        return nt

682
    def nt_pos_to_contour(self,nt_pos):
683
        return (nt_pos+0.5)/(self.num_nt)
684

685
    def contour_to_position(self,s):
686
        p = interpolate.splev( s, self.position_spline_params[0] )
687
688
689
690
        if len(p) > 1: p = np.array(p).T
        return p

    def contour_to_tangent(self,s):
691
        t = interpolate.splev( s, self.position_spline_params[0], der=1 )
692
693
        t = (t / np.linalg.norm(t,axis=0))
        return t.T
694
695
696
        

    def contour_to_orientation(self,s):
697
        assert( isinstance(s,float) or isinstance(s,int) or len(s) == 1 )   # TODO make vectorized version
698

699
700
701
702
703
        if self.quaternion_spline_params is None:
            axis = self.contour_to_tangent(s)
            axis = axis / np.linalg.norm(axis)
            rotAxis = np.cross(axis,np.array((0,0,1)))
            rotAxisL = np.linalg.norm(rotAxis)
704
            zAxis = np.array((0,0,1))
705
706

            if rotAxisL > 0.001:
707
708
709
                theta = np.arcsin(rotAxisL) * 180/np.pi
                if axis.dot(zAxis) < 0: theta = 180-theta
                orientation0 = rotationAboutAxis( rotAxis/rotAxisL, theta, normalizeAxis=False ).T
710
            else:
711
712
                orientation0 = np.eye(3) if axis.dot(zAxis) > 0 else \
                               rotationAboutAxis( np.array((1,0,0)), 180, normalizeAxis=False )
713
714
715
            if self.start_orientation is not None:
                orientation0 = orientation0.dot(self.start_orientation)

716
717
718
            orientation = rotationAboutAxis( axis, self.twist_per_nt*self.contour_to_nt_pos(s), normalizeAxis=False )
            orientation = orientation.dot(orientation0)
        else:
719
            q = interpolate.splev( s, self.quaternion_spline_params[0] )
720
721
            if len(q) > 1: q = np.array(q).T # TODO: is this needed?
            orientation = quaternion_to_matrix(q)
722

723
        return orientation
724

725
726
727
728
729
730
731
    def _ntpos_to_seg_and_ntpos(self, nt_pos, is_fwd=True, visited_segs=tuple()):
        """ Cross intrahelical to obtain a tuple of the segment nucleotide position """
        """ TODO: This function could perhaps replace parts of SegmentParticle.get_contour_position """

        if nt_pos >= 0 and nt_pos < self.num_nt:
            return (self,nt_pos,is_fwd)
        else:
732
733
734
735
736
            try:
                c,A,B = self.get_contour_sorted_connections_and_locations(type_='intrahelical')[0 if nt_pos < 0 else -1]
            except:
                return None

737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            ## Special logic for circular segments
            assert(A.container == self)
            if A.container == B.container:
                if (nt_pos < 0 and A.address > B.address) or (nt_pos >= self.num_nt and A.address < B.address):
                    A,B = (B,A)

            other_seg = B.container
            if A.address == (0 if nt_pos < 0 else 1) and other_seg not in visited_segs:
                ## Cross the crossover
                other_pos = (nt_pos-self.num_nt) if nt_pos >= self.num_nt else -nt_pos-1
                if B.address > 0.5:
                    other_pos = (other_seg.num_nt-1) - other_pos
                other_fwd = not is_fwd if A.address == B.address else is_fwd
                return other_seg._ntpos_to_seg_and_ntpos(other_pos, other_fwd,
                                                         list(visited_segs)+[self])

            else:
                ## Could not find a path to nt_pos
                return None
        assert(False)


cmaffeo2's avatar
cmaffeo2 committed
759
    def get_contour_sorted_connections_and_locations(self,type_):
cmaffeo2's avatar
cmaffeo2 committed
760
        sort_fn = lambda c: c[1].address
cmaffeo2's avatar
cmaffeo2 committed
761
        cl = self.get_connections_and_locations(type_)
cmaffeo2's avatar
cmaffeo2 committed
762
        return sorted(cl, key=sort_fn)
763
764
765
    
    def randomize_unset_sequence(self):
        bases = list(seqComplement.keys())
766
        # bases = ['T']        ## FOR DEBUG
767
        if self.sequence is None:
768
            self.sequence = [random.choice(bases) for i in range(self.num_nt)]
769
        else:
770
            assert(len(self.sequence) == self.num_nt) # TODO move
771
772
773
            for i in range(len(self.sequence)):
                if self.sequence[i] is None:
                    self.sequence[i] = random.choice(bases)
774

cmaffeo2's avatar
cmaffeo2 committed
775
776
777
    def _get_num_beads(self, max_basepairs_per_bead, max_nucleotides_per_bead ):
        raise NotImplementedError

778
    def _generate_one_bead(self, contour_position, nts):
779
780
        raise NotImplementedError

cmaffeo2's avatar
cmaffeo2 committed
781
    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale, strand_segment):
cmaffeo2's avatar
cmaffeo2 committed
782
783
784
785
786
        """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """

        # print("Generating nucleotide at {}".format(contour_position))
        
        pos = self.contour_to_position(contour_position)
787
        orientation = self.contour_to_orientation(contour_position)
cmaffeo2's avatar
cmaffeo2 committed
788

789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
        """ deleteme
        ## TODO: move this code (?)
        if orientation is None:
            import pdb
            pdb.set_trace()
            axis = self.contour_to_tangent(contour_position)
            angleVec = np.array([1,0,0])
            if axis.dot(angleVec) > 0.9: angleVec = np.array([0,1,0])
            angleVec = angleVec - angleVec.dot(axis)*axis
            angleVec = angleVec/np.linalg.norm(angleVec)
            y = np.cross(axis,angleVec)
            orientation = np.array([angleVec,y,axis]).T
            ## TODO: improve placement of ssDNA
            # rot = rotationAboutAxis( axis, contour_position*self.twist_per_nt*self.num_nt, normalizeAxis=True )
            # orientation = rot.dot(orientation)
        else:
            orientation = orientation                            
        """
cmaffeo2's avatar
cmaffeo2 committed
807
        key = seq
808
809
        nt_dict = canonicalNtFwd if is_fwd else canonicalNtRev

810
        atoms = nt_dict[ key ].generate() # TODO: clone?
cmaffeo2's avatar
cmaffeo2 committed
811
        atoms.orientation = orientation.dot(atoms.orientation)
812
813
814
815
        if isinstance(self, SingleStrandedSegment):
            if scale is not None and scale != 1:
                for a in atoms:
                    a.position = scale*a.position
816
            atoms.position = pos - atoms.atoms_by_name["C1'"].collapsedPosition()
817
818
819
820
821
822
823
824
825
        else:
            if scale is not None and scale != 1:
                if atoms.sequence in ("A","G"):
                    r0 = atoms.atoms_by_name["N9"].position
                else:
                    r0 = atoms.atoms_by_name["N1"].position
                for a in atoms:
                    if a.name[-1] in ("'","P","T"):
                        a.position = scale*(a.position-r0) + r0
826
                    else:
827
                        a.fixed = 1
828
            atoms.position = pos
cmaffeo2's avatar
cmaffeo2 committed
829
830
831
832
833
834
        
        atoms.contour_position = contour_position
        strand_segment.add(atoms)

        for callback in self._generate_nucleotide_callbacks:
            callback(atoms)
cmaffeo2's avatar
cmaffeo2 committed
835
836

        return atoms
837

838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
    def _generate_oxdna_nucleotide(self, contour_position, is_fwd, seq):
        bp_center = self.contour_to_position(contour_position)
        orientation = self.contour_to_orientation(contour_position)

        DefaultOrientation = rotationAboutAxis([0,0,1], 90)
        if is_fwd: 
            DefaultOrientation = rotationAboutAxis([1,0,0], 180).dot(DefaultOrientation)

        o = orientation.dot(DefaultOrientation)

        if isinstance(self, SingleStrandedSegment):
            pos = bp_center
        else:
            pos = bp_center - 5*o.dot(np.array((1,0,0)))

        nt = PointParticle("oxdna_nt", position= pos,
                             orientation = o)

        nt.contour_position = contour_position
        return nt


860
861
    def add_location(self, nt, type_, on_fwd_strand=True):
        ## Create location if needed, add to segment
862
        c = self.nt_pos_to_contour(nt)
863
864
865
866
867
868
        assert(c >= 0 and c <= 1)
        # TODO? loc = self.Location( address=c, type_=type_, on_fwd_strand=is_fwd )
        loc = Location( self, address=c, type_=type_, on_fwd_strand=on_fwd_strand )
        self.locations.append(loc)

    ## TODO? Replace with abstract strand-based model?
869
870

    def add_nick(self, nt, on_fwd_strand=True):
871
872
873
874
875
876
877
        if on_fwd_strand:
            self.add_3prime(nt,on_fwd_strand)
            self.add_5prime(nt+1,on_fwd_strand)
        else:
            self.add_5prime(nt,on_fwd_strand)
            self.add_3prime(nt+1,on_fwd_strand)

878

879
    def add_5prime(self, nt, on_fwd_strand=True):
880
881
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
882
        self.add_location(nt,"5prime",on_fwd_strand)
883
884

    def add_3prime(self, nt, on_fwd_strand=True):
885
886
        if isinstance(self,SingleStrandedSegment):
            on_fwd_strand = True
887
        self.add_location(nt,"3prime",on_fwd_strand)
888

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
    ## Real work
    def _connect_ends(self, end1, end2, type_, force_connection=False):
        debug = False
        ## TODO remove self?
        ## validate the input
        for end in (end1, end2):
            assert( isinstance(end, Location) )
            assert( end.type_ in ("end3","end5") )
        assert( end1.type_ != end2.type_ )

        ## Remove other connections involving these points
        if end1.connection is not None:
            if debug: print("WARNING: reconnecting {}".format(end1))
            end1.connection.delete()
        if end2.connection is not None:
            if debug: print("WARNING: reconnecting {}".format(end2))
            end2.connection.delete()

        ## Create and add connection
        if end2.type_ == "end5":
            end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ), in_3prime_direction=True )
        else:
            end2.container._connect( end1.container, Connection( end2, end1, type_=type_ ), in_3prime_direction=True )

913
    def get_3prime_locations(self):
cmaffeo2's avatar
cmaffeo2 committed
914
        return sorted(self.get_locations("3prime"),key=lambda x: x.address)
915
    
cmaffeo2's avatar
cmaffeo2 committed
916
    def get_5prime_locations(self):
917
        ## TODO? ensure that data is consistent before _build_model calls
cmaffeo2's avatar
cmaffeo2 committed
918
        return sorted(self.get_locations("5prime"),key=lambda x: x.address)
cmaffeo2's avatar
cmaffeo2 committed
919

920
921
922
923
924
925
926
927
928
929
930
    def get_blunt_DNA_ends(self):
        if isinstance(self, SingleStrandedSegment): return []
        ret = []
        cl = self.get_connections_and_locations("intrahelical")
        if not any([c[1].address == 0 for c in cl]):
            ret.append((self.start5,self.start3,-1))
        if not any([c[1].address == 1 for c in cl]):
            ret.append((self.end5,self.end3,1))
        # return [c for c in cl if c[1].address == 0 or c[1].address == 1]
        return ret

931
    def iterate_connections_and_locations(self, reverse=False):
cmaffeo2's avatar
cmaffeo2 committed
932
933
        ## connections to other segments
        cl = self.get_contour_sorted_connections_and_locations()
934
        if reverse:
cmaffeo2's avatar
cmaffeo2 committed
935
            cl = cl[::-1]
936
937
938
            
        for c in cl:
            yield c
cmaffeo2's avatar
cmaffeo2 committed
939

940
941
942
943
944
945
946
947
948
949
950
    ## TODO rename
    def _add_strand_piece(self, strand_piece):
        """ Registers a strand segment within this object """

        ## TODO use weakref
        d = 'fwd' if strand_piece.is_fwd else 'rev'

        ## Validate strand_piece (ensure no clashes)
        for s in self.strand_pieces[d]:
            l,h = sorted((s.start,s.end))
            for value in (strand_piece.start,strand_piece.end):
951
952
                if value >= l and value <= h:
                    raise Exception("Strand piece already exists! DNA may be circular.")
953
954
955
956
957
958

        ## Add strand_piece in correct order
        self.strand_pieces[d].append(strand_piece)
        self.strand_pieces[d] = sorted(self.strand_pieces[d],
                                       key = lambda x: x.start)

959
    ## TODO rename
960
    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
961
        """ Walks through locations, checking for crossovers """
962
963
964
965
        # if self.name in ("6-1","1-1"):
        #     import pdb
        #     pdb.set_trace()
        move_at_least = 0
966
967

        ## Iterate through locations
cmaffeo2's avatar
cmaffeo2 committed
968
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
969
970
971
972
973
974
        def loc_rank(l):
            nt = l.get_nt_pos()
            ## optionally add logic about type of connection
            return (nt, not l.on_fwd_strand)
        # locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
        locations = sorted(self.locations, key=loc_rank, reverse=(not is_fwd))
975
976
        # print(locations)

977
        for l in locations:
cmaffeo2's avatar
cmaffeo2 committed
978
979
980
981
            # TODOTODO probably okay
            if l.address == 0:
                pos = 0.0
            elif l.address == 1:
982
                pos = self.num_nt-1
cmaffeo2's avatar
cmaffeo2 committed
983
984
            else:
                pos = self.contour_to_nt_pos(l.address, round_nt=True)
985
986
987

            ## DEBUG

cmaffeo2's avatar
cmaffeo2 committed
988

989
            ## Skip locations encountered before our strand
990
991
992
993
994
995
996
997
            # tol = 0.1
            # if is_fwd:
            #     if pos-nt_pos <= tol: continue 
            # elif   nt_pos-pos <= tol: continue
            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
            ## TODO: remove move_at_least
            if np.isclose(pos,nt_pos):
                if l.is_3prime_side_of_connection: continue
998
999

            ## Stop if we found the 3prime end
1000
            if l.on_fwd_strand == is_fwd and l.type_ == "3prime" and l.connection is None:
1001
1002
                if _DEBUG_TRACE:
                    print("  found end at",l)
1003
                return pos, None, None, None, None
1004
1005
1006
1007
1008
1009
1010
1011

            ## Check location connections
            c = l.connection
            if c is None: continue
            B = c.other(l)            

            ## Found a location on the same strand?
            if l.on_fwd_strand == is_fwd:
1012
1013
1014
                if _DEBUG_TRACE:
                    print("  passing through",l)
                    print("from {}, connection {} to {}".format(nt_pos,l,B))