Skip to content
Snippets Groups Projects
atomicModel.py 54.84 KiB
# -*- coding: utf-8 -*-
from datetime import datetime
from cadnano.cnenum import PointType
from cadnano.strand import Strand
from math import pi,sqrt,exp,floor
import numpy as np
import random
import os, sys, subprocess
from pdb import set_trace

from coords import minimizeRmsd, rotationAboutAxis
from CanonicalNucleotideAtoms import canonicalNtFwd, canonicalNtRev, seqComplement
from CanonicalNucleotideAtoms import enmTemplateHC, enmTemplateSQ, enmCorrectionsHC

class abstractNucleotide():
    def __init__(self, sequence, position, angle, ntAt5prime=None, **prop):
        self.sequence = sequence
        self.prop = prop
        # self.prop['isFwd'] = isFwd
        # self.isFwd = isFwd
        self.position = np.array(position)
        # self.angle = angle
        self.orientation = rotationAboutAxis([0,0,1], angle)
        self.ntAt5prime = ntAt5prime
        self.ntAt3prime = None

        self.basepair = None
        self.stack5prime = None
        self.stack3prime = None

        self.firstAtomicIndex = -1
        
    def set3primeNt(self,n):
        self.ntAt3prime = n


    def atoms(self, transform=True, scale=1.0):
        key = self.sequence
        if self.ntAt5prime is None and self.ntAt3prime is not None: key = "5"+key
        if self.ntAt5prime is not None and self.ntAt3prime is None: key = key+"3"
        if self.ntAt5prime is None and self.ntAt3prime is None: key = key+"singlet"
        
        if self.prop['isFwd']:
            atoms = canonicalNtFwd[ key ]
        else:
            atoms = canonicalNtRev[ key ]

        if transform:
            singleStranded = (self.basepair is None)
            atoms = atoms.transformed(self.orientation, self.position, scale, singleStranded)
        return atoms
            

class RigidBodyNode():
    def __init__(self, helix, pos, angle, type="dsDNA"):
        self.helix    = helix
        self.position = pos
        # self.angle = 
        self.type     = type
        # self.numBps
        self.nodeAbove = None
        self.nodeBelow = None
        self.xovers = []
        self.idx = helix.model.numNodes
        helix.model.numNodes = helix.model.numNodes + 1

    def addNodeAbove(self, node, separation):
        self.nodeAbove = node
        self.nodeAboveSep = separation # bp
    def addNodeBelow(self, node, separation):
        self.nodeBelow = node
        self.nodeBelowSep = separation # bp

    def addXover(self, node, polarity, double=False):
        ## TODO: what is meant by polarity?
        self.xovers.append( (node,polarity,double) )

    
    def getNodesAbove(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeAbove is None: break
            n = n.nodeAbove
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

    def getNodesBelow(self,numNodes,inclusive=False):
        assert( type(numNodes) is int and numNodes > 0 )

        nodeList,sepList = [[],[]]
        n = self
        if inclusive:
            nodeList.append(n)

        for i in range(numNodes):
            if n.nodeBelow is None: break
            n = n.nodeBelow
            nodeList.append(n)
            sepList.append(n.nodeBelowSep)
            
        return nodeList,sepList

class Segment():
    def __init__(self, segname, ntScale=1.0):
        # self.model = model
        self.ntScale = ntScale
        self.props = dict()
        self.props['segname'] = segname
        self.nodes = dict()
        self.firstNode = None

    def addNucleotide(self, helixId, strandIdx, zIdx,
                      sequence, isFwd, position, angle, ntAt5prime=None):

        if helixId in self.nodes:
            if strandIdx in self.nodes[helixId]:
                for n in self.nodes[helixId][strandIdx]:
                    if n.prop['zIdx'] == zIdx and n.prop['isFwd'] == isFwd:
                        raise Exception("Attempted to add a node in the same location (%d:%.1f:%s) twice!" % (helixId,zIdx,isFwd))
        else:
            self.nodes[helixId] = dict()
        
        if strandIdx not in self.nodes[helixId]:
            self.nodes[helixId][strandIdx] = []

        n = abstractNucleotide(sequence, position, angle, ntAt5prime, 
                               isFwd=isFwd, helixId=helixId, zIdx=zIdx, idx=strandIdx)

        if ntAt5prime is not None:
            ntAt5prime.set3primeNt(n)

        self.nodes[helixId][strandIdx].append( n ) # TODO: use strandIdx?

        # if self._firstNode is not None:
        #     self._findFirstNode()

        # ## Update ordered list of nodes 
        # if self.model.nodes is not None:
        #     model.buildOrderedNodesList()
        return n

    def __len__(self):
        l = 0
        for hid,ntsAtIdx in self.nodes.items():
            for strandIdx,nts in ntsAtIdx.items():
                l += len(nts)
        return l

    def __iter__(self):
        for n in self._get5primeNts():            
            yield n
            while n.ntAt3prime is not None:
                n = n.ntAt3prime
                yield n

    def atoms(self, transform=True):
        for nt in self:
            yield nt.atoms(transform, self.ntScale)
 
    def sequence(self):
        for nt in self:
            yield nt.sequence

    def _get5primeNts(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)

        if len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
        return firstNodes
        
    ## TODO deprecate
    def _findFirstNode(self):
        firstNodes = []
        for a,x in self.nodes.items():
            for b,y in x.items():
                for nt in y:
                    if nt.ntAt5prime is None:
                        firstNodes.append(nt)

        if len(firstNodes) > 1:
            raise Exception("Segment object contains two 5' ends!")
        elif len(firstNodes) == 0:
            print("WARNING: found circular segment; untested", file=sys.stderr)
            firstNodes = self.nodes.items()[0].items()[0]
            
        return firstNodes[0]
    
class atomicModel():

    def __init__(self, part, ntScale=1.0, randomSeed=1):
        self.dsChunks = dict()

        self.numNodes = 0
        self.segments = []
        self.helixNtsFwd = dict()
        self.helixNtsRev = dict()
        self.nodeTypeCounts = None
        self.ntScale = ntScale
        self.part = part

        # self._nbParams = set()
        self.bonds = set()
        self.angles = set()
        self.dihedrals = set()

        # self._bondParams = set()
        # self._angleParams = set()
        # self._dihedralParams = set()

        self._nbParamFiles = []
        # self._bondParamFiles = set()
        # self._angleParamFiles = set()
        # self._dihedralParamFiles = set()

        random.seed(randomSeed)
        self.useTclForces = False

        self.latticeType = atomicModel._determineLatticeType(part)
        self._buildModel(part)
        

    def _strandOccupancies(self, helixId):
        try:
            ends1,ends2 = self._getStrandEnds(helixId)
        except:
            ## hid not in strand_list
            return []

        strandOccupancies = [ [x for i in range(0,len(e),2)
                               for x in range(e[i],e[i+1]+1)]
                              for e in (ends1,ends2) ]
        return sorted(list(set( strandOccupancies[0] + strandOccupancies[1] )))

    def _getStrandEnds(self, helixId):

        """Utility method to convert cadnano strand lists into list of
        indices of terminal points"""

        helixStrands = self.strand_list[helixId]

        endLists = [[],[]]
        for endList, strandList in zip(endLists,helixStrands):
            lastStrand = None
            for s in strandList:
                if lastStrand is None:
                    ## first strand
                    endList.append(s[0])
                elif lastStrand[1] != s[0]-1: 
                    assert( s[0] > lastStrand[1] )
                    endList.extend( [lastStrand[1], s[0]] )
                lastStrand = s
            if lastStrand is not None:
                endList.append(lastStrand[1])
        return endLists

    # -------------------------- #
    # Methods for building model #
    # -------------------------- #
    def _buildModel(self, part):
        # maxVhelixId = part.getIdNumMax()
        
        props = part.getModelProperties().copy()
        # print(props)
        vh_props = []
        origins = []
        if props.get('point_type') == PointType.ARBITRARY:
            # TODO add code to encode Parts with ARBITRARY point configurations
            raise Exception("Not implemented")
        else:
            vh_props, origins = part.helixPropertiesAndOrigins()
            # print(' VIRTUAL HELICES:', vh_props)
            # # print(' ORIGINS:', origins)
            # group_props['virtual_helices'] = vh_props
            # group_props['origins'] = origins
            
        ## TODO: compartmentalize following
        ## Loop over virtual helices and build lists of strands 
        vh_list = []
        self.strand_list = []
        xover_list = []

        numHID = part.getIdNumMax() + 1
        for id_num in range(numHID):
            offset_and_size = part.getOffsetAndSize(id_num)
            if offset_and_size is None:
                # add a placeholder
                self.strand_list.append(None)
                # prop_list.append(None)
            else:
                offset, size = offset_and_size
                vh_list.append((id_num, size))
                fwd_ss, rev_ss = part.getStrandSets(id_num)
                # for s in fwd_ss:
                #     print(' VHELIX %d fwd_ss:' % id_num, s)
                fwd_idxs, fwd_colors  = fwd_ss.dump(xover_list)
                rev_idxs, rev_colors  = rev_ss.dump(xover_list)
                self.strand_list.append((fwd_idxs, rev_idxs))
                
                ## prop_list.append((fwd_colors, rev_colors))
                # for s in strand_list:
                #     print( s )
        segid = 1

        """
        Strategy:
        1) Create all oligos/segments, describing 5'/3' bonding & location
        2) Add basepairing information
        3) For all basepairs, find stacks 5'/3' stacks
        """

        for hid in range(numHID):
            self.helixNtsFwd[hid] = dict()
            self.helixNtsRev[hid] = dict()
            self.dsChunks[hid] = []

        ## Loop over oligos/segments
        segments = []
        ntCount = 0
        keyFn = lambda o: (o.length(), o._strand5p.idNum(), o._strand5p.idx5Prime())
        for oligo in reversed(sorted(part.oligos(), key=keyFn)):
            # if oligo.isLoop():
            #     # print("A loop exists")
            #     raise Exception("Loop strands are not supported")
            
            seg = Segment("D%03d" % segid, self.ntScale)
            segments.append(seg)
            segid += 1

            lastNt = None
            pos = []
            angle = []
            seq = []

            ## loop over strands in oligo
            for strand in oligo.strand5p().generator3pStrand():
                hid = strand.idNum() # virtual helix ID
                x,y = origins[hid]

                axis = [0,0,1]

                keys = ['bases_per_repeat',
                        'turns_per_repeat',
                        'eulerZ','z']
                bpr,tpr,eulerZ,z = [vh_props[k][hid] for k in keys] 

                twist_per_base = tpr*360./bpr
                ## override twist_per_base if square lattice
                if self.latticeType == "square":
                    # assert(twist_per_base == 360.0*3/32)
                    twist_per_base == 360.0*3/32
                else:
                    assert(twist_per_base == 360.0/10.5)

                zIdxToPos = lambda idx: (x*10,y*10,z-3.4*idx)
                zIdxToAngle = lambda idx: idx*twist_per_base + eulerZ 

                isFwd = strand.isForward()

                lo,hi = strand.idxs()
                seqs = Strand.sequence(strand, for_export=True)
                numNt = strand.totalLength()

                ## We place nts as pairs; spacing is not always uniform along strand.
                ## Here we find the locations between which spacing will be uniform.
                compIdxs = [comp.idxs() for comp in strand.getComplementStrands()]
                chunks = self.combineRegionLists(
                    [[lo,hi]],
                    compIdxs,
                    intersect=True
                )
                self.dsChunks[hid].extend( chunks ) # Add running list of chunks for push bonds

                chunks = self.combineRegionLists(
                    [[lo,hi]],
                    chunks,
                    intersect=False
                )

                ## Find zIdx for each nucleotide
                zIdxs = []
                for l,h in chunks:
                    ## Find the number of nts between l,h
                    nts = h-l+1
                    for ins in strand.insertionsOnStrand(l,h):
                        nts += ins.length()

                    if nts <= 0:
                        print("WARNING: there is something strange in the structure at helix %d (%d,%d)" % (hid,l,h))
                        continue
                    assert( nts > 0 )
                    for i in range(nts):
                        if nts == 1:
                            zIdxs.append( l )
                        else:
                            zIdxs.append( l + (h-l) * float(i)/(nts-1) )

                assert( len(zIdxs) == numNt )
                    

                if isFwd:
                    helixNts = self.helixNtsFwd[hid]
                else:
                    helixNts = self.helixNtsRev[hid]
                    zIdxs = reversed(zIdxs)

                ## Find strandOccupancy index for each nucleotide
                strandOccIdx = []
                for l,h in chunks:
                    for i in range(l,h+1):
                        nts = 1
                        for ins in strand.insertionsOnStrand(i,i):
                            nts += ins.length()
                        helixNts[i] = []
                        for j in range(nts):
                            strandOccIdx.append(i)
                if not isFwd:
                    strandOccIdx = reversed(strandOccIdx)

                ## Add nucleotides
                for s,zIdx,idx in zip(seqs,zIdxs,strandOccIdx):
                    p = zIdxToPos(zIdx) 
                    a = zIdxToAngle(zIdx)
                    lastNt = seg.addNucleotide(hid, idx, zIdx, s, isFwd,
                                               p, a, lastNt)

                    helixNts[idx].append(lastNt)

        ## Correct the order of nts in helixNtsRev
        for hid,ntsAtIdx in self.helixNtsRev.items():
            for i,nts in ntsAtIdx.items():
                self.helixNtsRev[hid][i] = reversed(nts)

        ## 2) Find basepairs and stacks
        for hid in range(numHID):
            prevBasepair = None
            for i in self._strandOccupancies(hid):
                ## Check that strand index is dsDNA 
                if not (i in self.helixNtsFwd[hid] and i in self.helixNtsRev[hid]):
                    continue

                for nt1,nt2 in zip(self.helixNtsFwd[hid][i],
                                   self.helixNtsRev[hid][i]):
                    self._pairBases(nt1,nt2)
                
                    if prevBasepair is not None and i - prevBasepair[0] <= 3: # TODO: find better way of locating stacks
                        atomicModel._stackBases( prevBasepair[1], nt1)
                        atomicModel._stackBases( nt2, prevBasepair[2])
                    # else:
                    #     print("NO STACK FOR %d:%d" % (hid,i))
                        #basepairs[hid].append((i,n1,n2))
                    prevBasepair = (i,nt1,nt2)
                    
        self.segments.extend(segments)
        return

    def _connectNodes(self, below, above, sep):
        below.addNodeAbove(above, sep)
        above.addNodeBelow(below, sep)
        
    def _getNeighborHelixDict(part):
        props, origins = part.helixPropertiesAndOrigins()
        neighborHelices = dict()
        for i in range(len(origins)):
            neighborHelices[i] = [int(j.strip('[,]')) for j in props['neighbors'][i].split() if j != '[]']
        return neighborHelices

    def _determineLatticeType(part):
        props, origins = part.helixPropertiesAndOrigins()
        origins = [np.array(o) for o in origins]
        neighborVecs = []
        for i,js in atomicModel._getNeighborHelixDict(part).items():
            for j in js:
                neighborVecs.append( origins[j]-origins[i] )

        dots = [nv2.dot(nv1)/(np.linalg.norm(nv1)*np.linalg.norm(nv2)) for nv1 in neighborVecs for nv2 in neighborVecs]

        nDots = len(dots)
        nAngled = np.sum( np.abs(np.abs(dots)-0.5) < 0.2 )
        if float(nAngled)/float(nDots) == 0:
            return 'square'
        elif float(nAngled)/float(nDots) < 0.05:
            print('WARNING: unusual neighbors in square lattice structure')
            return 'square'
        elif float(nAngled)/float(nDots) > 0.4:
            return 'honeycomb'
        else:
            raise Exception("Could not identify lattice type")
           
    def backmap(self, simplerModel, simplerModelCoords, 
                dsDnaHelixNeighborDist=50, dsDnaAllNeighborDist=30,
                ssDnaHelixNeighborDist=20, ssDnaAllNeighborDist=15):

        ## Assign each nucleotide to a bead in simplerModel
        mapToSimplerModel = dict()
        cgWeight = dict()

        nts = [nt for seg in self.segments for nt in seg]
        for i in range(len(nts)):
            nt = nts[i]
            hid = nt.prop['helixId']
            if hid not in simplerModel.helices:
                raise Exception("Could not find helix %d in simple model; using just a single cadnano structure")
            cgH = simplerModel.helices[hid]
            zIdxs = np.array( sorted([i for i,b in cgH]) )

            zi = nt.prop['zIdx'] #
            cgi = np.searchsorted(zIdxs,zi,side='left',sorter=None)
            cgi, = [zIdxs[x] if x < len(zIdxs) else zIdxs[-1] for x in (cgi,)]
            mapToSimplerModel[i] = [cgH.nodes[x] for x in (cgi,)]

        ## Find transformation for each bead of simplerModel used in mapping
        trans = dict()
        for b in list(set([b for i,bs in mapToSimplerModel.items() for b in bs])):
            helixCutoff = dsDnaHelixNeighborDist if b.type[0] == 'd' else ssDnaHelixNeighborDist
            allCutoff = dsDnaAllNeighborDist if b.type[0] == 'd' else ssDnaAllNeighborDist

            r0 = b.initialPosition
            r1 = simplerModelCoords[b.idx]

            above = b.nodeAbove if b.nodeAbove is not None and \
                    np.sum((b.nodeAbove.initialPosition-b.initialPosition)**2) < 10**2 else b
            below = b.nodeBelow if b.nodeBelow is not None and \
                    np.sum((b.nodeBelow.initialPosition-b.initialPosition)**2) < 10**2 else b

            if type(simplerModel).__name__ == "beadModelTwist" and \
                    b.orientationNode is not None:
                o = b.orientationNode
                assert(above != below)

                z0 = above.initialPosition - below.initialPosition
                z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]
                z0,z1 = [z/np.linalg.norm(z) for z in (z0,z1)]

                x0 = o.initialPosition - r0
                x1 = simplerModelCoords[o.idx] - r1
                x0,x1 = [x-x.dot(z)*z for x,z in zip((x0,x1),(z0,z1))] # make x orthogonal to z
                x0,x1 = [x/np.linalg.norm(x) for x in (x0,x1)]

                y0,y1 = [np.cross(z,x) for x,z in zip((x0,x1),(z0,z1))]
                R,tmp1,tmp2 = minimizeRmsd( (x0,y0,z0), (x1,y1,z1) )

            else:
                ids = simplerModel._getNeighborhoodIds(b, simplerModelCoords, helixCutoff, allCutoff)
                posOld = np.array( [simplerModel.particles[i][0].initialPosition for i in ids] )
                posNew = np.array( [simplerModelCoords[i] for i in ids] )
                R,cOld,cNew = minimizeRmsd( posOld, posNew )

                if b.type[0] == 's':
                    assert(above != below)
                    z0 = above.initialPosition - below.initialPosition
                    z1 = simplerModelCoords[above.idx] - simplerModelCoords[below.idx]

                    zR = z0.dot(R)
                    cross = np.cross(z1,zR) / (np.linalg.norm(z1)*np.linalg.norm(zR))
                    angle = np.arcsin(np.linalg.norm(cross))*180/np.pi
                    R2 = rotationAboutAxis(cross,angle)

                    R = R.dot(R2)

            cOld,cNew = (r0,r1)
            trans[b.idx] = (R,cOld,cNew)

        ## Optionally smooth orientations
        
        ## Apply transformation to each bead of self
        for i in range(len(nts)):
            b = nts[i]
            cgb, = mapToSimplerModel[i]
            cgi = cgb.idx
            r0 = simplerModel.particles[cgi][0].initialPosition
            R,c0,c1 = trans[cgi]
            # assert( np.linalg.norm((b.position-r0).dot(R)) < 100 )
            b.position = (b.position - r0).dot(R) + simplerModelCoords[cgi]
            b.orientation = b.orientation.dot( R )

        for i in range(len(nts)):
            b = nts[i]
            if b.basepair is not None:
                assert( np.all( b.orientation == b.basepair.orientation ) )
                # bpVec = (b.position - b.basepair.position)
                # zVec = np.array((0,0,1)).dot(b.orientation)
                # assert( np.dot( bpVec, zVec ) < 0.05 * np.linalg.norm(bpVec) )
                
    def _pairBases(self, n1,n2):
        assert(n1.basepair is None and n2.basepair is None)
        n1.basepair = n2
        n2.basepair = n1

    def _stackBases(below, above):
        assert(below.stack3prime is None and above.stack5prime is None)
        below.stack3prime = above
        above.stack5prime = below        

    def _removeIntrahelicalConnectionsBeyond(self, cutoff):
        ## lazy progarmming: leave nts in one segment after bond break
        nts = [nt for seg in self.segments for nt in seg]
        for n1 in nts:
            n2 = n1.ntAt3prime
            if n2 is not None:
                r2 = np.sum( (n1.position - n2.position)**2 )
                if r2 > cutoff**2:
                    if n1.ntAt3prime == n2:
                        assert(n2.ntAt5prime == n1)
                        n1.ntAt3prime = None
                        n2.ntAt5prime = None
                    elif n2.ntAt3prime == n1:
                        assert(n1.ntAt5prime == n2)
                        n1.ntAt5prime = None
                        n2.ntAt3prime = None
                    else:
                        raise

    def scaffoldSeq(seqFile):
        return

    def randomizeUnassignedNtSeqs(self):
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = random.choice(tuple(seqComplement.keys()))
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]

    def setUnassignedNtSeqs(self,base):
        assert(base in ("A","T","C","G"))
        for seg in self.segments:
            for nt in seg:
                if nt.sequence == '?':
                    nt.sequence = base
                    if nt.basepair is not None:
                        nt.basepair.sequence = seqComplement[nt.sequence]

    def setScaffoldSeq(self,sequence):
        ## get longest segment
        segs = sorted( [(len(s),s) for s in self.segments], key=lambda x: x[0] )
        seg = segs[-1][1]
        assert( len(seg) > len(segs[-2][1]) )
        if len(seg) > len(sequence):
            raise Exception("Sequence (%d) too short for scaffold (%d)!" % (len(sequence),len(seg)))

        for nt,seq in zip(seg,sequence):
            assert(seq in ("A","T","C","G"))
            nt.sequence = seq
            if nt.basepair is not None:
                nt.basepair.sequence = seqComplement[nt.sequence]        

    def _assignAtomicIndexToNucleotides(self):
        natoms=0
        for seg in self.segments:
            for nt in seg:
                nt.firstAtomicIndex = natoms
                natoms += len(nt.atoms(transform=False))

    def combineRegionLists(self,loHi1,loHi2,intersect=False):

        """Combines two lists of (lo,hi) pairs specifying integer
        regions a single list of regions.  """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])

        if len(loHi1) == 0:
            if intersect:
                return []
            else:
                return loHi2
        if len(loHi2) == 0:
            if intersect:
                return []
            else:
                return loHi1

        ## Break input into lists of compact regions
        compactRegions1,compactRegions2 = [[],[]]
        for compactRegions,loHi in zip(
                [compactRegions1,compactRegions2],
                [loHi1,loHi2]):
            tmp = []
            lastHi = loHi[0][0]-1
            for lo,hi in loHi:
                if lo-1 != lastHi:
                    compactRegions.append(tmp)
                    tmp = []
                tmp.append((lo,hi))
                lastHi = hi
            if len(tmp) > 0:
                compactRegions.append(tmp)

        ## Build result
        result = []
        region = []
        i,j = [0,0]
        compactRegions1.append([[1e10]])
        compactRegions2.append([[1e10]])
        while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
            cr1 = compactRegions1[i]
            cr2 = compactRegions2[j]

            ## initialize region
            if len(region) == 0:
                if cr1[0][0] <= cr2[0][0]:
                    region = cr1
                    i += 1
                    continue
                else:
                    region = cr2
                    j += 1
                    continue

            if region[-1][-1] >= cr1[0][0]:
                region = self.combineCompactRegionLists(region, cr1, intersect=False)
                i+=1
            elif region[-1][-1] >= cr2[0][0]:
                region = self.combineCompactRegionLists(region, cr2, intersect=False)
                j+=1
            else:
                result.extend(region)
                region = []

        assert( len(region) > 0 )
        result.extend(region)
        result = sorted(result)

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(result,"\n")

        if intersect:
            lo = max( [loHi1[0][0], loHi2[0][0]] )
            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
            result = [r for r in result if r[0] >= lo and r[1] <= hi]

        return result

    def combineCompactRegionLists(self,loHi1,loHi2,intersect=False):

        """Combines two lists of (lo,hi) pairs specifying regions within a
        compact integer set into a single list of regions.

        examples:
        loHi1 = [[0,4],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 4), (5, 7), (8, 9)]

        loHi1 = [[0,3],[5,7]]
        loHi2 = [[2,4],[5,9]]
        out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
        """

        ## Validate input
        for l in (loHi1,loHi2):
            ## Assert each region in lists is sorted
            for pair in l:
                assert(len(pair) == 2)
                assert(pair[0] <= pair[1])
            ## Assert lists are compact
            for pair1,pair2 in zip(l[::2],l[1::2]):
                assert(pair1[1]+1 == pair2[0])

        if len(loHi1) == 0:
            if intersect:
                return []
            else:
                return loHi2
        if len(loHi2) == 0:
            if intersect:
                return []
            else:
                return loHi1

        ## Find the ends of the region
        lo = min( [loHi1[0][0], loHi2[0][0]] )
        hi = max( [loHi1[-1][1], loHi2[-1][1]] )

        ## Make a list of indices where each region will be split
        splitAfter = []
        for l,h in loHi2:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)

        for l,h in loHi1:
            if l != lo:
                splitAfter.append(l-1)
            if h != hi:
                splitAfter.append(h)
        splitAfter = sorted(list(set(splitAfter)))

        # print("splitAfter:",splitAfter)

        split=[]
        last = -2
        for s in splitAfter:
            split.append(s)
            last = s

        # print("split:",split)
        returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]

        if intersect:
            lo = max( [loHi1[0][0], loHi2[0][0]] )
            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
            returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]

        # print("loHi1:",loHi1)
        # print("loHi2:",loHi2)
        # print(returnList,"\n")
        return returnList

    def _getDsChunksForHelixPair(self,hid1,hid2):
        chunks = []
        chunks1,chunks2 = [sorted(list(set(self.dsChunks[hid]))) for hid in [hid1,hid2]]
        return self.combineRegionLists(chunks1,chunks2)

        i,j = [0,0]
        lastHi = -100000
        # pdb.set_trace()
        while i < len(chunks1) or j < len(chunks2):
            if i >= len(chunks1):
                tmp
            tmp = [x for x in sorted(list(set(chunks1[i]+chunks2[j]))) if x > lastHi]

            # if len(tmp) > 1:
            #     chunks.append(tmp[:2])
            # if len(tmp) > 0:
            #     lastHi = tmp[min(1,len(tmp))]
            chunks.append(tmp[:2])
            lastHi = tmp[min(1,len(tmp))]

            if chunks1[i][1] <= lastHi: i+=1
            if chunks2[j][1] <= lastHi: j+=1
        return chunks

    def getNtsBetweenCadnanoIdxInclusive(self,hid,lo,hi,isFwd=True):
        if isFwd:
            helixNts = self.helixNtsFwd[hid]
        else:
            helixNts = self.helixNtsRev[hid]

        nts = []
        for i in range(lo,hi+1):
            if i in helixNts:
                nts.extend(helixNts[i])
        return nts

    # -------------------------- #
    # Methods for querying model #
    # -------------------------- #
    def _getIntrahelicalNodeSeries(self,seriesLen):
        nodeSeries = set()
        for hid,hlx in self.helices.items():
            for zid,n in hlx:
                nodeList,sepList = n.getNodesAbove(seriesLen-1, inclusive = True)
                
                # nodeList = [n]
                # sepList = []
                # for i in range(seriesLen-1):
                #     if n.nodeAbove is None: break
                #     n = n.nodeAbove
                #     nodeList.append(n)
                #     sepList.append(n.nodeBelowSep)
                
                if len(nodeList) == seriesLen:
                    nodeList = tuple(nodeList)
                    sepList = tuple(sepList)
                    nodeSeries.add( tuple((nodeList,sepList)) )
        return nodeSeries

    def _getIntrahelicalBonds(self):
        return self._getIntrahelicalNodeSeries(2)
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( ((n, n.nodeAbove, n.nodeAboveSep) for zid,n in hlx if n.nodeAbove is not None) )
        #     bonds.update( ((n.nodeBelow, n, n.nodeBelowSep) for zid,n in hlx if n.nodeBelow is not None) )
        # return bonds

    def _getIntrahelicalAngles(self):
        return self._getIntrahelicalNodeSeries(3)

    def _getCrossoverBonds(self):
        return { ((n, xo[0]), (1)) 
                 for hid,hlx in self.helices.items()
                 for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx }
        # bonds = set()
        # for hid,hlx in self.helices.items():
        #     bonds.update( (((n, xo[0]), (1)) for zid,n in hlx for xo in n.xovers if n.idx < xo[0].idx) )
        # return bonds

    def _getCrossoverAnglesAndDihedrals(self):
        angles,dihedrals = [set(),set()]
        for hid,hlx in self.helices.items():
            lastXoverNode = None 
            bpsBetween = 0
            for zid,n in hlx:
                ## Search for a pair of crossovers
                if n.nodeBelow is None or n.type != "dsDNA":
                    ## Found ssDNA or a gap; reset search
                    lastXoverNode = None
                    bpsBetween = 0

                if lastXoverNode is None:
                    if len(n.xovers) > 0:
                        ## First node with a crossover
                        lastXoverNode = n
                else:
                    if n.nodeBelow is not None:
                        bpsBetween += n.nodeBelowSep
                    if len(n.xovers) > 0:
                        ## Second node with a crossover, Add dihedral(s)
                        for xo1 in lastXoverNode.xovers:
                            for xo2 in n.xovers:
                                assert( bpsBetween != 0 )
                                # if bpsBetween != 0
                                angles.add( ((xo1[0], lastXoverNode, n), bpsBetween) )
                                angles.add( ((lastXoverNode, n, xo2[0]), bpsBetween) )
                                dihedrals.add( ((xo1[0], lastXoverNode, n, xo2[0]), bpsBetween, xo1[1], xo2[1]) )

                        lastXoverNode = n
                        bpsBetween = 0

        return angles, dihedrals
        
    def _getBonds(self):
        bonds = self._getIntrahelicalBonds()
        bonds.update( self._getCrossoverBonds() )
        return bonds


    # -------------------------- #
    # Methods for printing model #
    # -------------------------- #
    def writeNamdFiles(self,prefix,numSteps=48000):
        self._assignAtomicIndexToNucleotides()
        self.writePdbPsf(prefix)
        self.writeENM(prefix)
        self.writeNamdFile(prefix,numSteps)

    def writePdbPsf(self, prefix):
        bonds,angles,dihedrals,impropers = [[],[],[],[]]
        with open("%s.pdb" % prefix,'w') as pdb, \
             open("%s.psf" % prefix,'w') as psf:

            ## Write headers
            pdb.write("CRYST1    1000.    1000.    1000.  90.00  90.00  90.00 P 1           1\n")
            psf.write("PSF NAMD\n\n") # create NAMD formatted psf
            psf.write("{:>8d} !NTITLE\n\n".format(0))

            ## Format strings
            ## http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM
            # pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}          {element:>2s}{charge:>2.2f}\n"
            pdbFormat = "ATOM{idx:>7d} {name:<4s}{altLoc:1s}{resname:<3s} {chain:1s}{residString:>4s}{iCode:1s}   {x:8.3f}{y:8.3f}{z:8.3f}{occupancy:6.2f}{beta:6.2f}      {segname:4s}{element:>2s}\n"
            ## From vmd/plugins/molfile_plugin/src/psfplugin.c
            ## "%d %7s %10s %7s %7s %7s %f %f"
            psfFormat = "{idx:>8d} {segname:7s} {resid:<10d} {resname:7s} " + \
                        "{name:7s} {type:7s} {charge: f} {mass:f}\n"

            ## PSF ATOMS section
            natoms=0
            for seg in self.segments:
                for atoms in seg.atoms(transform=False):
                    natoms += len(atoms)
            psf.write("{:>8d} !NATOM\n".format(natoms))

            atomProps = dict(idx=0,altLoc='',chain='A', iCode='',
                             occupancy=0, beta=0, charge='', resid=1)

            resnameDict = dict(A='ADE',T='THY',G='GUA',C='CYT')

            ## Loop over segments, setting common properties
            for seg in self.segments:
                atomProps['segname'] = seg.props['segname']
                prevAtoms = None
                
                ## Loop over nucleotides, setting common properties
                for atoms, seq in zip(seg.atoms(),seg.sequence()):
                    atomProps['resname'] = resnameDict[seq]
                    if atomProps['resid'] < 9999:
                        atomProps['residString'] = "%4d" % atomProps['resid']
                    else:
                        digit = (atomProps['resid'] // 1000) - 10
                        char = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"[digit]
                        rem = atomProps['resid'] - (digit+10)*1000
                        atomProps['residString'] = "{:1s}{:<3d}".format(char,rem)

                    ## Add bonds, angles, dihedrals associated with LKNA
                    if prevAtoms is not None:
                        c4idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C4'")
                        c2idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C2'")
                        h3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("H3'")
                        c3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("C3'")
                        o3idx = atomProps['idx']-len(prevAtoms)+prevAtoms.index("O3'")
                        Pidx = atomProps['idx'] + atoms.index("P")
                        o5idx = atomProps['idx'] + atoms.index("O5'")
                        o1pidx = atomProps['idx'] + atoms.index("O1P")
                        o2pidx = atomProps['idx'] + atoms.index("O2P")
                        c5idx = atomProps['idx'] + atoms.index("C5'")

                        bonds.append([o3idx, Pidx])
                        angles.append([c3idx, o3idx, Pidx])
                        for x in (o5idx,o1pidx,o2pidx):
                            angles.append([o3idx, Pidx, x])

                        for x in (c4idx,c2idx,h3idx):
                            dihedrals.append([x,c3idx, o3idx, Pidx])
                        for x in (o5idx,o1pidx,o2pidx):
                            dihedrals.append([c3idx, o3idx, Pidx,x])
                        dihedrals.append([o3idx, Pidx, o5idx, c5idx])
                        
                            
                    ## Find bonds,angles,dihedrals, and impropers
                    bonds.extend( atoms.bonds+atomProps['idx'] )
                    angles.extend( atoms.angles+atomProps['idx'] )
                    dihedrals.extend( atoms.dihedrals+atomProps['idx'] )
                    impropers.extend( atoms.impropers+atomProps['idx'] )

                    ## Loop over atoms
                    for props in atoms.atomicProps():
                        for k,v in props.items():
                            atomProps[k] = v
                        # atomProps['name'] = n
                        atomProps['element'] = atomProps['name'][0]

                        ## Write lines
                        pdb.write(pdbFormat.format( **atomProps ))
                        ## psfResid   = "%d%c%c" % (idx," "," "), # TODO: work with large indeces
                        ## Increment counter
                        atomProps['idx'] += 1
                        psf.write(psfFormat.format( **atomProps ))

                    prevAtoms = atoms
                    atomProps['resid'] += 1

            ## Write out bonds, angles, dihedrals
            psf.write("\n{:>8d} !NBOND".format(len(bonds)))
            counter = 0
            for b in bonds:
                if counter % 4 == 0: psf.write("\n")  
                psf.write(" {:>7d} {:>7d}".format(*b))
                counter += 1

            psf.write("\n\n{:>8d} !NTHETA".format(len(angles)))
            counter = 0
            for b in angles:
                if counter % 3 == 0: psf.write("\n")
                psf.write(" {:>7d} {:>7d} {:>7d}".format(*b))
                counter += 1

            psf.write("\n\n{:>8d} !NPHI".format(len(dihedrals)))
            counter = 0
            for b in dihedrals:
                if counter % 2 == 0: psf.write("\n")
                psf.write(" {:>7d} {:>7d} {:>7d} {:>7d}".format(*b))
                counter += 1

            psf.write("\n\n{:>8d} !NIMPHI".format(len(impropers)))
            counter = 0
            for b in impropers:
                if counter % 2 == 0: psf.write("\n")
                psf.write(" {:>7d} {:>7d} {:>7d} {:>7d}".format(*b))
                counter += 1
                
            psf.write("\n\n       0 !NDON: donors\n\n\n")
            psf.write("\n       0 !NACC: acceptors\n\n\n")
            psf.write("\n       0 !NNB\n\n")
            for i in range(natoms//8):
                psf.write("      0       0       0       0       0       0       0       0\n")
            for i in range(natoms-8*(natoms//8)):
                psf.write("      0")
            psf.write("\n\n       1       0 !NGRP\n\n")


    def writeENM(self, prefix):
        if self.latticeType == "square":
            enmTemplate = enmTemplateSQ
        elif self.latticeType == "honeycomb":
            enmTemplate = enmTemplateHC
        else:
            raise Exception("Lattice type '%s' not supported" % self.latticeType)
        noStackPrime = 0
        noBasepair = 0
        with open("%s.exb" % prefix,'w') as fh:
            # natoms=0
            for seg in self.segments:
                for nt1 in seg:
                    other = []

                    nt2 = nt1.basepair
                    if nt2 is not None:
                        if nt2.firstAtomicIndex > nt1.firstAtomicIndex:
                            other.append((nt2,'pair'))
                        nt2 = nt2.stack3prime
                        if nt2 is not None and nt2.firstAtomicIndex > nt1.firstAtomicIndex:
                            other.append((nt2,'paircross'))

                    else:
                        noBasepair += 1

                    nt2 = nt1.stack3prime
                    if nt2 is not None:
                        other.append((nt2,'stack'))
                        nt2 = nt2.basepair
                        if nt2 is not None and nt2.firstAtomicIndex > nt1.firstAtomicIndex:
                            other.append((nt2,'cross'))
                    else:
                        noStackPrime += 1

                    a1 = nt1.atoms(transform=False)
                    for nt2,key in other:
                        key = ','.join((key,nt1.sequence[0],nt2.sequence[0]))
                        for n1, n2, d in enmTemplate[key]:
                            d = float(d)
                            a2 = nt2.atoms(transform=False)
                            i,j = [a.index(n)-1 for a,n in zip((a1,a2),(n1,n2))]
                            # try:
                            #     i,j = [a.index(n)-1 for a,n in zip((a1,a2),(n1,n2))]
                            # except:
                            #     continue

                            k = 0.1
                            if self.latticeType == 'honeycomb':
                                correctionKey = ','.join((key,n1,n2))
                                assert(correctionKey in enmCorrectionsHC)
                                dk,dr = enmCorrectionsHC[correctionKey]
                                k  = float(dk)
                                d += float(dr)

                            i += nt1.firstAtomicIndex
                            j += nt2.firstAtomicIndex
                            fh.write("bond %d %d %f %.2f\n" % (i,j,k,d))


            print("NO STACKS found for:", noStackPrime)
            print("NO BASEPAIRS found for:", noBasepair)

        props, origins = self.part.helixPropertiesAndOrigins()
        origins = [np.array(o) for o in origins]

        ## Push bonds
        pushBonds = []

        fwdDict = dict()    # dictionaries of nts with keys [hid][zidx]
        revDict = dict()
        xo = dict()         # dictionary of crossovers between [hid1][hid2]

        helixIds = [hid for s in self.segments for hid in s.nodes.keys()]
        helixIds = sorted(list(set(helixIds)))

        #- initialize dictionaries
        for h1 in helixIds:
            fwdDict[h1] = dict()
            revDict[h1] = dict()
            xo[h1] = dict()
            for h2 in helixIds:
                xo[h1][h2] = []

        #- fill dictionaries
        for seg in self.segments:
            for nt in seg:
                hid = nt.prop['helixId']

                ## Add nt to ntDict
                idx = nt.prop['idx']
                zidx = nt.prop['zIdx']
                isFwd = nt.prop['isFwd']
                if isFwd:
                    fwdDict[hid][idx] = nt
                else:
                    revDict[hid][idx] = nt

                ## Find crossovers and ends
                for nt2 in (nt.ntAt5prime, nt.ntAt3prime):
                    if nt2 is not None:
                        hid2 = nt2.prop['helixId']
                        if hid2 != hid:
                            # xo[hid][hid2].append(zidx)
                            xo[hid][hid2].append(idx)
                            xo[hid2][hid].append(idx)

        props = self.part.getModelProperties()
        if props.get('point_type') == PointType.ARBITRARY:
            raise Exception("Not implemented")
        else:
            props, origins = self.part.helixPropertiesAndOrigins()
            neighborHelices = atomicModel._getNeighborHelixDict(self.part)

            if self.latticeType == 'honeycomb':
                # minDist = 8 # TODO: test against server
                minDist = 11 # Matches ENRG MD server... should it?
            elif self.latticeType == 'square':
                minDist = 11

            for hid1 in helixIds:
                for hid2 in neighborHelices[hid1]:
                    if hid2 not in helixIds:
                        # print("WARNING: writeENM: helix %d not in helixIds" % hid2) # xo[%d] dict" % (hid2,hid1))
                        continue

                    ## Build chunk for helix pair
                    chunks = self._getDsChunksForHelixPair(hid1,hid2)

                    for lo,hi in chunks:
                        # pdb.set_trace()
                        nts1 = self.getNtsBetweenCadnanoIdxInclusive(hid1,lo,hi)
                        nts2 = self.getNtsBetweenCadnanoIdxInclusive(hid2,lo,hi)

                        if len(nts1) <= len(nts2):
                            iVals = list(range(len(nts1)))
                            if len(nts1) > 1:
                                jVals = [int(round(j*(len(nts2)-1)/(len(nts1)-1))) for j in range(len(nts1))]
                            else:
                                jVals = [0]
                        else:
                            if len(nts2) > 1:
                                iVals = [int(round(i*(len(nts1)-1)/(len(nts2)-1))) for i in range(len(nts2))]
                            else:
                                iVals = [0]
                            jVals = list(range(len(nts2)))

                        ntPairs = [[nts1[i],nts2[j]] for i,j in zip(iVals,jVals)]

                        ## Skip pairs close to nearest crossover on lo side
                        xoIdx = [idx for idx in xo[hid1][hid2] if idx <= lo]
                        if len(xoIdx) > 0:
                            xoIdx = max(xoIdx)
                            assert( type(lo) is int )
                            xoDist = min([len(self.getNtsBetweenCadnanoIdxInclusive(hid,xoIdx,lo-1)) for hid in (hid1,hid2)])

                            skip=minDist-xoDist
                            if skip > 0:
                                ntPairs = ntPairs[skip:]

                        ## Skip pairs close to nearest crossover on hi side
                        xoIdx = [idx for idx in xo[hid1][hid2] if idx >= hi]
                        if len(xoIdx) > 0:
                            xoIdx = min(xoIdx)
                            xoDist = min([len(self.getNtsBetweenCadnanoIdxInclusive(hid,hi+1,xoIdx)) for hid in (hid1,hid2)])

                            skip=minDist-xoDist
                            if skip > 0:
                                ntPairs = ntPairs[:-skip]


                        for ntPair in ntPairs:
                            bps = [nt.basepair for nt in ntPair]
                            if None in bps: continue
                            for nt1,nt2 in [ntPair,bps]:
                                i,j = [nt.firstAtomicIndex for nt in (nt1,nt2)]
                                if j <= i: continue

                                if np.linalg.norm(nt1.position-nt2.position) > 45: continue

                                ai,aj = [nt.atoms(transform=False) for nt in (nt1,nt2)]
                                try:
                                    i += ai.index('P')-1
                                    j += aj.index('P')-1
                                    pushBonds.append(i)
                                    pushBonds.append(j)
                                except:
                                    pass

        print("PUSH BONDS:", len(pushBonds)/2)

        if not self.useTclForces:
            with open("%s.exb" % prefix, 'a') as fh:
                for i,j in zip(pushBonds[::2],pushBonds[1::2]):
                    fh.write("bond %d %d %f %.2f\n" % (i,j,1.0,31))
        else:
            atomList = list(set(pushBonds))
            with open("%s.forces.tcl" % prefix,'w') as fh:
                fh.write("set atomList {%s}\n\n" %
                         " ".join([str(x-1) for x in  atomList]) )
                fh.write("set bonds {%s}\n" %
                         " ".join([str(x-1) for x in pushBonds]) )
                fh.write("""
foreach atom $atomList {
    addatom $atom
}

proc calcforces {} {
    global atomList bonds
    loadcoords rv

    foreach i $atomList {
        set force($i) {0 0 0}
    }

    foreach {i j} $bonds {
        set rvec [vecsub $rv($j) $rv($i)]
        # lassign $rvec x y z
        # set r [expr {sqrt($x*$x+$y*$y+$z*$z)}]
        set r [getbond $rv($j) $rv($i)]
        set f [expr {2*($r-31.0)/$r}]
        vecadd $force($i) [vecscale $f $rvec]
        vecadd $force($j) [vecscale [expr {-1.0*$f}] $rvec]
    }

    foreach i $atomList {
        addforce $i $force($i)
    }

}
""")


    def writeNamdFile(self,prefix,numSteps=4800000):
        with open("%s.namd" % prefix,'w') as fh:
            fh.write("""#############################################################
## CONFIGURATION FILE FOR ORIGAMI STRUCTURE PREDICTION     ##
#############################################################

#############################################################
## ADJUSTABLE PARAMETERS                                   ##
#############################################################\n""")
            fh.write("set prefix %s\n" % prefix)
            fh.write("""set nLast 0;			# increment when continueing a simulation
set n [expr $nLast+1]
set out output/$prefix-$n
set temperature    300

structure          $prefix.psf
coordinates        $prefix.pdb

outputName         $out
XSTfile            $out.xst
DCDfile            $out.dcd

#############################################################
## SIMULATION PARAMETERS                                   ##
#############################################################

# Input
paraTypeCharmm	    on
parameters          charmm36.nbfix/par_all36_na.prm
parameters	    charmm36.nbfix/par_water_ions_na.prm

wrapAll             off

# Force-Field Parameters
exclude             scaled1-4
1-4scaling          1.0
switching           on
switchdist           8
cutoff              10
pairlistdist        12
""")
            if not self.useTclForces:
                fh.write("margin              30\n")
            fh.write("""
# Integrator Parameters
timestep            2.0  ;# 2fs/step
rigidBonds          all  ;# needed for 2fs steps
nonbondedFreq       1
fullElectFrequency  3
stepspercycle       12

# PME (for full-system periodic electrostatics)
PME                 no
PMEGridSpacing      1.2

# Constant Temperature Control
langevin            on    ;# do langevin dynamics
# langevinDamping     1   ;# damping coefficient (gamma); used in original study
langevinDamping     0.1   ;# less friction for faster relaxation
langevinTemp        $temperature
langevinHydrogen    off    ;# don't couple langevin bath to hydrogens

# output
useGroupPressure    yes
xstFreq             4800
outputEnergies      4800
dcdfreq             4800
restartfreq         48000

#############################################################
## EXTRA FORCES                                            ##
#############################################################

# ENM and intrahelical extrabonds
extraBonds on
extraBondsFile $prefix.exb
""")
            if self.useTclForces:
                fh.write("""
tclForces on
tclForcesScript $prefix.forces.tcl
""")
            fh.write("""
#############################################################
## RUN                                                     ##
#############################################################

# Continuing a job from the restart files
cellBasisVector1 1000 0 0
cellBasisVector2 0 1000 0
cellBasisVector3 0 0 1000

if {$nLast == 0} {
    temperature 300
    fixedAtoms on
    fixedAtomsForces on
    fixedAtomsFile $prefix.pdb
    fixedAtomsCol B
    minimize 2400

    fixedAtoms off
    minimize 2400
} else {
    bincoordinates  output/$prefix-$nLast.restart.coor
    binvelocities   output/$prefix-$nLast.restart.vel
}
""")
            fh.write("run %d\n" % numSteps)


    def simulate(self, outputPrefix, outputDirectory='output', numSteps=4800000, numprocs=8, gpus=0, namd=None):
        # pdb.set_trace()
        if type(gpus) is int:
            gpus = [gpus]
        for gpu in gpus:
            assert(type(gpu) is int)

        assert(type(numSteps) is int)
        assert(type(numprocs) is int)

        if outputDirectory == '': outputDirectory='.'

        if namd is None:
            for path in os.environ["PATH"].split(os.pathsep):
                path = path.strip('"')
                fname = os.path.join(path, "namd2")
                if os.path.isfile(fname) and os.access(fname, os.X_OK):
                    fname = fname
                    break

        if not os.path.exists(namd):
            raise Exception("NAMD was not found")
        if not os.path.isfile(namd):
            raise Exception("NAMD was not found")
        if not os.access(namd, os.X_OK):
            raise Exception("NAMD is not executable")

        if not os.path.exists(outputDirectory):
            os.makedirs(outputDirectory)
        elif not os.path.isdir(outputDirectory):
            raise Exception("outputDirectory '%s' is not a directory!" % outputDirectory)

        self.writeNamdFiles( outputPrefix, numSteps=numSteps )

        charmrun = os.path.join( os.path.dirname(namd), "charmrun" )

        ## $bindir/charmrun +p$procs $bindir/namd2 +netpoll +idlepoll +devices $gpus $config &> $log
        gpus = ','.join([str(gpu) for gpu in gpus])
        # cmd = (charmrun, "+p%d" % numprocs, namd, '+netpoll',
        env = os.environ.copy()
        env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        cmd = (namd,
               "+ppn", numprocs,
               '+netpoll','+idlepoll',
               '+devices', gpus,
               "%s.namd" % outputPrefix)
        cmd = tuple(str(x) for x in cmd)

        logfile = "%s/%s.log" % (outputDirectory, outputPrefix)
        print("Running NAMD with: %s >& %s" % (" ".join(cmd), logfile) )
        with open(logfile,'w') as fh:
            subprocess.call(cmd, env=env, stdout=fh, stderr=subprocess.STDOUT)