atomicModel.py 51.9 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
from datetime import datetime
from cadnano.cnenum import PointType
from cadnano.strand import Strand
from math import pi,sqrt,exp,floor
import numpy as np
import random
8
import os, sys, subprocess
cmaffeo2's avatar
cmaffeo2 committed
9
import pdb
10

cmaffeo2's avatar
cmaffeo2 committed
11
from coords import minimizeRmsd, rotationAboutAxis
12
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
13
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
14
15
16
17
18
19
20

class abstractNucleotide():
    def __init__(self, sequence, position, angle, ntAt5prime=None, **prop):
        self.sequence = sequence
        self.prop = prop
        # self.prop['isFwd'] = isFwd
        # self.isFwd = isFwd
21
        self.position = np.array(position)
cmaffeo2's avatar
cmaffeo2 committed
22
23
        # self.angle = angle
        self.orientation = rotationAboutAxis([0,0,1], angle)
24
25
26
27
28
29
        self.ntAt5prime = ntAt5prime
        self.ntAt3prime = None

        self.basepair = None
        self.stack5prime = None
        self.stack3prime = None
30
31

        self.firstAtomicIndex = -1
32
33
34
35
        
    def set3primeNt(self,n):
        self.ntAt3prime = n

36

37
38
39
40
41
42
43
44
45
46
47
48
    def atoms(self, transform=True, scale=1.0):
        key = self.sequence
        if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"
        
        if self.prop['isFwd']:
            atoms = canonicalNtFwd[ key ]
        else:
            atoms = canonicalNtRev[ key ]

        if transform:
49
50
            singleStranded = (self.basepair is None)
            atoms = atoms.transformed(self.orientation, self.position, scale, singleStranded)
51
52
        return atoms
            
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class RigidBodyNode():
    def __init__(self, helix, pos, angle, type="dsDNA"):
        self.helix    = helix
        self.position = pos
        # self.angle = 
        self.type     = type
        # self.numBps
        self.nodeAbove = None
        self.nodeBelow = None
        self.xovers = []
        self.idx = helix.model.numNodes
        helix.model.numNodes = helix.model.numNodes + 1

    def addNodeAbove(self, node, separation):
        self.nodeAbove = node
        self.nodeAboveSep = separation # bp
    def addNodeBelow(self, node, separation):
        self.nodeBelow = node
        self.nodeBelowSep = separation # bp

    def addXover(self, node, polarity, double=False):
        ## TODO: what is meant by polarity?
        self.xovers.append( (node,polarity,double) )

    
    def getNodesAbove(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeAbove is None: break
            n = n.nodeAbove
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

    def getNodesBelow(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeBelow is None: break
            n = n.nodeBelow
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

class Segment():
cmaffeo2's avatar
cmaffeo2 committed
112
    def __init__(self, segname, ntScale=1.0):
113
        # self.model = model
cmaffeo2's avatar
cmaffeo2 committed
114
        self.ntScale = ntScale
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.props = dict()
        self.props['segname'] = segname
        self.nodes = dict()
        self.firstNode = None

    def addNucleotide(self, helixId, strandIdx, zIdx,
                      sequence, isFwd, position, angle, ntAt5prime=None):

        if helixId in self.nodes:
            if strandIdx in self.nodes[helixId]:
                for n in self.nodes[helixId][strandIdx]:
                    if n.prop['zIdx'] == zIdx and n.prop['isFwd'] == isFwd:
                        raise Exception("Attempted to add a node in the same location (%d:%.1f:%s) twice!" % (helixId,zIdx,isFwd))
        else:
            self.nodes[helixId] = dict()
        
        if strandIdx not in self.nodes[helixId]:
            self.nodes[helixId][strandIdx] = []

        n = abstractNucleotide(sequence, position, angle, ntAt5prime, 
135
                               isFwd=isFwd, helixId=helixId, zIdx=zIdx, idx=strandIdx)
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        if ntAt5prime is not None:
            ntAt5prime.set3primeNt(n)

        self.nodes[helixId][strandIdx].append( n ) # TODO: use strandIdx?

        # if self._firstNode is not None:
        #     self._findFirstNode()

        # ## Update ordered list of nodes 
        # if self.model.nodes is not None:
        #     model.buildOrderedNodesList()
        return n

150
151
152
153
154
155
156
    def __len__(self):
        l = 0
        for hid,ntsAtIdx in self.nodes.items():
            for strandIdx,nts in ntsAtIdx.items():
                l += len(nts)
        return l

157
    def __iter__(self):
158
        for n in self._get5primeNts():            
159
            yield n
160
161
162
            while n.ntAt3prime is not None:
                n = n.ntAt3prime
                yield n
163
164
165

    def atoms(self, transform=True):
        for nt in self:
166
            yield nt.atoms(transform, self.ntScale)
167
 
168
169
170
171
    def sequence(self):
        for nt in self:
            yield nt.sequence

172
173
174
175
176
177
178
    def _get5primeNts(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)
179

180
181
182
183
184
185
        if len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
        return firstNodes
        
    ## TODO deprecate
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _findFirstNode(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)

        if len(firstNodes) > 1:
            raise Exception("Segment object contains two 5' ends!")
        elif len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
            
        return firstNodes[0]
    
cmaffeo2's avatar
cmaffeo2 committed
202
class atomicModel():
203

cmaffeo2's avatar
cmaffeo2 committed
204
    def __init__(self, part, ntScale=1.0, randomSeed=1):
205
206
        self.chunks = dict()

207
208
        self.numNodes = 0
        self.segments = []
209
210
        self.helixNtsFwd = dict()
        self.helixNtsRev = dict()
211
        self.nodeTypeCounts = None
cmaffeo2's avatar
cmaffeo2 committed
212
        self.ntScale = ntScale
cmaffeo2's avatar
cmaffeo2 committed
213
        self.part = part
214
215

        # self._nbParams = set()
216
217
218
        self.bonds = set()
        self.angles = set()
        self.dihedrals = set()
219

220
221
222
        # self._bondParams = set()
        # self._angleParams = set()
        # self._dihedralParams = set()
223
224

        self._nbParamFiles = []
225
226
227
        # self._bondParamFiles = set()
        # self._angleParamFiles = set()
        # self._dihedralParamFiles = set()
228

cmaffeo2's avatar
cmaffeo2 committed
229
        random.seed(randomSeed)
230

231
        self.latticeType = atomicModel._determineLatticeType(part)
232
        self._buildModel(part)
233
234
235
236
237
238
239
240
241
242
243
244
245
246

    def _strandOccupancies(self, helixId):
        try:
            ends1,ends2 = self._getStrandEnds(helixId)
        except:
            ## hid not in strand_list
            return []

        strandOccupancies = [ [x for i in range(0,len(e),2)
                               for x in range(e[i],e[i+1]+1)]
                              for e in (ends1,ends2) ]
        return sorted(list(set( strandOccupancies[0] + strandOccupancies[1] )))

    def _getStrandEnds(self, helixId):
247
248
249
250

        """Utility method to convert cadnano strand lists into list of
        indices of terminal points"""

251
252
        helixStrands = self.strand_list[helixId]

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        endLists = [[],[]]
        for endList, strandList in zip(endLists,helixStrands):
            lastStrand = None
            for s in strandList:
                if lastStrand is None:
                    ## first strand
                    endList.append(s[0])
                elif lastStrand[1] != s[0]-1: 
                    assert( s[0] > lastStrand[1] )
                    endList.extend( [lastStrand[1], s[0]] )
                lastStrand = s
            if lastStrand is not None:
                endList.append(lastStrand[1])
        return endLists

    # -------------------------- #
    # Methods for building model #
    # -------------------------- #
    def _buildModel(self, part):
        # maxVhelixId = part.getIdNumMax()
        
        props = part.getModelProperties().copy()
        # print(props)
        vh_props = []
        origins = []
        if props.get('point_type') == PointType.ARBITRARY:
            # TODO add code to encode Parts with ARBITRARY point configurations
cmaffeo2's avatar
cmaffeo2 committed
280
            raise Exception("Not implemented")
281
282
283
284
285
286
287
288
289
290
        else:
            vh_props, origins = part.helixPropertiesAndOrigins()
            # print(' VIRTUAL HELICES:', vh_props)
            # # print(' ORIGINS:', origins)
            # group_props['virtual_helices'] = vh_props
            # group_props['origins'] = origins
            
        ## TODO: compartmentalize following
        ## Loop over virtual helices and build lists of strands 
        vh_list = []
291
        self.strand_list = []
292
293
294
295
296
297
298
        xover_list = []

        numHID = part.getIdNumMax() + 1
        for id_num in range(numHID):
            offset_and_size = part.getOffsetAndSize(id_num)
            if offset_and_size is None:
                # add a placeholder
299
                self.strand_list.append(None)
300
                # prop_list.append(None)
301
302
303
304
305
306
307
308
            else:
                offset, size = offset_and_size
                vh_list.append((id_num, size))
                fwd_ss, rev_ss = part.getStrandSets(id_num)
                # for s in fwd_ss:
                #     print(' VHELIX %d fwd_ss:' % id_num, s)
                fwd_idxs, fwd_colors  = fwd_ss.dump(xover_list)
                rev_idxs, rev_colors  = rev_ss.dump(xover_list)
309
                self.strand_list.append((fwd_idxs, rev_idxs))
310
311
312
313
314
315
316
317
318
319
320
321
322
323
                
                ## prop_list.append((fwd_colors, rev_colors))
                # for s in strand_list:
                #     print( s )
        segid = 1

        """
        Strategy:
        1) Create all oligos/segments, describing 5'/3' bonding & location
        2) Add basepairing information
        3) For all basepairs, find stacks 5'/3' stacks
        """

        for hid in range(numHID):
324
325
326
            self.helixNtsFwd[hid] = dict()
            self.helixNtsRev[hid] = dict()
            self.chunks[hid] = []
327
328
329

        ## Loop over oligos/segments
        segments = []
330
        ntCount = 0
331
332
        keyFn = lambda o: (o.length(), o._strand5p.idNum(), o._strand5p.idx5Prime())
        for oligo in reversed(sorted(part.oligos(), key=keyFn)):
333
334
335
            # if oligo.isLoop():
            #     # print("A loop exists")
            #     raise Exception("Loop strands are not supported")
336
            
cmaffeo2's avatar
cmaffeo2 committed
337
            seg = Segment("D%03d" % segid, self.ntScale)
338
339
340
341
342
343
344
            segments.append(seg)
            segid += 1

            lastNt = None
            pos = []
            angle = []
            seq = []
cmaffeo2's avatar
cmaffeo2 committed
345

346
347
348
349
350
351
352
353
354
355
356
357
358
            ## loop over strands in oligo
            for strand in oligo.strand5p().generator3pStrand():
                hid = strand.idNum() # virtual helix ID
                x,y = origins[hid]

                axis = [0,0,1]

                keys = ['bases_per_repeat',
                        'turns_per_repeat',
                        'eulerZ','z']
                bpr,tpr,eulerZ,z = [vh_props[k][hid] for k in keys] 

                twist_per_base = tpr*360./bpr
359
360
361
362
363
364
365
                ## override twist_per_base if square lattice
                if self.latticeType == "square":
                    # assert(twist_per_base == 360.0*3/32)
                    twist_per_base == 360.0*3/32
                else:
                    assert(twist_per_base == 360.0/10.5)

366
                zIdxToPos = lambda idx: (x*10,y*10,z-3.4*idx)
cmaffeo2's avatar
cmaffeo2 committed
367
                zIdxToAngle = lambda idx: idx*twist_per_base + eulerZ 
368

cmaffeo2's avatar
cmaffeo2 committed
369
370
371
                isFwd = strand.isForward()

                lo,hi = strand.idxs()
372
373
374
                seqs = Strand.sequence(strand, for_export=True)
                numNt = strand.totalLength()

cmaffeo2's avatar
cmaffeo2 committed
375
376
                ## We place nts as pairs; spacing is not always uniform along strand.
                ## Here we find the locations between which spacing will be uniform.
377
                compIdxs = [comp.idxs() for comp in strand.getComplementStrands()]
378
379
380
381
382
383
                chunks = self.combineCompactRegionLists(
                    [[lo,hi]],
                    compIdxs,
                    intersect=True
                )
                self.chunks[hid].extend( chunks ) # Add running list of chunks for push bonds
384

385
                ## Find zIdx for each nucleotide
cmaffeo2's avatar
cmaffeo2 committed
386
                zIdxs = []
387
                for l,h in chunks:
cmaffeo2's avatar
cmaffeo2 committed
388
389
390
391
                    ## Find the number of nts between l,h
                    nts = h-l+1
                    for ins in strand.insertionsOnStrand(l,h):
                        nts += ins.length()
392

cmaffeo2's avatar
cmaffeo2 committed
393
394
395
                    for i in range(nts):
                        zIdxs.append( l + (h-l) * float(i)/(nts-1) )

396
397
                assert(len(zIdxs) == numNt)

398
                if isFwd:
399
                    helixNts = self.helixNtsFwd[hid]
400
                else:
401
                    helixNts = self.helixNtsRev[hid]
402
403
                    zIdxs = reversed(zIdxs)

404
405
                ## Find strandOccupancy index for each nucleotide
                strandOccIdx = []
406
                for l,h in chunks:
407
408
409
410
                    for i in range(l,h+1):
                        nts = 1
                        for ins in strand.insertionsOnStrand(i,i):
                            nts += ins.length()
411
                        helixNts[i] = []
412
413
414
415
416
                        for j in range(nts):
                            strandOccIdx.append(i)
                if not isFwd:
                    strandOccIdx = reversed(strandOccIdx)

417
                ## Add nucleotides
418
                for s,zIdx,idx in zip(seqs,zIdxs,strandOccIdx):
419
420
                    p = zIdxToPos(zIdx) 
                    a = zIdxToAngle(zIdx)
421
                    lastNt = seg.addNucleotide(hid, idx, zIdx, s, isFwd,
422
423
                                               p, a, lastNt)

424
425
426
427
428
429
                    helixNts[idx].append(lastNt)

        ## Correct the order of nts in helixNtsRev
        for hid,ntsAtIdx in self.helixNtsRev.items():
            for i,nts in ntsAtIdx.items():
                self.helixNtsRev[hid][i] = reversed(nts)
430
431
432
433

        ## 2) Find basepairs and stacks
        for hid in range(numHID):
            prevBasepair = None
434
            for i in self._strandOccupancies(hid):
435
                ## Check that strand index is dsDNA 
436
                if not (i in self.helixNtsFwd[hid] and i in self.helixNtsRev[hid]):
437
                    continue
438

439
440
                for nt1,nt2 in zip(self.helixNtsFwd[hid][i],
                                   self.helixNtsRev[hid][i]):
441
442
                    self._pairBases(nt1,nt2)
                
443
                    if prevBasepair is not None and i - prevBasepair[0] <= 3: # TODO: find better way of locating stacks
444
445
446
447
                        atomicModel._stackBases( prevBasepair[1], nt1)
                        atomicModel._stackBases( nt2, prevBasepair[2])
                    # else:
                    #     print("NO STACK FOR %d:%d" % (hid,i))
448
                        #basepairs[hid].append((i,n1,n2))
449
                    prevBasepair = (i,nt1,nt2)
450
451
452
453
454
455
456
457
                    
        self.segments.extend(segments)
        return

    def _connectNodes(self, below, above, sep):
        below.addNodeAbove(above, sep)
        above.addNodeBelow(below, sep)
        
458
    def _getNeighborHelixDict(part):
cmaffeo2's avatar
cmaffeo2 committed
459
        props, origins = part.helixPropertiesAndOrigins()
460
461
462
463
        neighborHelices = dict()
        for i in range(len(origins)):
            neighborHelices[i] = [int(j.strip('[,]')) for j in props['neighbors'][i].split() if j != '[]']
        return neighborHelices
cmaffeo2's avatar
cmaffeo2 committed
464

465
466
467
    def _determineLatticeType(part):
        props, origins = part.helixPropertiesAndOrigins()
        origins = [np.array(o) for o in origins]
cmaffeo2's avatar
cmaffeo2 committed
468
        neighborVecs = []
469
        for i,js in atomicModel._getNeighborHelixDict(part).items():
cmaffeo2's avatar
cmaffeo2 committed
470
471
472
473
474
475
            for j in js:
                neighborVecs.append( origins[j]-origins[i] )

        dots = [nv2.dot(nv1)/(np.linalg.norm(nv1)*np.linalg.norm(nv2)) for nv1 in neighborVecs for nv2 in neighborVecs]

        nDots = len(dots)
476
        nAngled = np.sum( np.abs(np.abs(dots)-0.5) < 0.2 )
cmaffeo2's avatar
cmaffeo2 committed
477
478
479
480
481
482
483
484
485
        if float(nAngled)/float(nDots) == 0:
            return 'square'
        elif float(nAngled)/float(nDots) < 0.05:
            print('WARNING: unusual neighbors in square lattice structure')
            return 'square'
        elif float(nAngled)/float(nDots) > 0.4:
            return 'honeycomb'
        else:
            raise Exception("Could not identify lattice type")
cmaffeo2's avatar
cmaffeo2 committed
486
           
487
488
489
    def backmap(self, simplerModel, simplerModelCoords, 
                dsDnaHelixNeighborDist=50, dsDnaAllNeighborDist=30,
                ssDnaHelixNeighborDist=20, ssDnaAllNeighborDist=15):
cmaffeo2's avatar
cmaffeo2 committed
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508

        ## Assign each nucleotide to a bead in simplerModel
        mapToSimplerModel = dict()
        cgWeight = dict()

        nts = [nt for seg in self.segments for nt in seg]
        for i in range(len(nts)):
            nt = nts[i]
            hid = nt.prop['helixId']
            if hid not in simplerModel.helices:
                raise Exception("Could not find helix %d in simple model; using just a single cadnano structure")
            cgH = simplerModel.helices[hid]
            zIdxs = np.array( sorted([i for i,b in cgH]) )

            zi = nt.prop['zIdx'] #
            cgi = np.searchsorted(zIdxs,zi,side='left',sorter=None)
            cgi, = [zIdxs[x] if x < len(zIdxs) else zIdxs[-1] for x in (cgi,)]
            mapToSimplerModel[i] = [cgH.nodes[x] for x in (cgi,)]

cmaffeo2's avatar
cmaffeo2 committed
509
        ## Find transformation for each bead of simplerModel used in mapping
cmaffeo2's avatar
cmaffeo2 committed
510
511
        trans = dict()
        for b in list(set([b for i,bs in mapToSimplerModel.items() for b in bs])):
512
513
514
            helixCutoff = dsDnaHelixNeighborDist if b.type[0] == 'd' else ssDnaHelixNeighborDist
            allCutoff = dsDnaAllNeighborDist if b.type[0] == 'd' else ssDnaAllNeighborDist

515
516
517
            r0 = b.initialPosition
            r1 = simplerModelCoords[b.idx]

cmaffeo2's avatar
cmaffeo2 committed
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
            if type(simplerModel).__name__ == "beadModelTwist" and \
                    b.orientationNode is not None:
                above = b.nodeAbove if b.nodeAbove is not None else b
                below = b.nodeBelow if b.nodeBelow is not None else b
                o = b.orientationNode
                assert(above != below)

                x0 = o.initialPosition - r0
                z0 = above.initialPosition - below.initialPosition

                r1 = simplerModelCoords[b.idx]
                x1 = simplerModelCoords[o.idx] - r1
                z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]
                x0,z0,x1,z1 = [x/np.linalg.norm(x) for x in (x0,z0,x1,z1)]
                y0,y1 = [np.cross(z,x) for x,z in zip((x0,x1),(z0,z1))]
                R,tmp1,tmp2 = minimizeRmsd( (x0,y0,z0), (x1,y1,z1) )

            else:
                ids = simplerModel._getNeighborhoodIds(b, simplerModelCoords, helixCutoff, allCutoff)
                posOld = np.array( [simplerModel.particles[i][0].initialPosition for i in ids] )
                posNew = np.array( [simplerModelCoords[i] for i in ids] )
                R,cOld,cNew = minimizeRmsd( posOld, posNew )

541
542
543
544
545
546
547
548
549
550
551
552
553
554
                if b.type[0] == 's':
                    above = b.nodeAbove if b.nodeAbove is not None else b
                    below = b.nodeBelow if b.nodeBelow is not None else b
                    assert(above != below)
                    z0 = above.initialPosition - below.initialPosition
                    z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]

                    zR = z0.dot(R)
                    cross = np.cross(z1,zR) / (np.linalg.norm(z1)*np.linalg.norm(zR))
                    angle = np.arcsin(np.linalg.norm(cross))*180/np.pi
                    R2 = rotationAboutAxis(cross,angle)
                    R = R.dot(R2)

            cOld,cNew = (r0,r1)
cmaffeo2's avatar
cmaffeo2 committed
555
            trans[b.idx] = (R,cOld,cNew)
cmaffeo2's avatar
cmaffeo2 committed
556
557
558
559
560
561
562
563

        ## Optionally smooth orientations
        
        ## Apply transformation to each bead of self
        for i in range(len(nts)):
            b = nts[i]
            cgb, = mapToSimplerModel[i]
            cgi = cgb.idx
564
            r0 = simplerModel.particles[cgi][0].initialPosition
cmaffeo2's avatar
cmaffeo2 committed
565
566
567
568
            R,c0,c1 = trans[cgi]
            # assert( np.linalg.norm((b.position-r0).dot(R)) < 100 )
            b.position = (b.position - r0).dot(R) + simplerModelCoords[cgi]
            b.orientation = b.orientation.dot( R )
569
570
571
572
573
574
575
576
577

        for i in range(len(nts)):
            b = nts[i]
            if b.basepair is not None:
                assert( np.all( b.orientation == b.basepair.orientation ) )
                # bpVec = (b.position - b.basepair.position)
                # zVec = np.array((0,0,1)).dot(b.orientation)
                # assert( np.dot( bpVec, zVec ) < 0.05 * np.linalg.norm(bpVec) )
                
578
579
580
581
582
583
584
585
586
587
    def _pairBases(self, n1,n2):
        assert(n1.basepair is None and n2.basepair is None)
        n1.basepair = n2
        n2.basepair = n1

    def _stackBases(below, above):
        assert(below.stack3prime is None and above.stack5prime is None)
        below.stack3prime = above
        above.stack5prime = below        

588
    def _removeIntrahelicalConnectionsBeyond(self, cutoff):
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
        ## lazy progarmming: leave nts in one segment after bond break
        nts = [nt for seg in self.segments for nt in seg]
        for n1 in nts:
            n2 = n1.ntAt3prime
            if n2 is not None:
                r2 = np.sum( (n1.position - n2.position)**2 )
                if r2 > cutoff**2:
                    if n1.ntAt3prime == n2:
                        assert(n2.ntAt5prime == n1)
                        n1.ntAt3prime = None
                        n2.ntAt5prime = None
                    elif n2.ntAt3prime == n1:
                        assert(n1.ntAt5prime == n2)
                        n1.ntAt5prime = None
                        n2.ntAt3prime = None
                    else:
                        raise

cmaffeo2's avatar
cmaffeo2 committed
607
608
609
610
    def scaffoldSeq(seqFile):
        return

    def randomizeUnassignedNtSeqs(self):
611
612
613
614
615
616
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = random.choice(tuple(seqComplement.keys()))
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]
617
618
619
620
621
622
623
624
625
626
627
628

    def setUnassignedNtSeqs(self,base):
        assert(base in ("A","T","C","G"))
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = base
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]

    def setScaffoldSeq(self,sequence):
        ## get longest segment
629
        segs = sorted( [(len(s),s) for s in self.segments], key=lambda x: x[0] )
630
631
        seg = segs[-1][1]
        assert( len(seg) > len(segs[-2][1]) )
632
        if len(seg) > len(sequence):
633
634
            raise Exception("Sequence (%d) too short for scaffold (%d)!" % (len(sequence),len(seg)))

635
636
637
638
639
        for nt,seq in zip(seg,sequence):
            assert(seq in ("A","T","C","G"))
            nt.sequence = seq
            if nt.basepair is not None:
                nt.basepair.sequence = seqComplement[nt.sequence]        
640
641
642
643
644
645

    def _assignAtomicIndexToNucleotides(self):
        natoms=0
        for seg in self.segments:
            for nt in seg:
                nt.firstAtomicIndex = natoms
646
                natoms += len(nt.atoms(transform=False))
647

648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
    def combineRegionLists(self,loHi1,loHi2):

        """Combines two lists of (lo,hi) pairs specifying integer
        regions a single list of regions.  """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])

        ## Break input into lists of compact regions
        compactRegions1,compactRegions2 = [[],[]]
        for compactRegions,loHi in zip(
                [compactRegions1,compactRegions2],
                [loHi1,loHi2]):
            tmp = []
            lastHi = loHi[0][0]-1
            for lo,hi in loHi:
                if lo-1 != lastHi:
                    compactRegions.append(tmp)
                    tmp = []
                tmp.append((lo,hi))
                lastHi = hi
            if len(tmp) > 0:
                compactRegions.append(tmp)

        ## Build result
        result = []
        region = []
        i,j = [0,0]
        compactRegions1.append([[1e10]])
        compactRegions2.append([[1e10]])
        while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
            cr1 = compactRegions1[i]
            cr2 = compactRegions2[j]

            ## initialize region
            if len(region) == 0:
                if cr1[0][0] <= cr2[0][0]:
                    region = cr1
                    i += 1
                    continue
                else:
                    region = cr2
                    j += 1
                    continue

            if region[-1][-1] >= cr1[0][0]:
                region = self.combineCompactRegionLists(region, cr1, intersect=False)
                i+=1
            elif region[-1][-1] >= cr2[0][0]:
                region = self.combineCompactRegionLists(region, cr2, intersect=False)
                j+=1
            else:
                result.extend(region)
                region = []

        assert( len(region) > 0 )
        result.extend(region)
        result = sorted(result)

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(result,"\n")

        return result

    def combineCompactRegionLists(self,loHi1,loHi2,intersect=False):

        """Combines two lists of (lo,hi) pairs specifying regions within a
        compact integer set into a single list of regions.

        examples:
        loHi1 = [[0,4],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 4), (5, 7), (8, 9)]

        loHi1 = [[0,3],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
        """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])
            ## Assert lists are compact
            for pair1,pair2 in zip(l[::2],l[1::2]):
                assert(pair1[1]+1 == pair2[0])


        ## Find the ends of the region
        lo = min( [loHi1[0][0], loHi2[0][0]] )
        hi = max( [loHi1[-1][1], loHi2[-1][1]] )

        ## Make a list of indices where each region will be split
        splitAfter = []
        for l,h in loHi2:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)

        for l,h in loHi1:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)
        splitAfter = sorted(list(set(splitAfter)))

        # print("splitAfter:",splitAfter)

        split=[]
        last = -2
        for s in splitAfter:
            split.append(s)
            last = s

        # print("split:",split)
        returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]

        if intersect:
            lo = max( [loHi1[0][0], loHi2[0][0]] )
            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
            returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(returnList,"\n")
        return returnList

    def _getChunksForHelixPair(self,hid1,hid2):
        chunks = []
        chunks1,chunks2 = [sorted(list(set(self.chunks[hid]))) for hid in [hid1,hid2]]
        return self.combineRegionLists(chunks1,chunks2)

        i,j = [0,0]
        lastHi = -100000
        # pdb.set_trace()
        while i < len(chunks1) or j < len(chunks2):
            if i >= len(chunks1):
                tmp
            tmp = [x for x in sorted(list(set(chunks1[i]+chunks2[j]))) if x > lastHi]

            # if len(tmp) > 1:
            #     chunks.append(tmp[:2])
            # if len(tmp) > 0:
            #     lastHi = tmp[min(1,len(tmp))]
            chunks.append(tmp[:2])
            lastHi = tmp[min(1,len(tmp))]

            if chunks1[i][1] <= lastHi: i+=1
            if chunks2[j][1] <= lastHi: j+=1
        return chunks

    def getNtsBetweenCadnanoIdxInclusive(self,hid,lo,hi,isFwd=True):
        if isFwd:
            helixNts = self.helixNtsFwd[hid]
        else:
            helixNts = self.helixNtsRev[hid]

        nts = []
        for i in range(lo,hi+1):
            if i in helixNts:
                nts.extend(helixNts[i])
        return nts

819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
    # -------------------------- #
    # Methods for querying model #
    # -------------------------- #
    def _getIntrahelicalNodeSeries(self,seriesLen):
        nodeSeries = set()
        for hid,hlx in self.helices.items():
            for zid,n in hlx:
                nodeList,sepList = n.getNodesAbove(seriesLen-1, inclusive = True)
                
                # nodeList = [n]
                # sepList = []
                # for i in range(seriesLen-1):
                #     if n.nodeAbove is None: break
                #     n = n.nodeAbove
                #     nodeList.append(n)
                #     sepList.append(n.nodeBelowSep)
                
                if len(nodeList) == seriesLen:
                    nodeList = tuple(nodeList)
                    sepList = tuple(sepList)
                    nodeSeries.add( tuple((nodeList,sepList)) )
        return nodeSeries

    def _getIntrahelicalBonds(self):
        return self._getIntrahelicalNodeSeries(2)
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( ((n, n.nodeAbove, n.nodeAboveSep) for zid,n in hlx if n.nodeAbove is not None) )
        #     bonds.update( ((n.nodeBelow, n, n.nodeBelowSep) for zid,n in hlx if n.nodeBelow is not None) )
        # return bonds

    def _getIntrahelicalAngles(self):
        return self._getIntrahelicalNodeSeries(3)

    def _getCrossoverBonds(self):
        return { ((n, xo[0]), (1)) 
                 for hid,hlx in self.helices.items()
                 for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx }
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( (((n, xo[0]), (1)) for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx) )
        # return bonds

    def _getCrossoverAnglesAndDihedrals(self):
        angles,dihedrals = [set(),set()]
        for hid,hlx in self.helices.items():
            lastXoverNode = None 
            bpsBetween = 0
            for zid,n in hlx:
                ## Search for a pair of crossovers
                if n.nodeBelow is None or n.type != "dsDNA":
                    ## Found ssDNA or a gap; reset search
                    lastXoverNode = None
                    bpsBetween = 0

                if lastXoverNode is None:
                    if len(n.xovers) > 0:
                        ## First node with a crossover
                        lastXoverNode = n
                else:
                    if n.nodeBelow is not None:
                        bpsBetween += n.nodeBelowSep
                    if len(n.xovers) > 0:
                        ## Second node with a crossover, Add dihedral(s)
                        for xo1 in lastXoverNode.xovers:
                            for xo2 in n.xovers:
                                assert( bpsBetween != 0 )
                                # if bpsBetween != 0
                                angles.add( ((xo1[0], lastXoverNode, n), bpsBetween) )
                                angles.add( ((lastXoverNode, n, xo2[0]), bpsBetween) )
                                dihedrals.add( ((xo1[0], lastXoverNode, n, xo2[0]), bpsBetween, xo1[1], xo2[1]) )

                        lastXoverNode = n
                        bpsBetween = 0

        return angles, dihedrals
        
    def _getBonds(self):
        bonds = self._getIntrahelicalBonds()
        bonds.update( self._getCrossoverBonds() )
        return bonds


    # -------------------------- #
    # Methods for printing model #
    # -------------------------- #
cmaffeo2's avatar
cmaffeo2 committed
905
    def writeNamdFiles(self,prefix,numSteps=48000):
906
        self._assignAtomicIndexToNucleotides()
cmaffeo2's avatar
cmaffeo2 committed
907
908
        self.writePdbPsf(prefix)
        self.writeENM(prefix)
cmaffeo2's avatar
cmaffeo2 committed
909
        self.writeNamdFile(prefix,numSteps)
cmaffeo2's avatar
cmaffeo2 committed
910

911
912
913
914
915
916
917
918
    def writePdbPsf(self, prefix):
        bonds,angles,dihedrals,impropers = [[],[],[],[]]
        with open("%s.pdb" % prefix,'w') as pdb, \
             open("%s.psf" % prefix,'w') as psf:

            ## Write headers
            pdb.write("CRYST1    1.000    1.000    1.000  90.00  90.00  90.00 P 1           1\n")
            psf.write("PSF NAMD\n\n") # create NAMD formatted psf
919
            psf.write("{:>8d} !NTITLE\n\n".format(0))
920
921
922

            ## Format strings
            ## http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM
923
924
            # pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}          {element:>2s}{charge:>2.2f}\n"
            pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}      {segname:4s}{element:>2s}\n"
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
            ## From vmd/plugins/molfile_plugin/src/psfplugin.c
            ## "%d %7s %10s %7s %7s %7s %f %f"
            psfFormat = "{idx:>8d} {segname:7s} {resid:<10d} {resname:7s} " + \
                        "{name:7s} {type:7s} {charge: f} {mass:f}\n"

            ## PSF ATOMS section
            natoms=0
            for seg in self.segments:
                for atoms in seg.atoms(transform=False):
                    natoms += len(atoms)
            psf.write("{:>8d} !NATOM\n".format(natoms))

            atomProps = dict(idx=0,altLoc='',chain='A', iCode='',
                             occupancy=0, beta=0, charge='', resid=1)

            resnameDict = dict(A='ADE',T='THY',G='GUA',C='CYT')

            ## Loop over segments, setting common properties
            for seg in self.segments:
                atomProps['segname'] = seg.props['segname']
                prevAtoms = None
                
                ## Loop over nucleotides, setting common properties
                for atoms, seq in zip(seg.atoms(),seg.sequence()):
                    atomProps['resname'] = resnameDict[seq]
                    if atomProps['resid'] < 9999:
                        atomProps['residString'] = "%4d" % atomProps['resid']
                    else:
                        digit = (atomProps['resid'] // 1000) - 10
                        char = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[digit]
                        rem = atomProps['resid'] - (digit+10)*1000
                        atomProps['residString'] = "{:1s}{:<3d}".format(char,rem)

                    ## Add bonds, angles, dihedrals associated with LKNA
                    if prevAtoms is not None:
960
961
962
963
                        c4idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C4'")
                        c2idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C2'")
                        h3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("H3'")
                        c3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C3'")
964
965
                        o3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("O3'")
                        Pidx = atomProps['idx'] + atoms.index("P")
966
967
968
969
                        o5idx = atomProps['idx'] + atoms.index("O5'")
                        o1pidx = atomProps['idx'] + atoms.index("O1P")
                        o2pidx = atomProps['idx'] + atoms.index("O2P")
                        c5idx = atomProps['idx'] + atoms.index("C5'")
970

971
972
973
974
975
976
977
978
979
980
981
982
                        bonds.append([o3idx, Pidx])
                        angles.append([c3idx, o3idx, Pidx])
                        for x in (o5idx,o1pidx,o2pidx):
                            angles.append([o3idx, Pidx, x])

                        for x in (c4idx,c2idx,h3idx):
                            dihedrals.append([x,c3idx, o3idx, Pidx])
                        for x in (o5idx,o1pidx,o2pidx):
                            dihedrals.append([c3idx, o3idx, Pidx,x])
                        dihedrals.append([o3idx, Pidx, o5idx, c5idx])
                        
                            
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
                    ## Find bonds,angles,dihedrals, and impropers
                    bonds.extend( atoms.bonds+atomProps['idx'] )
                    angles.extend( atoms.angles+atomProps['idx'] )
                    dihedrals.extend( atoms.dihedrals+atomProps['idx'] )
                    impropers.extend( atoms.impropers+atomProps['idx'] )

                    ## Loop over atoms
                    for props in atoms.atomicProps():
                        for k,v in props.items():
                            atomProps[k] = v
                        # atomProps['name'] = n
                        atomProps['element'] = atomProps['name'][0]

                        ## Write lines
                        pdb.write(pdbFormat.format( **atomProps ))
                        ## psfResid   = "%d%c%c" % (idx," "," "), # TODO: work with large indeces
                        ## Increment counter
                        atomProps['idx'] += 1
For faster browsing, not all history is shown. View entire blame