atomicModel.py 54.3 KB
Newer Older
1
2
3
4
5
6
7
# -*- coding: utf-8 -*-
from datetime import datetime
from cadnano.cnenum import PointType
from cadnano.strand import Strand
from math import pi,sqrt,exp,floor
import numpy as np
import random
8
import os, sys, subprocess
cmaffeo2's avatar
cmaffeo2 committed
9
from pdb import set_trace
10

cmaffeo2's avatar
cmaffeo2 committed
11
from coords import minimizeRmsd, rotationAboutAxis
12
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
13
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
14
15
16
17
18
19
20

class abstractNucleotide():
    def __init__(self, sequence, position, angle, ntAt5prime=None, **prop):
        self.sequence = sequence
        self.prop = prop
        # self.prop['isFwd'] = isFwd
        # self.isFwd = isFwd
21
        self.position = np.array(position)
cmaffeo2's avatar
cmaffeo2 committed
22
23
        # self.angle = angle
        self.orientation = rotationAboutAxis([0,0,1], angle)
24
25
26
27
28
29
        self.ntAt5prime = ntAt5prime
        self.ntAt3prime = None

        self.basepair = None
        self.stack5prime = None
        self.stack3prime = None
30
31

        self.firstAtomicIndex = -1
32
33
34
35
        
    def set3primeNt(self,n):
        self.ntAt3prime = n

36

37
38
39
40
41
42
43
44
45
46
47
48
    def atoms(self, transform=True, scale=1.0):
        key = self.sequence
        if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"
        
        if self.prop['isFwd']:
            atoms = canonicalNtFwd[ key ]
        else:
            atoms = canonicalNtRev[ key ]

        if transform:
49
50
            singleStranded = (self.basepair is None)
            atoms = atoms.transformed(self.orientation, self.position, scale, singleStranded)
51
52
        return atoms
            
53

54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class RigidBodyNode():
    def __init__(self, helix, pos, angle, type="dsDNA"):
        self.helix    = helix
        self.position = pos
        # self.angle = 
        self.type     = type
        # self.numBps
        self.nodeAbove = None
        self.nodeBelow = None
        self.xovers = []
        self.idx = helix.model.numNodes
        helix.model.numNodes = helix.model.numNodes + 1

    def addNodeAbove(self, node, separation):
        self.nodeAbove = node
        self.nodeAboveSep = separation # bp
    def addNodeBelow(self, node, separation):
        self.nodeBelow = node
        self.nodeBelowSep = separation # bp

    def addXover(self, node, polarity, double=False):
        ## TODO: what is meant by polarity?
        self.xovers.append( (node,polarity,double) )

    
    def getNodesAbove(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeAbove is None: break
            n = n.nodeAbove
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

    def getNodesBelow(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeBelow is None: break
            n = n.nodeBelow
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

class Segment():
cmaffeo2's avatar
cmaffeo2 committed
112
    def __init__(self, segname, ntScale=1.0):
113
        # self.model = model
cmaffeo2's avatar
cmaffeo2 committed
114
        self.ntScale = ntScale
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.props = dict()
        self.props['segname'] = segname
        self.nodes = dict()
        self.firstNode = None

    def addNucleotide(self, helixId, strandIdx, zIdx,
                      sequence, isFwd, position, angle, ntAt5prime=None):

        if helixId in self.nodes:
            if strandIdx in self.nodes[helixId]:
                for n in self.nodes[helixId][strandIdx]:
                    if n.prop['zIdx'] == zIdx and n.prop['isFwd'] == isFwd:
                        raise Exception("Attempted to add a node in the same location (%d:%.1f:%s) twice!" % (helixId,zIdx,isFwd))
        else:
            self.nodes[helixId] = dict()
        
        if strandIdx not in self.nodes[helixId]:
            self.nodes[helixId][strandIdx] = []

        n = abstractNucleotide(sequence, position, angle, ntAt5prime, 
135
                               isFwd=isFwd, helixId=helixId, zIdx=zIdx, idx=strandIdx)
136
137
138
139
140
141
142
143
144
145
146
147
148
149

        if ntAt5prime is not None:
            ntAt5prime.set3primeNt(n)

        self.nodes[helixId][strandIdx].append( n ) # TODO: use strandIdx?

        # if self._firstNode is not None:
        #     self._findFirstNode()

        # ## Update ordered list of nodes 
        # if self.model.nodes is not None:
        #     model.buildOrderedNodesList()
        return n

150
151
152
153
154
155
156
    def __len__(self):
        l = 0
        for hid,ntsAtIdx in self.nodes.items():
            for strandIdx,nts in ntsAtIdx.items():
                l += len(nts)
        return l

157
    def __iter__(self):
158
        for n in self._get5primeNts():            
159
            yield n
160
161
162
            while n.ntAt3prime is not None:
                n = n.ntAt3prime
                yield n
163
164
165

    def atoms(self, transform=True):
        for nt in self:
166
            yield nt.atoms(transform, self.ntScale)
167
 
168
169
170
171
    def sequence(self):
        for nt in self:
            yield nt.sequence

172
173
174
175
176
177
178
    def _get5primeNts(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)
179

180
181
182
183
184
185
        if len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
        return firstNodes
        
    ## TODO deprecate
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    def _findFirstNode(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)

        if len(firstNodes) > 1:
            raise Exception("Segment object contains two 5' ends!")
        elif len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
            
        return firstNodes[0]
    
cmaffeo2's avatar
cmaffeo2 committed
202
class atomicModel():
203

cmaffeo2's avatar
cmaffeo2 committed
204
    def __init__(self, part, ntScale=1.0, randomSeed=1):
cmaffeo2's avatar
cmaffeo2 committed
205
        self.dsChunks = dict()
206

207
208
        self.numNodes = 0
        self.segments = []
209
210
        self.helixNtsFwd = dict()
        self.helixNtsRev = dict()
211
        self.nodeTypeCounts = None
cmaffeo2's avatar
cmaffeo2 committed
212
        self.ntScale = ntScale
cmaffeo2's avatar
cmaffeo2 committed
213
        self.part = part
214
215

        # self._nbParams = set()
216
217
218
        self.bonds = set()
        self.angles = set()
        self.dihedrals = set()
219

220
221
222
        # self._bondParams = set()
        # self._angleParams = set()
        # self._dihedralParams = set()
223
224

        self._nbParamFiles = []
225
226
227
        # self._bondParamFiles = set()
        # self._angleParamFiles = set()
        # self._dihedralParamFiles = set()
228

cmaffeo2's avatar
cmaffeo2 committed
229
        random.seed(randomSeed)
cmaffeo2's avatar
cmaffeo2 committed
230
        self.useTclForces = False
231

232
        self.latticeType = atomicModel._determineLatticeType(part)
233
        self._buildModel(part)
cmaffeo2's avatar
cmaffeo2 committed
234
        
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    def _strandOccupancies(self, helixId):
        try:
            ends1,ends2 = self._getStrandEnds(helixId)
        except:
            ## hid not in strand_list
            return []

        strandOccupancies = [ [x for i in range(0,len(e),2)
                               for x in range(e[i],e[i+1]+1)]
                              for e in (ends1,ends2) ]
        return sorted(list(set( strandOccupancies[0] + strandOccupancies[1] )))

    def _getStrandEnds(self, helixId):
249
250
251
252

        """Utility method to convert cadnano strand lists into list of
        indices of terminal points"""

253
254
        helixStrands = self.strand_list[helixId]

255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        endLists = [[],[]]
        for endList, strandList in zip(endLists,helixStrands):
            lastStrand = None
            for s in strandList:
                if lastStrand is None:
                    ## first strand
                    endList.append(s[0])
                elif lastStrand[1] != s[0]-1: 
                    assert( s[0] > lastStrand[1] )
                    endList.extend( [lastStrand[1], s[0]] )
                lastStrand = s
            if lastStrand is not None:
                endList.append(lastStrand[1])
        return endLists

    # -------------------------- #
    # Methods for building model #
    # -------------------------- #
    def _buildModel(self, part):
        # maxVhelixId = part.getIdNumMax()
        
        props = part.getModelProperties().copy()
        # print(props)
        vh_props = []
        origins = []
        if props.get('point_type') == PointType.ARBITRARY:
            # TODO add code to encode Parts with ARBITRARY point configurations
cmaffeo2's avatar
cmaffeo2 committed
282
            raise Exception("Not implemented")
283
284
285
286
287
288
289
290
291
292
        else:
            vh_props, origins = part.helixPropertiesAndOrigins()
            # print(' VIRTUAL HELICES:', vh_props)
            # # print(' ORIGINS:', origins)
            # group_props['virtual_helices'] = vh_props
            # group_props['origins'] = origins
            
        ## TODO: compartmentalize following
        ## Loop over virtual helices and build lists of strands 
        vh_list = []
293
        self.strand_list = []
294
295
296
297
298
299
300
        xover_list = []

        numHID = part.getIdNumMax() + 1
        for id_num in range(numHID):
            offset_and_size = part.getOffsetAndSize(id_num)
            if offset_and_size is None:
                # add a placeholder
301
                self.strand_list.append(None)
302
                # prop_list.append(None)
303
304
305
306
307
308
309
310
            else:
                offset, size = offset_and_size
                vh_list.append((id_num, size))
                fwd_ss, rev_ss = part.getStrandSets(id_num)
                # for s in fwd_ss:
                #     print(' VHELIX %d fwd_ss:' % id_num, s)
                fwd_idxs, fwd_colors  = fwd_ss.dump(xover_list)
                rev_idxs, rev_colors  = rev_ss.dump(xover_list)
311
                self.strand_list.append((fwd_idxs, rev_idxs))
312
313
314
315
316
317
318
319
320
321
322
323
324
325
                
                ## prop_list.append((fwd_colors, rev_colors))
                # for s in strand_list:
                #     print( s )
        segid = 1

        """
        Strategy:
        1) Create all oligos/segments, describing 5'/3' bonding & location
        2) Add basepairing information
        3) For all basepairs, find stacks 5'/3' stacks
        """

        for hid in range(numHID):
326
327
            self.helixNtsFwd[hid] = dict()
            self.helixNtsRev[hid] = dict()
cmaffeo2's avatar
cmaffeo2 committed
328
            self.dsChunks[hid] = []
329
330
331

        ## Loop over oligos/segments
        segments = []
332
        ntCount = 0
333
334
        keyFn = lambda o: (o.length(), o._strand5p.idNum(), o._strand5p.idx5Prime())
        for oligo in reversed(sorted(part.oligos(), key=keyFn)):
335
336
337
            # if oligo.isLoop():
            #     # print("A loop exists")
            #     raise Exception("Loop strands are not supported")
338
            
cmaffeo2's avatar
cmaffeo2 committed
339
            seg = Segment("D%03d" % segid, self.ntScale)
340
341
342
343
344
345
346
            segments.append(seg)
            segid += 1

            lastNt = None
            pos = []
            angle = []
            seq = []
cmaffeo2's avatar
cmaffeo2 committed
347

348
349
350
351
352
353
354
355
356
357
358
359
360
            ## loop over strands in oligo
            for strand in oligo.strand5p().generator3pStrand():
                hid = strand.idNum() # virtual helix ID
                x,y = origins[hid]

                axis = [0,0,1]

                keys = ['bases_per_repeat',
                        'turns_per_repeat',
                        'eulerZ','z']
                bpr,tpr,eulerZ,z = [vh_props[k][hid] for k in keys] 

                twist_per_base = tpr*360./bpr
361
362
363
364
365
366
367
                ## override twist_per_base if square lattice
                if self.latticeType == "square":
                    # assert(twist_per_base == 360.0*3/32)
                    twist_per_base == 360.0*3/32
                else:
                    assert(twist_per_base == 360.0/10.5)

368
                zIdxToPos = lambda idx: (x*10,y*10,z-3.4*idx)
cmaffeo2's avatar
cmaffeo2 committed
369
                zIdxToAngle = lambda idx: idx*twist_per_base + eulerZ 
370

cmaffeo2's avatar
cmaffeo2 committed
371
372
373
                isFwd = strand.isForward()

                lo,hi = strand.idxs()
374
375
376
                seqs = Strand.sequence(strand, for_export=True)
                numNt = strand.totalLength()

cmaffeo2's avatar
cmaffeo2 committed
377
378
                ## We place nts as pairs; spacing is not always uniform along strand.
                ## Here we find the locations between which spacing will be uniform.
379
                compIdxs = [comp.idxs() for comp in strand.getComplementStrands()]
cmaffeo2's avatar
cmaffeo2 committed
380
                chunks = self.combineRegionLists(
381
382
383
384
                    [[lo,hi]],
                    compIdxs,
                    intersect=True
                )
cmaffeo2's avatar
cmaffeo2 committed
385
386
387
388
389
390
391
                self.dsChunks[hid].extend( chunks ) # Add running list of chunks for push bonds

                chunks = self.combineRegionLists(
                    [[lo,hi]],
                    chunks,
                    intersect=False
                )
392

393
                ## Find zIdx for each nucleotide
cmaffeo2's avatar
cmaffeo2 committed
394
                zIdxs = []
395
                for l,h in chunks:
cmaffeo2's avatar
cmaffeo2 committed
396
397
398
399
                    ## Find the number of nts between l,h
                    nts = h-l+1
                    for ins in strand.insertionsOnStrand(l,h):
                        nts += ins.length()
400

401
402
403
                    if nts <= 0:
                        print("WARNING: there is something strange in the structure at helix %d (%d,%d)" % (hid,l,h))
                        continue
cmaffeo2's avatar
cmaffeo2 committed
404
                    assert( nts > 0 )
cmaffeo2's avatar
cmaffeo2 committed
405
                    for i in range(nts):
cmaffeo2's avatar
cmaffeo2 committed
406
407
408
409
                        if nts == 1:
                            zIdxs.append( l )
                        else:
                            zIdxs.append( l + (h-l) * float(i)/(nts-1) )
cmaffeo2's avatar
cmaffeo2 committed
410

cmaffeo2's avatar
cmaffeo2 committed
411
                assert( len(zIdxs) == numNt )
412
                    
413

414
                if isFwd:
415
                    helixNts = self.helixNtsFwd[hid]
416
                else:
417
                    helixNts = self.helixNtsRev[hid]
418
419
                    zIdxs = reversed(zIdxs)

420
421
                ## Find strandOccupancy index for each nucleotide
                strandOccIdx = []
422
                for l,h in chunks:
423
424
425
426
                    for i in range(l,h+1):
                        nts = 1
                        for ins in strand.insertionsOnStrand(i,i):
                            nts += ins.length()
427
                        helixNts[i] = []
428
429
430
431
432
                        for j in range(nts):
                            strandOccIdx.append(i)
                if not isFwd:
                    strandOccIdx = reversed(strandOccIdx)

433
                ## Add nucleotides
434
                for s,zIdx,idx in zip(seqs,zIdxs,strandOccIdx):
435
436
                    p = zIdxToPos(zIdx) 
                    a = zIdxToAngle(zIdx)
437
                    lastNt = seg.addNucleotide(hid, idx, zIdx, s, isFwd,
438
439
                                               p, a, lastNt)

440
441
442
443
444
445
                    helixNts[idx].append(lastNt)

        ## Correct the order of nts in helixNtsRev
        for hid,ntsAtIdx in self.helixNtsRev.items():
            for i,nts in ntsAtIdx.items():
                self.helixNtsRev[hid][i] = reversed(nts)
446
447
448
449

        ## 2) Find basepairs and stacks
        for hid in range(numHID):
            prevBasepair = None
450
            for i in self._strandOccupancies(hid):
451
                ## Check that strand index is dsDNA 
452
                if not (i in self.helixNtsFwd[hid] and i in self.helixNtsRev[hid]):
453
                    continue
454

455
456
                for nt1,nt2 in zip(self.helixNtsFwd[hid][i],
                                   self.helixNtsRev[hid][i]):
457
458
                    self._pairBases(nt1,nt2)
                
459
                    if prevBasepair is not None and i - prevBasepair[0] <= 3: # TODO: find better way of locating stacks
460
461
462
463
                        atomicModel._stackBases( prevBasepair[1], nt1)
                        atomicModel._stackBases( nt2, prevBasepair[2])
                    # else:
                    #     print("NO STACK FOR %d:%d" % (hid,i))
464
                        #basepairs[hid].append((i,n1,n2))
465
                    prevBasepair = (i,nt1,nt2)
466
467
468
469
470
471
472
473
                    
        self.segments.extend(segments)
        return

    def _connectNodes(self, below, above, sep):
        below.addNodeAbove(above, sep)
        above.addNodeBelow(below, sep)
        
474
    def _getNeighborHelixDict(part):
cmaffeo2's avatar
cmaffeo2 committed
475
        props, origins = part.helixPropertiesAndOrigins()
476
477
478
479
        neighborHelices = dict()
        for i in range(len(origins)):
            neighborHelices[i] = [int(j.strip('[,]')) for j in props['neighbors'][i].split() if j != '[]']
        return neighborHelices
cmaffeo2's avatar
cmaffeo2 committed
480

481
482
483
    def _determineLatticeType(part):
        props, origins = part.helixPropertiesAndOrigins()
        origins = [np.array(o) for o in origins]
cmaffeo2's avatar
cmaffeo2 committed
484
        neighborVecs = []
485
        for i,js in atomicModel._getNeighborHelixDict(part).items():
cmaffeo2's avatar
cmaffeo2 committed
486
487
488
489
490
491
            for j in js:
                neighborVecs.append( origins[j]-origins[i] )

        dots = [nv2.dot(nv1)/(np.linalg.norm(nv1)*np.linalg.norm(nv2)) for nv1 in neighborVecs for nv2 in neighborVecs]

        nDots = len(dots)
492
        nAngled = np.sum( np.abs(np.abs(dots)-0.5) < 0.2 )
cmaffeo2's avatar
cmaffeo2 committed
493
494
495
496
497
498
499
500
501
        if float(nAngled)/float(nDots) == 0:
            return 'square'
        elif float(nAngled)/float(nDots) < 0.05:
            print('WARNING: unusual neighbors in square lattice structure')
            return 'square'
        elif float(nAngled)/float(nDots) > 0.4:
            return 'honeycomb'
        else:
            raise Exception("Could not identify lattice type")
cmaffeo2's avatar
cmaffeo2 committed
502
           
503
504
505
    def backmap(self, simplerModel, simplerModelCoords, 
                dsDnaHelixNeighborDist=50, dsDnaAllNeighborDist=30,
                ssDnaHelixNeighborDist=20, ssDnaAllNeighborDist=15):
cmaffeo2's avatar
cmaffeo2 committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524

        ## Assign each nucleotide to a bead in simplerModel
        mapToSimplerModel = dict()
        cgWeight = dict()

        nts = [nt for seg in self.segments for nt in seg]
        for i in range(len(nts)):
            nt = nts[i]
            hid = nt.prop['helixId']
            if hid not in simplerModel.helices:
                raise Exception("Could not find helix %d in simple model; using just a single cadnano structure")
            cgH = simplerModel.helices[hid]
            zIdxs = np.array( sorted([i for i,b in cgH]) )

            zi = nt.prop['zIdx'] #
            cgi = np.searchsorted(zIdxs,zi,side='left',sorter=None)
            cgi, = [zIdxs[x] if x < len(zIdxs) else zIdxs[-1] for x in (cgi,)]
            mapToSimplerModel[i] = [cgH.nodes[x] for x in (cgi,)]

cmaffeo2's avatar
cmaffeo2 committed
525
        ## Find transformation for each bead of simplerModel used in mapping
cmaffeo2's avatar
cmaffeo2 committed
526
527
        trans = dict()
        for b in list(set([b for i,bs in mapToSimplerModel.items() for b in bs])):
528
529
530
            helixCutoff = dsDnaHelixNeighborDist if b.type[0] == 'd' else ssDnaHelixNeighborDist
            allCutoff = dsDnaAllNeighborDist if b.type[0] == 'd' else ssDnaAllNeighborDist

531
532
533
            r0 = b.initialPosition
            r1 = simplerModelCoords[b.idx]

534
535
536
537
538
            above = b.nodeAbove if b.nodeAbove is not None and \
                    np.sum((b.nodeAbove.initialPosition-b.initialPosition)**2) < 10**2 else b
            below = b.nodeBelow if b.nodeBelow is not None and \
                    np.sum((b.nodeBelow.initialPosition-b.initialPosition)**2) < 10**2 else b

cmaffeo2's avatar
cmaffeo2 committed
539
540
541
542
543
544
            if type(simplerModel).__name__ == "beadModelTwist" and \
                    b.orientationNode is not None:
                o = b.orientationNode
                assert(above != below)

                z0 = above.initialPosition - below.initialPosition
545
546
                z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]
                z0,z1 = [z/np.linalg.norm(z) for z in (z0,z1)]
cmaffeo2's avatar
cmaffeo2 committed
547

548
                x0 = o.initialPosition - r0
cmaffeo2's avatar
cmaffeo2 committed
549
                x1 = simplerModelCoords[o.idx] - r1
550
551
552
                x0,x1 = [x-x.dot(z)*z for x,z in zip((x0,x1),(z0,z1))] # make x orthogonal to z
                x0,x1 = [x/np.linalg.norm(x) for x in (x0,x1)]

cmaffeo2's avatar
cmaffeo2 committed
553
554
555
556
557
558
559
560
561
                y0,y1 = [np.cross(z,x) for x,z in zip((x0,x1),(z0,z1))]
                R,tmp1,tmp2 = minimizeRmsd( (x0,y0,z0), (x1,y1,z1) )

            else:
                ids = simplerModel._getNeighborhoodIds(b, simplerModelCoords, helixCutoff, allCutoff)
                posOld = np.array( [simplerModel.particles[i][0].initialPosition for i in ids] )
                posNew = np.array( [simplerModelCoords[i] for i in ids] )
                R,cOld,cNew = minimizeRmsd( posOld, posNew )

562
563
564
565
566
567
568
569
570
                if b.type[0] == 's':
                    assert(above != below)
                    z0 = above.initialPosition - below.initialPosition
                    z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]

                    zR = z0.dot(R)
                    cross = np.cross(z1,zR) / (np.linalg.norm(z1)*np.linalg.norm(zR))
                    angle = np.arcsin(np.linalg.norm(cross))*180/np.pi
                    R2 = rotationAboutAxis(cross,angle)
571

572
573
574
                    R = R.dot(R2)

            cOld,cNew = (r0,r1)
cmaffeo2's avatar
cmaffeo2 committed
575
            trans[b.idx] = (R,cOld,cNew)
cmaffeo2's avatar
cmaffeo2 committed
576
577
578
579
580
581
582
583

        ## Optionally smooth orientations
        
        ## Apply transformation to each bead of self
        for i in range(len(nts)):
            b = nts[i]
            cgb, = mapToSimplerModel[i]
            cgi = cgb.idx
584
            r0 = simplerModel.particles[cgi][0].initialPosition
cmaffeo2's avatar
cmaffeo2 committed
585
586
587
588
            R,c0,c1 = trans[cgi]
            # assert( np.linalg.norm((b.position-r0).dot(R)) < 100 )
            b.position = (b.position - r0).dot(R) + simplerModelCoords[cgi]
            b.orientation = b.orientation.dot( R )
589
590
591
592
593
594
595
596
597

        for i in range(len(nts)):
            b = nts[i]
            if b.basepair is not None:
                assert( np.all( b.orientation == b.basepair.orientation ) )
                # bpVec = (b.position - b.basepair.position)
                # zVec = np.array((0,0,1)).dot(b.orientation)
                # assert( np.dot( bpVec, zVec ) < 0.05 * np.linalg.norm(bpVec) )
                
598
599
600
601
602
603
604
605
606
607
    def _pairBases(self, n1,n2):
        assert(n1.basepair is None and n2.basepair is None)
        n1.basepair = n2
        n2.basepair = n1

    def _stackBases(below, above):
        assert(below.stack3prime is None and above.stack5prime is None)
        below.stack3prime = above
        above.stack5prime = below        

608
    def _removeIntrahelicalConnectionsBeyond(self, cutoff):
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
        ## lazy progarmming: leave nts in one segment after bond break
        nts = [nt for seg in self.segments for nt in seg]
        for n1 in nts:
            n2 = n1.ntAt3prime
            if n2 is not None:
                r2 = np.sum( (n1.position - n2.position)**2 )
                if r2 > cutoff**2:
                    if n1.ntAt3prime == n2:
                        assert(n2.ntAt5prime == n1)
                        n1.ntAt3prime = None
                        n2.ntAt5prime = None
                    elif n2.ntAt3prime == n1:
                        assert(n1.ntAt5prime == n2)
                        n1.ntAt5prime = None
                        n2.ntAt3prime = None
                    else:
                        raise

cmaffeo2's avatar
cmaffeo2 committed
627
628
629
630
    def scaffoldSeq(seqFile):
        return

    def randomizeUnassignedNtSeqs(self):
631
632
633
634
635
636
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = random.choice(tuple(seqComplement.keys()))
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]
637
638
639
640
641
642
643
644
645
646
647
648

    def setUnassignedNtSeqs(self,base):
        assert(base in ("A","T","C","G"))
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = base
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]

    def setScaffoldSeq(self,sequence):
        ## get longest segment
649
        segs = sorted( [(len(s),s) for s in self.segments], key=lambda x: x[0] )
650
651
        seg = segs[-1][1]
        assert( len(seg) > len(segs[-2][1]) )
652
        if len(seg) > len(sequence):
653
654
            raise Exception("Sequence (%d) too short for scaffold (%d)!" % (len(sequence),len(seg)))

655
656
657
658
659
        for nt,seq in zip(seg,sequence):
            assert(seq in ("A","T","C","G"))
            nt.sequence = seq
            if nt.basepair is not None:
                nt.basepair.sequence = seqComplement[nt.sequence]        
660
661
662
663
664
665

    def _assignAtomicIndexToNucleotides(self):
        natoms=0
        for seg in self.segments:
            for nt in seg:
                nt.firstAtomicIndex = natoms
666
                natoms += len(nt.atoms(transform=False))
667

cmaffeo2's avatar
cmaffeo2 committed
668
    def combineRegionLists(self,loHi1,loHi2,intersect=False):
669
670
671
672
673
674
675
676
677
678
679

        """Combines two lists of (lo,hi) pairs specifying integer
        regions a single list of regions.  """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])

cmaffeo2's avatar
cmaffeo2 committed
680
681
682
683
684
685
686
687
688
689
690
        if len(loHi1) == 0:
            if intersect:
                return []
            else:
                return loHi2
        if len(loHi2) == 0:
            if intersect:
                return []
            else:
                return loHi1

691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        ## Break input into lists of compact regions
        compactRegions1,compactRegions2 = [[],[]]
        for compactRegions,loHi in zip(
                [compactRegions1,compactRegions2],
                [loHi1,loHi2]):
            tmp = []
            lastHi = loHi[0][0]-1
            for lo,hi in loHi:
                if lo-1 != lastHi:
                    compactRegions.append(tmp)
                    tmp = []
                tmp.append((lo,hi))
                lastHi = hi
            if len(tmp) > 0:
                compactRegions.append(tmp)

        ## Build result
        result = []
        region = []
        i,j = [0,0]
        compactRegions1.append([[1e10]])
        compactRegions2.append([[1e10]])
        while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
            cr1 = compactRegions1[i]
            cr2 = compactRegions2[j]

            ## initialize region
            if len(region) == 0:
                if cr1[0][0] <= cr2[0][0]:
                    region = cr1
                    i += 1
                    continue
                else:
                    region = cr2
                    j += 1
                    continue

            if region[-1][-1] >= cr1[0][0]:
                region = self.combineCompactRegionLists(region, cr1, intersect=False)
                i+=1
            elif region[-1][-1] >= cr2[0][0]:
                region = self.combineCompactRegionLists(region, cr2, intersect=False)
                j+=1
            else:
                result.extend(region)
                region = []

        assert( len(region) > 0 )
        result.extend(region)
        result = sorted(result)

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(result,"\n")

cmaffeo2's avatar
cmaffeo2 committed
746
747
748
749
750
        if intersect:
            lo = max( [loHi1[0][0], loHi2[0][0]] )
            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
            result = [r for r in result if r[0] >= lo and r[1] <= hi]

751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        return result

    def combineCompactRegionLists(self,loHi1,loHi2,intersect=False):

        """Combines two lists of (lo,hi) pairs specifying regions within a
        compact integer set into a single list of regions.

        examples:
        loHi1 = [[0,4],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 4), (5, 7), (8, 9)]

        loHi1 = [[0,3],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
        """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])
            ## Assert lists are compact
            for pair1,pair2 in zip(l[::2],l[1::2]):
                assert(pair1[1]+1 == pair2[0])

cmaffeo2's avatar
cmaffeo2 committed
778
779
780
781
782
783
784
785
786
787
        if len(loHi1) == 0:
            if intersect:
                return []
            else:
                return loHi2
        if len(loHi2) == 0:
            if intersect:
                return []
            else:
                return loHi1
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828

        ## Find the ends of the region
        lo = min( [loHi1[0][0], loHi2[0][0]] )
        hi = max( [loHi1[-1][1], loHi2[-1][1]] )

        ## Make a list of indices where each region will be split
        splitAfter = []
        for l,h in loHi2:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)

        for l,h in loHi1:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)
        splitAfter = sorted(list(set(splitAfter)))

        # print("splitAfter:",splitAfter)

        split=[]
        last = -2
        for s in splitAfter:
            split.append(s)
            last = s

        # print("split:",split)
        returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]

        if intersect:
            lo = max( [loHi1[0][0], loHi2[0][0]] )
            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
            returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(returnList,"\n")
        return returnList

cmaffeo2's avatar
cmaffeo2 committed
829
    def _getDsChunksForHelixPair(self,hid1,hid2):
830
        chunks = []
cmaffeo2's avatar
cmaffeo2 committed
831
        chunks1,chunks2 = [sorted(list(set(self.dsChunks[hid]))) for hid in [hid1,hid2]]
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        return self.combineRegionLists(chunks1,chunks2)

        i,j = [0,0]
        lastHi = -100000
        # pdb.set_trace()
        while i < len(chunks1) or j < len(chunks2):
            if i >= len(chunks1):
                tmp
            tmp = [x for x in sorted(list(set(chunks1[i]+chunks2[j]))) if x > lastHi]

            # if len(tmp) > 1:
            #     chunks.append(tmp[:2])
            # if len(tmp) > 0:
            #     lastHi = tmp[min(1,len(tmp))]
            chunks.append(tmp[:2])
            lastHi = tmp[min(1,len(tmp))]

            if chunks1[i][1] <= lastHi: i+=1
            if chunks2[j][1] <= lastHi: j+=1
        return chunks

    def getNtsBetweenCadnanoIdxInclusive(self,hid,lo,hi,isFwd=True):
        if isFwd:
            helixNts = self.helixNtsFwd[hid]
        else:
            helixNts = self.helixNtsRev[hid]

        nts = []
        for i in range(lo,hi+1):
            if i in helixNts:
                nts.extend(helixNts[i])
        return nts

865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
    # -------------------------- #
    # Methods for querying model #
    # -------------------------- #
    def _getIntrahelicalNodeSeries(self,seriesLen):
        nodeSeries = set()
        for hid,hlx in self.helices.items():
            for zid,n in hlx:
                nodeList,sepList = n.getNodesAbove(seriesLen-1, inclusive = True)
                
                # nodeList = [n]
                # sepList = []
                # for i in range(seriesLen-1):
                #     if n.nodeAbove is None: break
                #     n = n.nodeAbove
                #     nodeList.append(n)
                #     sepList.append(n.nodeBelowSep)
                
                if len(nodeList) == seriesLen:
                    nodeList = tuple(nodeList)
                    sepList = tuple(sepList)
                    nodeSeries.add( tuple((nodeList,sepList)) )
        return nodeSeries

    def _getIntrahelicalBonds(self):
        return self._getIntrahelicalNodeSeries(2)
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( ((n, n.nodeAbove, n.nodeAboveSep) for zid,n in hlx if n.nodeAbove is not None) )
        #     bonds.update( ((n.nodeBelow, n, n.nodeBelowSep) for zid,n in hlx if n.nodeBelow is not None) )
        # return bonds

    def _getIntrahelicalAngles(self):
        return self._getIntrahelicalNodeSeries(3)

    def _getCrossoverBonds(self):
        return { ((n, xo[0]), (1)) 
                 for hid,hlx in self.helices.items()
                 for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx }
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( (((n, xo[0]), (1)) for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx) )
        # return bonds

    def _getCrossoverAnglesAndDihedrals(self):
        angles,dihedrals = [set(),set()]
        for hid,hlx in self.helices.items():
            lastXoverNode = None 
            bpsBetween = 0
            for zid,n in hlx:
                ## Search for a pair of crossovers
                if n.nodeBelow is None or n.type != "dsDNA":
                    ## Found ssDNA or a gap; reset search
                    lastXoverNode = None
                    bpsBetween = 0

                if lastXoverNode is None:
                    if len(n.xovers) > 0:
                        ## First node with a crossover
                        lastXoverNode = n
                else:
                    if n.nodeBelow is not None:
                        bpsBetween += n.nodeBelowSep
                    if len(n.xovers) > 0:
                        ## Second node with a crossover, Add dihedral(s)
                        for xo1 in lastXoverNode.xovers:
                            for xo2 in n.xovers:
                                assert( bpsBetween != 0 )
                                # if bpsBetween != 0
                                angles.add( ((xo1[0], lastXoverNode, n), bpsBetween) )
                                angles.add( ((lastXoverNode, n, xo2[0]), bpsBetween) )
                                dihedrals.add( ((xo1[0], lastXoverNode, n, xo2[0]), bpsBetween, xo1[1], xo2[1]) )

                        lastXoverNode = n
                        bpsBetween = 0

        return angles, dihedrals
        
    def _getBonds(self):
        bonds = self._getIntrahelicalBonds()
        bonds.update( self._getCrossoverBonds() )
        return bonds


    # -------------------------- #
    # Methods for printing model #
    # -------------------------- #
cmaffeo2's avatar
cmaffeo2 committed
951
    def writeNamdFiles(self,prefix,numSteps=48000):
952
        self._assignAtomicIndexToNucleotides()
cmaffeo2's avatar
cmaffeo2 committed
953
954
        self.writePdbPsf(prefix)
        self.writeENM(prefix)
cmaffeo2's avatar
cmaffeo2 committed
955
        self.writeNamdFile(prefix,numSteps)
cmaffeo2's avatar
cmaffeo2 committed
956

957
958
959
960
961
962
    def writePdbPsf(self, prefix):
        bonds,angles,dihedrals,impropers = [[],[],[],[]]
        with open("%s.pdb" % prefix,'w') as pdb, \
             open("%s.psf" % prefix,'w') as psf:

            ## Write headers
963
            pdb.write("CRYST1    1000.    1000.    1000.  90.00  90.00  90.00 P 1           1\n")
964
            psf.write("PSF NAMD\n\n") # create NAMD formatted psf
965
            psf.write("{:>8d} !NTITLE\n\n".format(0))
966
967
968

            ## Format strings
            ## http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM
969
970
            # pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}          {element:>2s}{charge:>2.2f}\n"
            pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}      {segname:4s}{element:>2s}\n"
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
            ## From vmd/plugins/molfile_plugin/src/psfplugin.c
            ## "%d %7s %10s %7s %7s %7s %f %f"
            psfFormat = "{idx:>8d} {segname:7s} {resid:<10d} {resname:7s} " + \
                        "{name:7s} {type:7s} {charge: f} {mass:f}\n"

            ## PSF ATOMS section
            natoms=0
            for seg in self.segments:
                for atoms in seg.atoms(transform=False):
                    natoms += len(atoms)
            psf.write("{:>8d} !NATOM\n".format(natoms))

            atomProps = dict(idx=0,altLoc='',chain='A', iCode='',
                             occupancy=0, beta=0, charge='', resid=1)

            resnameDict = dict(A='ADE',T='THY',G='GUA',C='CYT')

            ## Loop over segments, setting common properties
            for seg in self.segments:
                atomProps['segname'] = seg.props['segname']
                prevAtoms = None
                
                ## Loop over nucleotides, setting common properties
                for atoms, seq in zip(seg.atoms(),seg.sequence()):
                    atomProps['resname'] = resnameDict[seq]
                    if atomProps['resid'] < 9999:
                        atomProps['residString'] = "%4d" % atomProps['resid']
                    else:
                        digit = (atomProps['resid'] // 1000) - 10
                        char = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[digit]
For faster browsing, not all history is shown. View entire blame