atomicModel.py 54.84 KiB
# -*- coding: utf-8 -*-
from datetime import datetime
from cadnano.cnenum import PointType
from cadnano.strand import Strand
from math import pi,sqrt,exp,floor
import numpy as np
import random
import os, sys, subprocess
from pdb import set_trace
from coords import minimizeRmsd, rotationAboutAxis
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC
class abstractNucleotide():
def __init__(self, sequence, position, angle, ntAt5prime=None, **prop):
self.sequence = sequence
self.prop = prop
# self.prop['isFwd'] = isFwd
# self.isFwd = isFwd
self.position = np.array(position)
# self.angle = angle
self.orientation = rotationAboutAxis([0,0,1], angle)
self.ntAt5prime = ntAt5prime
self.ntAt3prime = None
self.basepair = None
self.stack5prime = None
self.stack3prime = None
self.firstAtomicIndex = -1
def set3primeNt(self,n):
self.ntAt3prime = n
def atoms(self, transform=True, scale=1.0):
key = self.sequence
if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"
if self.prop['isFwd']:
atoms = canonicalNtFwd[ key ]
else:
atoms = canonicalNtRev[ key ]
if transform:
singleStranded = (self.basepair is None)
atoms = atoms.transformed(self.orientation, self.position, scale, singleStranded)
return atoms
class RigidBodyNode():
def __init__(self, helix, pos, angle, type="dsDNA"):
self.helix = helix
self.position = pos
# self.angle =
self.type = type
# self.numBps
self.nodeAbove = None
self.nodeBelow = None
self.xovers = []
self.idx = helix.model.numNodes
helix.model.numNodes = helix.model.numNodes + 1
def addNodeAbove(self, node, separation):
self.nodeAbove = node
self.nodeAboveSep = separation # bp
def addNodeBelow(self, node, separation):
self.nodeBelow = node
self.nodeBelowSep = separation # bp
def addXover(self, node, polarity, double=False):
## TODO: what is meant by polarity?
self.xovers.append( (node,polarity,double) )
def getNodesAbove(self,numNodes,inclusive=False):
assert( type(numNodes) is int and numNodes > 0 )
nodeList,sepList = [[],[]]
n = self
if inclusive:
nodeList.append(n)
for i in range(numNodes):
if n.nodeAbove is None: break
n = n.nodeAbove
nodeList.append(n)
sepList.append(n.nodeBelowSep)
return nodeList,sepList
def getNodesBelow(self,numNodes,inclusive=False):
assert( type(numNodes) is int and numNodes > 0 )
nodeList,sepList = [[],[]]
n = self
if inclusive:
nodeList.append(n)
for i in range(numNodes):
if n.nodeBelow is None: break
n = n.nodeBelow
nodeList.append(n)
sepList.append(n.nodeBelowSep)
return nodeList,sepList
class Segment():
def __init__(self, segname, ntScale=1.0):
# self.model = model
self.ntScale = ntScale
self.props = dict()
self.props['segname'] = segname
self.nodes = dict()
self.firstNode = None
def addNucleotide(self, helixId, strandIdx, zIdx,
sequence, isFwd, position, angle, ntAt5prime=None):
if helixId in self.nodes:
if strandIdx in self.nodes[helixId]:
for n in self.nodes[helixId][strandIdx]:
if n.prop['zIdx'] == zIdx and n.prop['isFwd'] == isFwd:
raise Exception("Attempted to add a node in the same location (%d:%.1f:%s) twice!" % (helixId,zIdx,isFwd))
else:
self.nodes[helixId] = dict()
if strandIdx not in self.nodes[helixId]:
self.nodes[helixId][strandIdx] = []
n = abstractNucleotide(sequence, position, angle, ntAt5prime,
isFwd=isFwd, helixId=helixId, zIdx=zIdx, idx=strandIdx)
if ntAt5prime is not None:
ntAt5prime.set3primeNt(n)
self.nodes[helixId][strandIdx].append( n ) # TODO: use strandIdx?
# if self._firstNode is not None:
# self._findFirstNode()
# ## Update ordered list of nodes
# if self.model.nodes is not None:
# model.buildOrderedNodesList()
return n
def __len__(self):
l = 0
for hid,ntsAtIdx in self.nodes.items():
for strandIdx,nts in ntsAtIdx.items():
l += len(nts)
return l
def __iter__(self):
for n in self._get5primeNts():
yield n
while n.ntAt3prime is not None:
n = n.ntAt3prime
yield n
def atoms(self, transform=True):
for nt in self:
yield nt.atoms(transform, self.ntScale)
def sequence(self):
for nt in self:
yield nt.sequence
def _get5primeNts(self):
firstNodes = []
for a,x in self.nodes.items():
for b,y in x.items():
for nt in y:
if nt.ntAt5prime is None:
firstNodes.append(nt)
if len(firstNodes) == 0:
print("WARNING: found circular segment; untested", file=sys.stderr)
firstNodes = self.nodes.items()[0].items()[0]
return firstNodes
## TODO deprecate
def _findFirstNode(self):
firstNodes = []
for a,x in self.nodes.items():
for b,y in x.items():
for nt in y:
if nt.ntAt5prime is None:
firstNodes.append(nt)
if len(firstNodes) > 1:
raise Exception("Segment object contains two 5' ends!")
elif len(firstNodes) == 0:
print("WARNING: found circular segment; untested", file=sys.stderr)
firstNodes = self.nodes.items()[0].items()[0]
return firstNodes[0]
class atomicModel():
def __init__(self, part, ntScale=1.0, randomSeed=1):
self.dsChunks = dict()
self.numNodes = 0
self.segments = []
self.helixNtsFwd = dict()
self.helixNtsRev = dict()
self.nodeTypeCounts = None
self.ntScale = ntScale
self.part = part
# self._nbParams = set()
self.bonds = set()
self.angles = set()
self.dihedrals = set()
# self._bondParams = set()
# self._angleParams = set()
# self._dihedralParams = set()
self._nbParamFiles = []
# self._bondParamFiles = set()
# self._angleParamFiles = set()
# self._dihedralParamFiles = set()
random.seed(randomSeed)
self.useTclForces = False
self.latticeType = atomicModel._determineLatticeType(part)
self._buildModel(part)
def _strandOccupancies(self, helixId):
try:
ends1,ends2 = self._getStrandEnds(helixId)
except:
## hid not in strand_list
return []
strandOccupancies = [ [x for i in range(0,len(e),2)
for x in range(e[i],e[i+1]+1)]
for e in (ends1,ends2) ]
return sorted(list(set( strandOccupancies[0] + strandOccupancies[1] )))
def _getStrandEnds(self, helixId):
"""Utility method to convert cadnano strand lists into list of
indices of terminal points"""
helixStrands = self.strand_list[helixId]
endLists = [[],[]]
for endList, strandList in zip(endLists,helixStrands):
lastStrand = None
for s in strandList:
if lastStrand is None:
## first strand
endList.append(s[0])
elif lastStrand[1] != s[0]-1:
assert( s[0] > lastStrand[1] )
endList.extend( [lastStrand[1], s[0]] )
lastStrand = s
if lastStrand is not None:
endList.append(lastStrand[1])
return endLists
# -------------------------- #
# Methods for building model #
# -------------------------- #
def _buildModel(self, part):
# maxVhelixId = part.getIdNumMax()
props = part.getModelProperties().copy()
# print(props)
vh_props = []
origins = []
if props.get('point_type') == PointType.ARBITRARY:
# TODO add code to encode Parts with ARBITRARY point configurations
raise Exception("Not implemented")
else:
vh_props, origins = part.helixPropertiesAndOrigins()
# print(' VIRTUAL HELICES:', vh_props)
# # print(' ORIGINS:', origins)
# group_props['virtual_helices'] = vh_props
# group_props['origins'] = origins
## TODO: compartmentalize following
## Loop over virtual helices and build lists of strands
vh_list = []
self.strand_list = []
xover_list = []
numHID = part.getIdNumMax() + 1
for id_num in range(numHID):
offset_and_size = part.getOffsetAndSize(id_num)
if offset_and_size is None:
# add a placeholder
self.strand_list.append(None)
# prop_list.append(None)
else:
offset, size = offset_and_size
vh_list.append((id_num, size))
fwd_ss, rev_ss = part.getStrandSets(id_num)
# for s in fwd_ss:
# print(' VHELIX %d fwd_ss:' % id_num, s)
fwd_idxs, fwd_colors = fwd_ss.dump(xover_list)
rev_idxs, rev_colors = rev_ss.dump(xover_list)
self.strand_list.append((fwd_idxs, rev_idxs))
## prop_list.append((fwd_colors, rev_colors))
# for s in strand_list:
# print( s )
segid = 1
"""
Strategy:
1) Create all oligos/segments, describing 5'/3' bonding & location
2) Add basepairing information
3) For all basepairs, find stacks 5'/3' stacks
"""
for hid in range(numHID):
self.helixNtsFwd[hid] = dict()
self.helixNtsRev[hid] = dict()
self.dsChunks[hid] = []
## Loop over oligos/segments
segments = []
ntCount = 0
keyFn = lambda o: (o.length(), o._strand5p.idNum(), o._strand5p.idx5Prime())
for oligo in reversed(sorted(part.oligos(), key=keyFn)):
# if oligo.isLoop():
# # print("A loop exists")
# raise Exception("Loop strands are not supported")
seg = Segment("D%03d" % segid, self.ntScale)
segments.append(seg)
segid += 1
lastNt = None
pos = []
angle = []
seq = []
## loop over strands in oligo
for strand in oligo.strand5p().generator3pStrand():
hid = strand.idNum() # virtual helix ID
x,y = origins[hid]
axis = [0,0,1]
keys = ['bases_per_repeat',
'turns_per_repeat',
'eulerZ','z']
bpr,tpr,eulerZ,z = [vh_props[k][hid] for k in keys]
twist_per_base = tpr*360./bpr
## override twist_per_base if square lattice
if self.latticeType == "square":
# assert(twist_per_base == 360.0*3/32)
twist_per_base == 360.0*3/32
else:
assert(twist_per_base == 360.0/10.5)
zIdxToPos = lambda idx: (x*10,y*10,z-3.4*idx)
zIdxToAngle = lambda idx: idx*twist_per_base + eulerZ
isFwd = strand.isForward()
lo,hi = strand.idxs()
seqs = Strand.sequence(strand, for_export=True)
numNt = strand.totalLength()
## We place nts as pairs; spacing is not always uniform along strand.
## Here we find the locations between which spacing will be uniform.
compIdxs = [comp.idxs() for comp in strand.getComplementStrands()]
chunks = self.combineRegionLists(
[[lo,hi]],
compIdxs,
intersect=True
)
self.dsChunks[hid].extend( chunks ) # Add running list of chunks for push bonds
chunks = self.combineRegionLists(
[[lo,hi]],
chunks,
intersect=False
)
## Find zIdx for each nucleotide
zIdxs = []
for l,h in chunks:
## Find the number of nts between l,h
nts = h-l+1
for ins in strand.insertionsOnStrand(l,h):
nts += ins.length()
if nts <= 0:
print("WARNING: there is something strange in the structure at helix %d (%d,%d)" % (hid,l,h))
continue
assert( nts > 0 )
for i in range(nts):
if nts == 1:
zIdxs.append( l )
else:
zIdxs.append( l + (h-l) * float(i)/(nts-1) )
assert( len(zIdxs) == numNt )
if isFwd:
helixNts = self.helixNtsFwd[hid]
else:
helixNts = self.helixNtsRev[hid]
zIdxs = reversed(zIdxs)
## Find strandOccupancy index for each nucleotide
strandOccIdx = []
for l,h in chunks:
for i in range(l,h+1):
nts = 1
for ins in strand.insertionsOnStrand(i,i):
nts += ins.length()
helixNts[i] = []
for j in range(nts):
strandOccIdx.append(i)
if not isFwd:
strandOccIdx = reversed(strandOccIdx)
## Add nucleotides
for s,zIdx,idx in zip(seqs,zIdxs,strandOccIdx):
p = zIdxToPos(zIdx)
a = zIdxToAngle(zIdx)
lastNt = seg.addNucleotide(hid, idx, zIdx, s, isFwd,
p, a, lastNt)
helixNts[idx].append(lastNt)
## Correct the order of nts in helixNtsRev
for hid,ntsAtIdx in self.helixNtsRev.items():
for i,nts in ntsAtIdx.items():
self.helixNtsRev[hid][i] = reversed(nts)
## 2) Find basepairs and stacks
for hid in range(numHID):
prevBasepair = None
for i in self._strandOccupancies(hid):
## Check that strand index is dsDNA
if not (i in self.helixNtsFwd[hid] and i in self.helixNtsRev[hid]):
continue
for nt1,nt2 in zip(self.helixNtsFwd[hid][i],
self.helixNtsRev[hid][i]):
self._pairBases(nt1,nt2)
if prevBasepair is not None and i - prevBasepair[0] <= 3: # TODO: find better way of locating stacks
atomicModel._stackBases( prevBasepair[1], nt1)
atomicModel._stackBases( nt2, prevBasepair[2])
# else:
# print("NO STACK FOR %d:%d" % (hid,i))
#basepairs[hid].append((i,n1,n2))
prevBasepair = (i,nt1,nt2)
self.segments.extend(segments)
return
def _connectNodes(self, below, above, sep):
below.addNodeAbove(above, sep)
above.addNodeBelow(below, sep)
def _getNeighborHelixDict(part):
props, origins = part.helixPropertiesAndOrigins()
neighborHelices = dict()
for i in range(len(origins)):
neighborHelices[i] = [int(j.strip('[,]')) for j in props['neighbors'][i].split() if j != '[]']
return neighborHelices
def _determineLatticeType(part):
props, origins = part.helixPropertiesAndOrigins()
origins = [np.array(o) for o in origins]
neighborVecs = []
for i,js in atomicModel._getNeighborHelixDict(part).items():
for j in js:
neighborVecs.append( origins[j]-origins[i] )
dots = [nv2.dot(nv1)/(np.linalg.norm(nv1)*np.linalg.norm(nv2)) for nv1 in neighborVecs for nv2 in neighborVecs]
nDots = len(dots)
nAngled = np.sum( np.abs(np.abs(dots)-0.5) < 0.2 )
if float(nAngled)/float(nDots) == 0:
return 'square'
elif float(nAngled)/float(nDots) < 0.05:
print('WARNING: unusual neighbors in square lattice structure')
return 'square'
elif float(nAngled)/float(nDots) > 0.4:
return 'honeycomb'
else:
raise Exception("Could not identify lattice type")
def backmap(self, simplerModel, simplerModelCoords,
dsDnaHelixNeighborDist=50, dsDnaAllNeighborDist=30,
ssDnaHelixNeighborDist=20, ssDnaAllNeighborDist=15):
## Assign each nucleotide to a bead in simplerModel
mapToSimplerModel = dict()
cgWeight = dict()
nts = [nt for seg in self.segments for nt in seg]
for i in range(len(nts)):
nt = nts[i]
hid = nt.prop['helixId']
if hid not in simplerModel.helices:
raise Exception("Could not find helix %d in simple model; using just a single cadnano structure")
cgH = simplerModel.helices[hid]
zIdxs = np.array( sorted([i for i,b in cgH]) )
zi = nt.prop['zIdx'] #
cgi = np.searchsorted(zIdxs,zi,side='left',sorter=None)
cgi, = [zIdxs[x] if x < len(zIdxs) else zIdxs[-1] for x in (cgi,)]
mapToSimplerModel[i] = [cgH.nodes[x] for x in (cgi,)]
## Find transformation for each bead of simplerModel used in mapping
trans = dict()
for b in list(set([b for i,bs in mapToSimplerModel.items() for b in bs])):
helixCutoff = dsDnaHelixNeighborDist if b.type[0] == 'd' else ssDnaHelixNeighborDist
allCutoff = dsDnaAllNeighborDist if b.type[0] == 'd' else ssDnaAllNeighborDist
r0 = b.initialPosition
r1 = simplerModelCoords[b.idx]
above = b.nodeAbove if b.nodeAbove is not None and \
np.sum((b.nodeAbove.initialPosition-b.initialPosition)**2) < 10**2 else b
below = b.nodeBelow if b.nodeBelow is not None and \
np.sum((b.nodeBelow.initialPosition-b.initialPosition)**2) < 10**2 else b
if type(simplerModel).__name__ == "beadModelTwist" and \
b.orientationNode is not None:
o = b.orientationNode
assert(above != below)
z0 = above.initialPosition - below.initialPosition
z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]
z0,z1 = [z/np.linalg.norm(z) for z in (z0,z1)]
x0 = o.initialPosition - r0
x1 = simplerModelCoords[o.idx] - r1
x0,x1 = [x-x.dot(z)*z for x,z in zip((x0,x1),(z0,z1))] # make x orthogonal to z
x0,x1 = [x/np.linalg.norm(x) for x in (x0,x1)]
y0,y1 = [np.cross(z,x) for x,z in zip((x0,x1),(z0,z1))]
R,tmp1,tmp2 = minimizeRmsd( (x0,y0,z0), (x1,y1,z1) )
else:
ids = simplerModel._getNeighborhoodIds(b, simplerModelCoords, helixCutoff, allCutoff)
posOld = np.array( [simplerModel.particles[i][0].initialPosition for i in ids] )
posNew = np.array( [simplerModelCoords[i] for i in ids] )
R,cOld,cNew = minimizeRmsd( posOld, posNew )
if b.type[0] == 's':
assert(above != below)
z0 = above.initialPosition - below.initialPosition
z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]
zR = z0.dot(R)
cross = np.cross(z1,zR) / (np.linalg.norm(z1)*np.linalg.norm(zR))
angle = np.arcsin(np.linalg.norm(cross))*180/np.pi
R2 = rotationAboutAxis(cross,angle)
R = R.dot(R2)
cOld,cNew = (r0,r1)
trans[b.idx] = (R,cOld,cNew)
## Optionally smooth orientations
## Apply transformation to each bead of self
for i in range(len(nts)):
b = nts[i]
cgb, = mapToSimplerModel[i]
cgi = cgb.idx
r0 = simplerModel.particles[cgi][0].initialPosition
R,c0,c1 = trans[cgi]
# assert( np.linalg.norm((b.position-r0).dot(R)) < 100 )
b.position = (b.position - r0).dot(R) + simplerModelCoords[cgi]
b.orientation = b.orientation.dot( R )
for i in range(len(nts)):
b = nts[i]
if b.basepair is not None:
assert( np.all( b.orientation == b.basepair.orientation ) )
# bpVec = (b.position - b.basepair.position)
# zVec = np.array((0,0,1)).dot(b.orientation)
# assert( np.dot( bpVec, zVec ) < 0.05 * np.linalg.norm(bpVec) )
def _pairBases(self, n1,n2):
assert(n1.basepair is None and n2.basepair is None)
n1.basepair = n2
n2.basepair = n1
def _stackBases(below, above):
assert(below.stack3prime is None and above.stack5prime is None)
below.stack3prime = above
above.stack5prime = below
def _removeIntrahelicalConnectionsBeyond(self, cutoff):
## lazy progarmming: leave nts in one segment after bond break
nts = [nt for seg in self.segments for nt in seg]
for n1 in nts:
n2 = n1.ntAt3prime
if n2 is not None:
r2 = np.sum( (n1.position - n2.position)**2 )
if r2 > cutoff**2:
if n1.ntAt3prime == n2:
assert(n2.ntAt5prime == n1)
n1.ntAt3prime = None
n2.ntAt5prime = None
elif n2.ntAt3prime == n1:
assert(n1.ntAt5prime == n2)
n1.ntAt5prime = None
n2.ntAt3prime = None
else:
raise
def scaffoldSeq(seqFile):
return
def randomizeUnassignedNtSeqs(self):
for seg in self.segments:
for nt in seg:
if nt.sequence == '?':
nt.sequence = random.choice(tuple(seqComplement.keys()))
if nt.basepair is not None:
nt.basepair.sequence = seqComplement[nt.sequence]
def setUnassignedNtSeqs(self,base):
assert(base in ("A","T","C","G"))
for seg in self.segments:
for nt in seg:
if nt.sequence == '?':
nt.sequence = base
if nt.basepair is not None:
nt.basepair.sequence = seqComplement[nt.sequence]
def setScaffoldSeq(self,sequence):
## get longest segment
segs = sorted( [(len(s),s) for s in self.segments], key=lambda x: x[0] )
seg = segs[-1][1]
assert( len(seg) > len(segs[-2][1]) )
if len(seg) > len(sequence):
raise Exception("Sequence (%d) too short for scaffold (%d)!" % (len(sequence),len(seg)))
for nt,seq in zip(seg,sequence):
assert(seq in ("A","T","C","G"))
nt.sequence = seq
if nt.basepair is not None:
nt.basepair.sequence = seqComplement[nt.sequence]
def _assignAtomicIndexToNucleotides(self):
natoms=0
for seg in self.segments:
for nt in seg:
nt.firstAtomicIndex = natoms
natoms += len(nt.atoms(transform=False))
def combineRegionLists(self,loHi1,loHi2,intersect=False):
"""Combines two lists of (lo,hi) pairs specifying integer
regions a single list of regions. """
## Validate input
for l in (loHi1,loHi2):
## Assert each region in lists is sorted
for pair in l:
assert(len(pair) == 2)
assert(pair[0] <= pair[1])
if len(loHi1) == 0:
if intersect:
return []
else:
return loHi2
if len(loHi2) == 0:
if intersect:
return []
else:
return loHi1
## Break input into lists of compact regions
compactRegions1,compactRegions2 = [[],[]]
for compactRegions,loHi in zip(
[compactRegions1,compactRegions2],
[loHi1,loHi2]):
tmp = []
lastHi = loHi[0][0]-1
for lo,hi in loHi:
if lo-1 != lastHi:
compactRegions.append(tmp)
tmp = []
tmp.append((lo,hi))
lastHi = hi
if len(tmp) > 0:
compactRegions.append(tmp)
## Build result
result = []
region = []
i,j = [0,0]
compactRegions1.append([[1e10]])
compactRegions2.append([[1e10]])
while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
cr1 = compactRegions1[i]
cr2 = compactRegions2[j]
## initialize region
if len(region) == 0:
if cr1[0][0] <= cr2[0][0]:
region = cr1
i += 1
continue
else:
region = cr2
j += 1
continue
if region[-1][-1] >= cr1[0][0]:
region = self.combineCompactRegionLists(region, cr1, intersect=False)
i+=1
elif region[-1][-1] >= cr2[0][0]:
region = self.combineCompactRegionLists(region, cr2, intersect=False)
j+=1
else:
result.extend(region)
region = []
assert( len(region) > 0 )
result.extend(region)
result = sorted(result)
# print("loHi1:",loHi1)
# print("loHi2:",loHi2)
# print(result,"\n")
if intersect:
lo = max( [loHi1[0][0], loHi2[0][0]] )
hi = min( [loHi1[-1][1], loHi2[-1][1]] )
result = [r for r in result if r[0] >= lo and r[1] <= hi]
return result
def combineCompactRegionLists(self,loHi1,loHi2,intersect=False):
"""Combines two lists of (lo,hi) pairs specifying regions within a
compact integer set into a single list of regions.
examples:
loHi1 = [[0,4],[5,7]]
loHi2 = [[2,4],[5,9]]
out = [(0, 1), (2, 4), (5, 7), (8, 9)]
loHi1 = [[0,3],[5,7]]
loHi2 = [[2,4],[5,9]]
out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
"""
## Validate input
for l in (loHi1,loHi2):
## Assert each region in lists is sorted
for pair in l:
assert(len(pair) == 2)
assert(pair[0] <= pair[1])
## Assert lists are compact
for pair1,pair2 in zip(l[::2],l[1::2]):
assert(pair1[1]+1 == pair2[0])
if len(loHi1) == 0:
if intersect:
return []
else:
return loHi2
if len(loHi2) == 0:
if intersect:
return []
else:
return loHi1
## Find the ends of the region
lo = min( [loHi1[0][0], loHi2[0][0]] )
hi = max( [loHi1[-1][1], loHi2[-1][1]] )
## Make a list of indices where each region will be split
splitAfter = []
for l,h in loHi2:
if l != lo:
splitAfter.append(l-1)
if h != hi:
splitAfter.append(h)
for l,h in loHi1:
if l != lo:
splitAfter.append(l-1)
if h != hi:
splitAfter.append(h)
splitAfter = sorted(list(set(splitAfter)))
# print("splitAfter:",splitAfter)
split=[]
last = -2
for s in splitAfter:
split.append(s)
last = s
# print("split:",split)
returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]
if intersect:
lo = max( [loHi1[0][0], loHi2[0][0]] )
hi = min( [loHi1[-1][1], loHi2[-1][1]] )
returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]
# print("loHi1:",loHi1)
# print("loHi2:",loHi2)
# print(returnList,"\n")
return returnList
def _getDsChunksForHelixPair(self,hid1,hid2):
chunks = []
chunks1,chunks2 = [sorted(list(set(self.dsChunks[hid]))) for hid in [hid1,hid2]]
return self.combineRegionLists(chunks1,chunks2)
i,j = [0,0]
lastHi = -100000
# pdb.set_trace()
while i < len(chunks1) or j < len(chunks2):
if i >= len(chunks1):
tmp
tmp = [x for x in sorted(list(set(chunks1[i]+chunks2[j]))) if x > lastHi]
# if len(tmp) > 1:
# chunks.append(tmp[:2])
# if len(tmp) > 0:
# lastHi = tmp[min(1,len(tmp))]
chunks.append(tmp[:2])
lastHi = tmp[min(1,len(tmp))]
if chunks1[i][1] <= lastHi: i+=1
if chunks2[j][1] <= lastHi: j+=1
return chunks
def getNtsBetweenCadnanoIdxInclusive(self,hid,lo,hi,isFwd=True):
if isFwd:
helixNts = self.helixNtsFwd[hid]
else:
helixNts = self.helixNtsRev[hid]
nts = []
for i in range(lo,hi+1):
if i in helixNts:
nts.extend(helixNts[i])
return nts
# -------------------------- #
# Methods for querying model #
# -------------------------- #
def _getIntrahelicalNodeSeries(self,seriesLen):
nodeSeries = set()
for hid,hlx in self.helices.items():
for zid,n in hlx:
nodeList,sepList = n.getNodesAbove(seriesLen-1, inclusive = True)
# nodeList = [n]
# sepList = []
# for i in range(seriesLen-1):
# if n.nodeAbove is None: break
# n = n.nodeAbove
# nodeList.append(n)
# sepList.append(n.nodeBelowSep)
if len(nodeList) == seriesLen:
nodeList = tuple(nodeList)
sepList = tuple(sepList)
nodeSeries.add( tuple((nodeList,sepList)) )
return nodeSeries
def _getIntrahelicalBonds(self):
return self._getIntrahelicalNodeSeries(2)
# bonds = set()
# for hid,hlx in self.helices.items():
# bonds.update( ((n, n.nodeAbove, n.nodeAboveSep) for zid,n in hlx if n.nodeAbove is not None) )
# bonds.update( ((n.nodeBelow, n, n.nodeBelowSep) for zid,n in hlx if n.nodeBelow is not None) )
# return bonds
def _getIntrahelicalAngles(self):
return self._getIntrahelicalNodeSeries(3)
def _getCrossoverBonds(self):
return { ((n, xo[0]), (1))
for hid,hlx in self.helices.items()
for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx }
# bonds = set()
# for hid,hlx in self.helices.items():
# bonds.update( (((n, xo[0]), (1)) for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx) )
# return bonds
def _getCrossoverAnglesAndDihedrals(self):
angles,dihedrals = [set(),set()]
for hid,hlx in self.helices.items():
lastXoverNode = None
bpsBetween = 0
for zid,n in hlx:
## Search for a pair of crossovers
if n.nodeBelow is None or n.type != "dsDNA":
## Found ssDNA or a gap; reset search
lastXoverNode = None
bpsBetween = 0
if lastXoverNode is None:
if len(n.xovers) > 0:
## First node with a crossover
lastXoverNode = n
else:
if n.nodeBelow is not None:
bpsBetween += n.nodeBelowSep
if len(n.xovers) > 0:
## Second node with a crossover, Add dihedral(s)
for xo1 in lastXoverNode.xovers:
for xo2 in n.xovers:
assert( bpsBetween != 0 )
# if bpsBetween != 0
angles.add( ((xo1[0], lastXoverNode, n), bpsBetween) )
angles.add( ((lastXoverNode, n, xo2[0]), bpsBetween) )
dihedrals.add( ((xo1[0], lastXoverNode, n, xo2[0]), bpsBetween, xo1[1], xo2[1]) )
lastXoverNode = n
bpsBetween = 0
return angles, dihedrals
def _getBonds(self):
bonds = self._getIntrahelicalBonds()
bonds.update( self._getCrossoverBonds() )
return bonds
# -------------------------- #
# Methods for printing model #
# -------------------------- #
def writeNamdFiles(self,prefix,numSteps=48000):
self._assignAtomicIndexToNucleotides()
self.writePdbPsf(prefix)
self.writeENM(prefix)
self.writeNamdFile(prefix,numSteps)
def writePdbPsf(self, prefix):
bonds,angles,dihedrals,impropers = [[],[],[],[]]
with open("%s.pdb" % prefix,'w') as pdb, \
open("%s.psf" % prefix,'w') as psf:
## Write headers
pdb.write("CRYST1 1000. 1000. 1000. 90.00 90.00 90.00 P 1 1\n")
psf.write("PSF NAMD\n\n") # create NAMD formatted psf
psf.write("{:>8d} !NTITLE\n\n".format(0))
## Format strings
## http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM
# pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s} {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f} {element:>2s}{charge:>2.2f}\n"
pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s} {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f} {segname:4s}{element:>2s}\n"
## From vmd/plugins/molfile_plugin/src/psfplugin.c
## "%d %7s %10s %7s %7s %7s %f %f"
psfFormat = "{idx:>8d} {segname:7s} {resid:<10d} {resname:7s} " + \
"{name:7s} {type:7s} {charge: f} {mass:f}\n"
## PSF ATOMS section
natoms=0
for seg in self.segments:
for atoms in seg.atoms(transform=False):
natoms += len(atoms)
psf.write("{:>8d} !NATOM\n".format(natoms))
atomProps = dict(idx=0,altLoc='',chain='A', iCode='',
occupancy=0, beta=0, charge='', resid=1)
resnameDict = dict(A='ADE',T='THY',G='GUA',C='CYT')
## Loop over segments, setting common properties
for seg in self.segments:
atomProps['segname'] = seg.props['segname']
prevAtoms = None
## Loop over nucleotides, setting common properties
for atoms, seq in zip(seg.atoms(),seg.sequence()):
atomProps['resname'] = resnameDict[seq]
if atomProps['resid'] < 9999:
atomProps['residString'] = "%4d" % atomProps['resid']
else:
digit = (atomProps['resid'] // 1000) - 10
char = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[digit]
rem = atomProps['resid'] - (digit+10)*1000
atomProps['residString'] = "{:1s}{:<3d}".format(char,rem)
## Add bonds, angles, dihedrals associated with LKNA
if prevAtoms is not None:
c4idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C4'")
c2idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C2'")
h3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("H3'")
c3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C3'")
o3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("O3'")
Pidx = atomProps['idx'] + atoms.index("P")
o5idx = atomProps['idx'] + atoms.index("O5'")
o1pidx = atomProps['idx'] + atoms.index("O1P")
o2pidx = atomProps['idx'] + atoms.index("O2P")
c5idx = atomProps['idx'] + atoms.index("C5'")
bonds.append([o3idx, Pidx])
angles.append([c3idx, o3idx, Pidx])
for x in (o5idx,o1pidx,o2pidx):
angles.append([o3idx, Pidx, x])
for x in (c4idx,c2idx,h3idx):
dihedrals.append([x,c3idx, o3idx, Pidx])
for x in (o5idx,o1pidx,o2pidx):
dihedrals.append([c3idx, o3idx, Pidx,x])
dihedrals.append([o3idx, Pidx, o5idx, c5idx])
## Find bonds,angles,dihedrals, and impropers
bonds.extend( atoms.bonds+atomProps['idx'] )
angles.extend( atoms.angles+atomProps['idx'] )
dihedrals.extend( atoms.dihedrals+atomProps['idx'] )
impropers.extend( atoms.impropers+atomProps['idx'] )
## Loop over atoms
for props in atoms.atomicProps():
for k,v in props.items():
atomProps[k] = v
# atomProps['name'] = n
atomProps['element'] = atomProps['name'][0]
## Write lines
pdb.write(pdbFormat.format( **atomProps ))
## psfResid = "%d%c%c" % (idx," "," "), # TODO: work with large indeces
## Increment counter
atomProps['idx'] += 1
psf.write(psfFormat.format( **atomProps ))
prevAtoms = atoms
atomProps['resid'] += 1
## Write out bonds, angles, dihedrals
psf.write("\n{:>8d} !NBOND".format(len(bonds)))
counter = 0
for b in bonds:
if counter % 4 == 0: psf.write("\n")
psf.write(" {:>7d} {:>7d}".format(*b))
counter += 1
psf.write("\n\n{:>8d} !NTHETA".format(len(angles)))
counter = 0
for b in angles:
if counter % 3 == 0: psf.write("\n")
psf.write(" {:>7d} {:>7d} {:>7d}".format(*b))
counter += 1
psf.write("\n\n{:>8d} !NPHI".format(len(dihedrals)))
counter = 0
for b in dihedrals:
if counter % 2 == 0: psf.write("\n")
psf.write(" {:>7d} {:>7d} {:>7d} {:>7d}".format(*b))
counter += 1
psf.write("\n\n{:>8d} !NIMPHI".format(len(impropers)))
counter = 0
for b in impropers:
if counter % 2 == 0: psf.write("\n")
psf.write(" {:>7d} {:>7d} {:>7d} {:>7d}".format(*b))
counter += 1
psf.write("\n\n 0 !NDON: donors\n\n\n")
psf.write("\n 0 !NACC: acceptors\n\n\n")
psf.write("\n 0 !NNB\n\n")
for i in range(natoms//8):
psf.write(" 0 0 0 0 0 0 0 0\n")
for i in range(natoms-8*(natoms//8)):
psf.write(" 0")
psf.write("\n\n 1 0 !NGRP\n\n")
def writeENM(self, prefix):
if self.latticeType == "square":
enmTemplate = enmTemplateSQ
elif self.latticeType == "honeycomb":
enmTemplate = enmTemplateHC
else:
raise Exception("Lattice type '%s' not supported" % self.latticeType)
noStackPrime = 0
noBasepair = 0
with open("%s.exb" % prefix,'w') as fh:
# natoms=0
for seg in self.segments:
for nt1 in seg:
other = []
nt2 = nt1.basepair
if nt2 is not None:
if nt2.firstAtomicIndex > nt1.firstAtomicIndex:
other.append((nt2,'pair'))
nt2 = nt2.stack3prime
if nt2 is not None and nt2.firstAtomicIndex > nt1.firstAtomicIndex:
other.append((nt2,'paircross'))
else:
noBasepair += 1
nt2 = nt1.stack3prime
if nt2 is not None:
other.append((nt2,'stack'))
nt2 = nt2.basepair
if nt2 is not None and nt2.firstAtomicIndex > nt1.firstAtomicIndex:
other.append((nt2,'cross'))
else:
noStackPrime += 1
a1 = nt1.atoms(transform=False)
for nt2,key in other:
key = ','.join((key,nt1.sequence[0],nt2.sequence[0]))
for n1, n2, d in enmTemplate[key]:
d = float(d)
a2 = nt2.atoms(transform=False)
i,j = [a.index(n)-1 for a,n in zip((a1,a2),(n1,n2))]
# try:
# i,j = [a.index(n)-1 for a,n in zip((a1,a2),(n1,n2))]
# except:
# continue
k = 0.1
if self.latticeType == 'honeycomb':
correctionKey = ','.join((key,n1,n2))
assert(correctionKey in enmCorrectionsHC)
dk,dr = enmCorrectionsHC[correctionKey]
k = float(dk)
d += float(dr)
i += nt1.firstAtomicIndex
j += nt2.firstAtomicIndex
fh.write("bond %d %d %f %.2f\n" % (i,j,k,d))
print("NO STACKS found for:", noStackPrime)
print("NO BASEPAIRS found for:", noBasepair)
props, origins = self.part.helixPropertiesAndOrigins()
origins = [np.array(o) for o in origins]
## Push bonds
pushBonds = []
fwdDict = dict() # dictionaries of nts with keys [hid][zidx]
revDict = dict()
xo = dict() # dictionary of crossovers between [hid1][hid2]
helixIds = [hid for s in self.segments for hid in s.nodes.keys()]
helixIds = sorted(list(set(helixIds)))
#- initialize dictionaries
for h1 in helixIds:
fwdDict[h1] = dict()
revDict[h1] = dict()
xo[h1] = dict()
for h2 in helixIds:
xo[h1][h2] = []
#- fill dictionaries
for seg in self.segments:
for nt in seg:
hid = nt.prop['helixId']
## Add nt to ntDict
idx = nt.prop['idx']
zidx = nt.prop['zIdx']
isFwd = nt.prop['isFwd']
if isFwd:
fwdDict[hid][idx] = nt
else:
revDict[hid][idx] = nt
## Find crossovers and ends
for nt2 in (nt.ntAt5prime, nt.ntAt3prime):
if nt2 is not None:
hid2 = nt2.prop['helixId']
if hid2 != hid:
# xo[hid][hid2].append(zidx)
xo[hid][hid2].append(idx)
xo[hid2][hid].append(idx)
props = self.part.getModelProperties()
if props.get('point_type') == PointType.ARBITRARY:
raise Exception("Not implemented")
else:
props, origins = self.part.helixPropertiesAndOrigins()
neighborHelices = atomicModel._getNeighborHelixDict(self.part)
if self.latticeType == 'honeycomb':
# minDist = 8 # TODO: test against server
minDist = 11 # Matches ENRG MD server... should it?
elif self.latticeType == 'square':
minDist = 11
for hid1 in helixIds:
for hid2 in neighborHelices[hid1]:
if hid2 not in helixIds:
# print("WARNING: writeENM: helix %d not in helixIds" % hid2) # xo[%d] dict" % (hid2,hid1))
continue
## Build chunk for helix pair
chunks = self._getDsChunksForHelixPair(hid1,hid2)
for lo,hi in chunks:
# pdb.set_trace()
nts1 = self.getNtsBetweenCadnanoIdxInclusive(hid1,lo,hi)
nts2 = self.getNtsBetweenCadnanoIdxInclusive(hid2,lo,hi)
if len(nts1) <= len(nts2):
iVals = list(range(len(nts1)))
if len(nts1) > 1:
jVals = [int(round(j*(len(nts2)-1)/(len(nts1)-1))) for j in range(len(nts1))]
else:
jVals = [0]
else:
if len(nts2) > 1:
iVals = [int(round(i*(len(nts1)-1)/(len(nts2)-1))) for i in range(len(nts2))]
else:
iVals = [0]
jVals = list(range(len(nts2)))
ntPairs = [[nts1[i],nts2[j]] for i,j in zip(iVals,jVals)]
## Skip pairs close to nearest crossover on lo side
xoIdx = [idx for idx in xo[hid1][hid2] if idx <= lo]
if len(xoIdx) > 0:
xoIdx = max(xoIdx)
assert( type(lo) is int )
xoDist = min([len(self.getNtsBetweenCadnanoIdxInclusive(hid,xoIdx,lo-1)) for hid in (hid1,hid2)])
skip=minDist-xoDist
if skip > 0:
ntPairs = ntPairs[skip:]
## Skip pairs close to nearest crossover on hi side
xoIdx = [idx for idx in xo[hid1][hid2] if idx >= hi]
if len(xoIdx) > 0:
xoIdx = min(xoIdx)
xoDist = min([len(self.getNtsBetweenCadnanoIdxInclusive(hid,hi+1,xoIdx)) for hid in (hid1,hid2)])
skip=minDist-xoDist
if skip > 0:
ntPairs = ntPairs[:-skip]
for ntPair in ntPairs:
bps = [nt.basepair for nt in ntPair]
if None in bps: continue
for nt1,nt2 in [ntPair,bps]:
i,j = [nt.firstAtomicIndex for nt in (nt1,nt2)]
if j <= i: continue
if np.linalg.norm(nt1.position-nt2.position) > 45: continue
ai,aj = [nt.atoms(transform=False) for nt in (nt1,nt2)]
try:
i += ai.index('P')-1
j += aj.index('P')-1
pushBonds.append(i)
pushBonds.append(j)
except:
pass
print("PUSH BONDS:", len(pushBonds)/2)
if not self.useTclForces:
with open("%s.exb" % prefix, 'a') as fh:
for i,j in zip(pushBonds[::2],pushBonds[1::2]):
fh.write("bond %d %d %f %.2f\n" % (i,j,1.0,31))
else:
atomList = list(set(pushBonds))
with open("%s.forces.tcl" % prefix,'w') as fh:
fh.write("set atomList {%s}\n\n" %
" ".join([str(x-1) for x in atomList]) )
fh.write("set bonds {%s}\n" %
" ".join([str(x-1) for x in pushBonds]) )
fh.write("""
foreach atom $atomList {
addatom $atom
}
proc calcforces {} {
global atomList bonds
loadcoords rv
foreach i $atomList {
set force($i) {0 0 0}
}
foreach {i j} $bonds {
set rvec [vecsub $rv($j) $rv($i)]
# lassign $rvec x y z
# set r [expr {sqrt($x*$x+$y*$y+$z*$z)}]
set r [getbond $rv($j) $rv($i)]
set f [expr {2*($r-31.0)/$r}]
vecadd $force($i) [vecscale $f $rvec]
vecadd $force($j) [vecscale [expr {-1.0*$f}] $rvec]
}
foreach i $atomList {
addforce $i $force($i)
}
}
""")
def writeNamdFile(self,prefix,numSteps=4800000):
with open("%s.namd" % prefix,'w') as fh:
fh.write("""#############################################################
## CONFIGURATION FILE FOR ORIGAMI STRUCTURE PREDICTION ##
#############################################################
#############################################################
## ADJUSTABLE PARAMETERS ##
#############################################################\n""")
fh.write("set prefix %s\n" % prefix)
fh.write("""set nLast 0; # increment when continueing a simulation
set n [expr $nLast+1]
set out output/$prefix-$n
set temperature 300
structure $prefix.psf
coordinates $prefix.pdb
outputName $out
XSTfile $out.xst
DCDfile $out.dcd
#############################################################
## SIMULATION PARAMETERS ##
#############################################################
# Input
paraTypeCharmm on
parameters charmm36.nbfix/par_all36_na.prm
parameters charmm36.nbfix/par_water_ions_na.prm
wrapAll off
# Force-Field Parameters
exclude scaled1-4
1-4scaling 1.0
switching on
switchdist 8
cutoff 10
pairlistdist 12
""")
if not self.useTclForces:
fh.write("margin 30\n")
fh.write("""
# Integrator Parameters
timestep 2.0 ;# 2fs/step
rigidBonds all ;# needed for 2fs steps
nonbondedFreq 1
fullElectFrequency 3
stepspercycle 12
# PME (for full-system periodic electrostatics)
PME no
PMEGridSpacing 1.2
# Constant Temperature Control
langevin on ;# do langevin dynamics
# langevinDamping 1 ;# damping coefficient (gamma); used in original study
langevinDamping 0.1 ;# less friction for faster relaxation
langevinTemp $temperature
langevinHydrogen off ;# don't couple langevin bath to hydrogens
# output
useGroupPressure yes
xstFreq 4800
outputEnergies 4800
dcdfreq 4800
restartfreq 48000
#############################################################
## EXTRA FORCES ##
#############################################################
# ENM and intrahelical extrabonds
extraBonds on
extraBondsFile $prefix.exb
""")
if self.useTclForces:
fh.write("""
tclForces on
tclForcesScript $prefix.forces.tcl
""")
fh.write("""
#############################################################
## RUN ##
#############################################################
# Continuing a job from the restart files
cellBasisVector1 1000 0 0
cellBasisVector2 0 1000 0
cellBasisVector3 0 0 1000
if {$nLast == 0} {
temperature 300
fixedAtoms on
fixedAtomsForces on
fixedAtomsFile $prefix.pdb
fixedAtomsCol B
minimize 2400
fixedAtoms off
minimize 2400
} else {
bincoordinates output/$prefix-$nLast.restart.coor
binvelocities output/$prefix-$nLast.restart.vel
}
""")
fh.write("run %d\n" % numSteps)
def simulate(self, outputPrefix, outputDirectory='output', numSteps=4800000, numprocs=8, gpus=0, namd=None):
# pdb.set_trace()
if type(gpus) is int:
gpus = [gpus]
for gpu in gpus:
assert(type(gpu) is int)
assert(type(numSteps) is int)
assert(type(numprocs) is int)
if outputDirectory == '': outputDirectory='.'
if namd is None:
for path in os.environ["PATH"].split(os.pathsep):
path = path.strip('"')
fname = os.path.join(path, "namd2")
if os.path.isfile(fname) and os.access(fname, os.X_OK):
fname = fname
break
if not os.path.exists(namd):
raise Exception("NAMD was not found")
if not os.path.isfile(namd):
raise Exception("NAMD was not found")
if not os.access(namd, os.X_OK):
raise Exception("NAMD is not executable")
if not os.path.exists(outputDirectory):
os.makedirs(outputDirectory)
elif not os.path.isdir(outputDirectory):
raise Exception("outputDirectory '%s' is not a directory!" % outputDirectory)
self.writeNamdFiles( outputPrefix, numSteps=numSteps )
charmrun = os.path.join( os.path.dirname(namd), "charmrun" )
## $bindir/charmrun +p$procs $bindir/namd2 +netpoll +idlepoll +devices $gpus $config &> $log
gpus = ','.join([str(gpu) for gpu in gpus])
# cmd = (charmrun, "+p%d" % numprocs, namd, '+netpoll',
env = os.environ.copy()
env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
cmd = (namd,
"+ppn", numprocs,
'+netpoll','+idlepoll',
'+devices', gpus,
"%s.namd" % outputPrefix)
cmd = tuple(str(x) for x in cmd)
logfile = "%s/%s.log" % (outputDirectory, outputPrefix)
print("Running NAMD with: %s >& %s" % (" ".join(cmd), logfile) )
with open(logfile,'w') as fh:
subprocess.call(cmd, env=env, stdout=fh, stderr=subprocess.STDOUT)