cadnano_segments.py 23.7 KB
Newer Older
cmaffeo2's avatar
cmaffeo2 committed
1
2
3
4
# -*- coding: utf-8 -*-
import pdb
import numpy as np
import os,sys
5
6
from glob import glob
import re
cmaffeo2's avatar
cmaffeo2 committed
7

8
from coords import readArbdCoords, readAvgArbdCoords
cmaffeo2's avatar
cmaffeo2 committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from segmentmodel import SegmentModel, SingleStrandedSegment, DoubleStrandedSegment

arbd="/home/cmaffeo2/development/cuda/arbd.dbg/src/arbd" # reduced the mem footprint cause vmd
namd="/home/cmaffeo2/development/namd-bin/NAMD_Git-2017-07-06_Linux-x86_64-multicore-CUDA/namd2"

def readSequenceFile(sequenceFile='resources/cadnano2pdb.seq.m13mp18.dat'):
    seq = []
    with open(sequenceFile) as ch:
        for l in ch:
            l = l.strip().replace(" ", "")
            if l[0] in (";","#"): continue
            seq.extend([c.upper() for c in l])
    return seq

m13seq = readSequenceFile(sequenceFile='resources/cadnano2pdb.seq.m13mp18.dat')

## TODO: separate SegmentModel from ArbdModel so multiple parts can be combined
## TODO: catch circular strands in "get_5prime" cadnano calls
## TODO: handle special motifs
##   - doubly-nicked helices
##   - helices that should be stacked across an empty region (crossovers from and end in the helix to another end in the helix)
##   - circular constructs

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
def combineRegionLists(loHi1,loHi2,intersect=False):

    """Combines two lists of (lo,hi) pairs specifying integer
    regions a single list of regions.  """

    ## Validate input
    for l in (loHi1,loHi2):
        ## Assert each region in lists is sorted
        for pair in l:
            assert(len(pair) == 2)
            assert(pair[0] <= pair[1])

    if len(loHi1) == 0:
        if intersect:
            return []
        else:
            return loHi2
    if len(loHi2) == 0:
        if intersect:
            return []
        else:
            return loHi1

    ## Break input into lists of compact regions
    compactRegions1,compactRegions2 = [[],[]]
    for compactRegions,loHi in zip(
            [compactRegions1,compactRegions2],
            [loHi1,loHi2]):
        tmp = []
        lastHi = loHi[0][0]-1
        for lo,hi in loHi:
            if lo-1 != lastHi:
                compactRegions.append(tmp)
                tmp = []
            tmp.append((lo,hi))
            lastHi = hi
        if len(tmp) > 0:
            compactRegions.append(tmp)

    ## Build result
    result = []
    region = []
    i,j = [0,0]
    compactRegions1.append([[1e10]])
    compactRegions2.append([[1e10]])
    while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
        cr1 = compactRegions1[i]
        cr2 = compactRegions2[j]

        ## initialize region
        if len(region) == 0:
            if cr1[0][0] <= cr2[0][0]:
                region = cr1
                i += 1
                continue
            else:
                region = cr2
                j += 1
                continue

        if region[-1][-1] >= cr1[0][0]:
            region = combineCompactRegionLists(region, cr1, intersect=False)
            i+=1
        elif region[-1][-1] >= cr2[0][0]:
            region = combineCompactRegionLists(region, cr2, intersect=False)
            j+=1
        else:
            result.extend(region)
            region = []

    assert( len(region) > 0 )
    result.extend(region)
    result = sorted(result)

    # print("loHi1:",loHi1)
    # print("loHi2:",loHi2)
    # print(result,"\n")

    if intersect:
        lo = max( [loHi1[0][0], loHi2[0][0]] )
        hi = min( [loHi1[-1][1], loHi2[-1][1]] )
        result = [r for r in result if r[0] >= lo and r[1] <= hi]

    return result

def combineCompactRegionLists(loHi1,loHi2,intersect=False):

    """Combines two lists of (lo,hi) pairs specifying regions within a
    compact integer set into a single list of regions.

    examples:
    loHi1 = [[0,4],[5,7]]
    loHi2 = [[2,4],[5,9]]
    out = [(0, 1), (2, 4), (5, 7), (8, 9)]

    loHi1 = [[0,3],[5,7]]
    loHi2 = [[2,4],[5,9]]
    out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
    """

    ## Validate input
    for l in (loHi1,loHi2):
        ## Assert each region in lists is sorted
        for pair in l:
            assert(len(pair) == 2)
            assert(pair[0] <= pair[1])
        ## Assert lists are compact
        for pair1,pair2 in zip(l[::2],l[1::2]):
            assert(pair1[1]+1 == pair2[0])

    if len(loHi1) == 0:
        if intersect:
            return []
        else:
            return loHi2
    if len(loHi2) == 0:
        if intersect:
            return []
        else:
            return loHi1

    ## Find the ends of the region
    lo = min( [loHi1[0][0], loHi2[0][0]] )
    hi = max( [loHi1[-1][1], loHi2[-1][1]] )

    ## Make a list of indices where each region will be split
    splitAfter = []
    for l,h in loHi2:
        if l != lo:
            splitAfter.append(l-1)
        if h != hi:
            splitAfter.append(h)

    for l,h in loHi1:
        if l != lo:
            splitAfter.append(l-1)
        if h != hi:
            splitAfter.append(h)
    splitAfter = sorted(list(set(splitAfter)))

    # print("splitAfter:",splitAfter)

    split=[]
    last = -2
    for s in splitAfter:
        split.append(s)
        last = s

    # print("split:",split)
    returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]

    if intersect:
        lo = max( [loHi1[0][0], loHi2[0][0]] )
        hi = min( [loHi1[-1][1], loHi2[-1][1]] )
        returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]

    # print("loHi1:",loHi1)
    # print("loHi2:",loHi2)
    # print(returnList,"\n")
    return returnList

cmaffeo2's avatar
cmaffeo2 committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
class cadnano_part(SegmentModel):
    def __init__(self, part, 
                 max_basepairs_per_bead = 7, 
                 max_nucleotides_per_bead = 5, 
                 local_twist = False
    ):
        self.part = part
        self._cadnano_part_to_segments(part)
        # SegmentModel.__init__(self,...)
        self.segments = [seg for hid,segs in self.helices.items() for seg in segs]
        self._add_intrahelical_connections()
        self._add_crossovers()
        self._add_prime_ends()
        SegmentModel.__init__(self, self.segments,
                              local_twist = local_twist,
                              max_basepairs_per_bead = max_basepairs_per_bead,
                              max_nucleotides_per_bead = max_nucleotides_per_bead,
                              dimensions=(5000,5000,5000))

    def _cadnano_part_to_segments(self,part):
        from cadnano.cnenum import PointType
        segments = dict()
        self.helices = helices = dict()
        self.helix_ranges = helix_ranges = dict()

        props = part.getModelProperties().copy()
        
        if props.get('point_type') == PointType.ARBITRARY:
            # TODO add code to encode Parts with ARBITRARY point configurations
            raise NotImplementedError("Not implemented")
        else:
            vh_props, origins = part.helixPropertiesAndOrigins()
        self.origins = origins

        vh_list = []
        strand_list = []
        self.xover_list = xover_list = []
        numHID = part.getIdNumMax() + 1

        for id_num in range(numHID):
            offset_and_size = part.getOffsetAndSize(id_num)
            if offset_and_size is None:
                ## Add a placeholder for empty helix
                vh_list.append((id_num, 0))
                strand_list.append(None)
            else:
                offset, size = offset_and_size
                vh_list.append((id_num, size))
                fwd_ss, rev_ss = part.getStrandSets(id_num)
                fwd_idxs, fwd_colors  = fwd_ss.dump(xover_list)
                rev_idxs, rev_colors  = rev_ss.dump(xover_list)
            
                strand_list.append((fwd_idxs, rev_idxs))
        
        ## Get lists of 5/3prime ends
        strands5 = [o.strand5p() for o in part.oligos()]
        strands3 = [o.strand3p() for o in part.oligos()]
        
        self._5prime_list = [(s.idNum(),s.isForward(),s.idx5Prime()) for s in strands5]
        self._3prime_list = [(s.idNum(),s.isForward(),s.idx3Prime()) for s in strands3]

        ## Get dictionary of insertions 
        self.insertions = allInsertions = part.insertions()
256
        self.strand_occupancies = dict()
cmaffeo2's avatar
cmaffeo2 committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

        ## Build helices 
        for hid in range(numHID):
            # print("Working on helix",hid)
            helices[hid] = []
            helix_ranges[hid] = []

            helixStrands = strand_list[hid]
            if helixStrands is None: continue

            ## Build list of tuples containing (idx,length) of insertions/skips
            insertions = sorted( [(i[0],i[1].length()) for i in allInsertions[hid].items()],
                                 key=lambda x: x[0] )
            
            ## Build list of strand ends and list of mandatory node locations
            ends1,ends2 = self._helixStrandsToEnds(helixStrands)

            ## Find crossovers for this helix
            reqNodeZids = sorted(list(set( ends1 + ends2 ) ) )
            
            ## Build lists of which nt sites are occupied in the helix
            strandOccupancies = [ [x for i in range(0,len(e),2) 
                                   for x in range(e[i],e[i+1]+1)] 
                                  for e in (ends1,ends2) ]
281
            self.strand_occupancies[hid] = strandOccupancies
cmaffeo2's avatar
cmaffeo2 committed
282

283
            ends1,ends2 = [ [(e[i],e[i+1]) for i in range(0,len(e),2)] for e in (ends1,ends2) ]
cmaffeo2's avatar
cmaffeo2 committed
284

285
286
287
288
            regions = combineRegionLists(ends1,ends2)
            if hid == 43:
                import pdb
            for zid1,zid2 in regions:
cmaffeo2's avatar
cmaffeo2 committed
289
                zMid = int(0.5*(zid1+zid2))
290
                assert( zMid in strandOccupancies[0] or zMid in strandOccupancies[1] )
cmaffeo2's avatar
cmaffeo2 committed
291
292
293
294
295
296
297

                numBps = zid2-zid1+1
                for ins_idx,length in insertions:
                    ## TODO: ensure placement of insertions is correct
                    ##   (e.g. are insertions at the ends handled correctly?)
                    if ins_idx < zid1:
                        continue
298
                    if ins_idx > zid2:
cmaffeo2's avatar
cmaffeo2 committed
299
300
301
302
303
304
                        break
                    numBps += length

                print("Adding helix with length",numBps,zid1,zid2)

                kwargs = dict(name="%d-%d" % (hid,len(helices[hid])),
305
306
307
308
309
310
                              num_nts = numBps)

                posargs1 = dict( start_position = self._get_cadnano_position(hid,zid1),
                                 end_position   = self._get_cadnano_position(hid,zid2) )
                posargs2 = dict( start_position = posargs1['end_position'],
                                 end_position = posargs1['start_position'])
cmaffeo2's avatar
cmaffeo2 committed
311
312
313
                
                ## TODO get sequence from cadnano api
                if zMid in strandOccupancies[0] and zMid in strandOccupancies[1]:
314
315
316
317
318
319
320
                    seg = DoubleStrandedSegment(**kwargs,**posargs1)
                elif zMid in strandOccupancies[0]:
                    seg = SingleStrandedSegment(**kwargs,**posargs1)
                elif zMid in strandOccupancies[1]:
                    seg = SingleStrandedSegment(**kwargs,**posargs2)
                else:
                    raise Exception("Segment could not be found")
cmaffeo2's avatar
cmaffeo2 committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
                                     
                helices[hid].append( seg )
                helix_ranges[hid].append( (zid1,zid2) )

    def _get_cadnano_position(self, hid, zid):
        return [10*a for a in self.origins[hid]] + [-3.4*zid]

    def _helixStrandsToEnds(self, helixStrands):
        """Utility method to convert cadnano strand lists into list of
        indices of terminal points"""

        endLists = [[],[]]
        for endList, strandList in zip(endLists,helixStrands):
            lastStrand = None
            for s in strandList:
                if lastStrand is None:
                    ## first strand
                    endList.append(s[0])
                elif lastStrand[1] != s[0]-1: 
                    assert( s[0] > lastStrand[1] )
                    endList.extend( [lastStrand[1], s[0]] )
                lastStrand = s
            if lastStrand is not None:
                endList.append(lastStrand[1])
        return endLists

347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    def _helix_strands_to_segment_ranges(self, helix_strands):
        """Utility method to convert cadnano strand lists into list of
        indices of terminal points"""
        def _join(strands):
            ends = []
            lastEnd = None
            for start,end in strands:
                if lastEnd is None:
                    ends.append([start])
                elif lastEnd != start-1:
                    ends[-1].append(lastEnd)
                    ends.append([start])
                lastEnd = end
            if lastEnd is not None:
                ends[-1].append(lastEnd)
            return ends

        s1,s2 = [_join(s) for s in helix_strands]
        i = j = 0
        
        ## iterate through strands
        while i < len(s1) and j < len(s2):
            min(s1[i][0],s2[j][0])

cmaffeo2's avatar
cmaffeo2 committed
371
372
373
374
375
376
377
378
379
380
381
    def _get_segment(self, hid, zid):
        ## TODO: rename these variables to segments
        segs = self.helices[hid]
        ranges = self.helix_ranges[hid]
        for i in range(len(ranges)):
            zmin,zmax = ranges[i]
            if zmin <= zid and zid <= zmax:
                return segs[i]
        raise Exception("Could not find segment in helix %d at position %d" % (hid,zid))
                
    def _get_nucleotide(self, hid, zid):
382
        raise Exception("Deprecated")
cmaffeo2's avatar
cmaffeo2 committed
383
384
385
386
387
388
389
390
391
392
393
394
395
        seg = self._get_segment(hid,zid)
        sid = self.helices[hid].index(seg)
        zmin,zmax = self.helix_ranges[hid][sid]

        nt = zid-zmin

        ## Find insertions
        # TODO: for i in range(zmin,zid+1): ?
        for i in range(zmin,zid):
            if i in self.insertions[hid]:
                nt += self.insertions[hid][i].length()
        return nt

396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    def _get_segment_nucleotide(self, hid, zid):
        """ returns segments and zero-based nucleotide index """
        seg = self._get_segment(hid,zid)
        sid = self.helices[hid].index(seg)
        zmin,zmax = self.helix_ranges[hid][sid]

        zMid = int(0.5*(zmin+zmax))
        occ = self.strand_occupancies[hid]

        ## TODO combine if/else when nested TODO is resolved
        if (zMid not in occ[0]) and (zMid in occ[1]):
            ## reversed ssDNA strand
            nt = zmax-zid
            # TODO: for i in range(zmin,zid+1): ?
            for i in range(zid,zmax):
                if i in self.insertions[hid]:
                    nt += self.insertions[hid][i].length()
        else:
            ## normal condition
            nt = zid-zmin
            # TODO: for i in range(zmin,zid+1): ?
            for i in range(zmin,zid):
                if i in self.insertions[hid]:
                    nt += self.insertions[hid][i].length()

        ## Find insertions
        return seg, nt

        

cmaffeo2's avatar
cmaffeo2 committed
426
427
428
    """ Routines to add connnections between helices """
    def _add_intrahelical_connections(self):
        for hid,segs in self.helices.items():
429
            occ = self.strand_occupancies[hid]
cmaffeo2's avatar
cmaffeo2 committed
430
            for i in range(len(segs)-1):
431
432
433
434
435
436
                    seg1,seg2 = [segs[j] for j in (i,i+1)]
                    r1,r2 = [self.helix_ranges[hid][j] for j in (i,i+1)]
                    if r1[1]+1 == r2[0]:
                        ## TODO: handle nicks that are at intrahelical connections(?)
                        zmid1 = int(0.5*(r1[0]+r1[1]))
                        zmid2 = int(0.5*(r2[0]+r2[1]))
cmaffeo2's avatar
cmaffeo2 committed
437
438
439
440
441
442
443
444
445
446
447
                        if seg1.name == "19-3" or seg2.name == "19-3":
                            import pdb
                            pdb.set_trace()
                        
                # if zMid in strandOccupancies[0] and zMid in strandOccupancies[1]:
                #     seg = DoubleStrandedSegment(**kwargs,**posargs1)
                # elif zMid in strandOccupancies[0]:
                #     seg = SingleStrandedSegment(**kwargs,**posargs1)
                # elif zMid in strandOccupancies[1]:
                #     seg = SingleStrandedSegment(**kwargs,**posargs2)

448
449
450
451
452
453
                        ## TODO: validate
                        if zmid1 in occ[0] and zmid2 in occ[0]:
                            seg1.connect_end3(seg2.start5)

                        if zmid1 in occ[1] and zmid2 in occ[1]:
                            if zmid1 in occ[0]:
cmaffeo2's avatar
cmaffeo2 committed
454
455
456
457
458
                                end = seg1.end5
                            else:
                                end = seg1.start5
                            if zmid2 in occ[0]:
                                seg2.connect_start3(end)
459
460
461
                            else:
                                seg2.connect_end3(seg1.start5)
                        
cmaffeo2's avatar
cmaffeo2 committed
462
463
    def _add_crossovers(self):
        for h1,f1,z1,h2,f2,z2 in self.xover_list:
cmaffeo2's avatar
cmaffeo2 committed
464
465
466
            # if (h1 == 52 or h2 == 52) and z1 == 221:
            #     import pdb
            #     pdb.set_trace()
467
468
            seg1, nt1 = self._get_segment_nucleotide(h1,z1)
            seg2, nt2 = self._get_segment_nucleotide(h2,z2)
cmaffeo2's avatar
cmaffeo2 committed
469
            ## TODO: use different types of crossovers
cmaffeo2's avatar
cmaffeo2 committed
470
            ## fwd?
cmaffeo2's avatar
cmaffeo2 committed
471
472
473
474
            seg1.add_crossover(nt1,seg2,nt2,[f1,f2])

    def _add_prime_ends(self):
        for h,fwd,z in self._5prime_list:
475
            seg, nt = self._get_segment_nucleotide(h,z)
cmaffeo2's avatar
cmaffeo2 committed
476
477
478
479
            print(seg.name,nt,fwd)
            seg.add_5prime(nt,fwd)

        for h,fwd,z in self._3prime_list:
480
            seg, nt = self._get_segment_nucleotide(h,z)
cmaffeo2's avatar
cmaffeo2 committed
481
482
483
484
            seg.add_3prime(nt,fwd) 
   
    def get_bead(self, hid, zid):
        # get segment, get nucleotide,
485
        seg, nt = self._get_segment_nucleotide(h,z)
cmaffeo2's avatar
cmaffeo2 committed
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
        # return seg.get_nearest_bead(seg,nt / seg.num_nts)
        return seg.get_nearest_bead(seg,nt / (seg.num_nts-1))
        
def read_json_file(filename):
    import json
    import re

    try:
        with open(filename) as ch:
            data = json.load(ch)
    except:
        with open(filename) as ch:
            content = ""
            for l in ch:
                l = re.sub(r"'", r'"', l)
                # https://stackoverflow.com/questions/4033633/handling-lazy-json-in-python-expecting-property-name
                # l = re.sub(r"{\s*(\w)", r'{"\1', l)
                # l = re.sub(r",\s*(\w)", r',"\1', l)
                # l = re.sub(r"(\w):", r'\1":', l)
                content += l+"\n"
            data = json.loads(content)
    return data

def decode_cadnano_part(json_data):
    import cadnano
    from cadnano.document import Document

    try:
        doc = Document()
        cadnano.fileio.v3decode.decode(doc, json_data)
    except:
        doc = Document()
        cadnano.fileio.v2decode.decode(doc, json_data)

    parts = [p for p in doc.getParts()]
    if len(parts) != 1:
        raise Exception("Only documents containing a single cadnano part are implemented at this time.")
    part = parts[0]
    return part

def package_archive( name, directory ):
    ...

529
def run_simulation_protocol( output_name, job_id, json_data,
cmaffeo2's avatar
cmaffeo2 committed
530
531
532
533
534
535
                             sequence=None,
                             remove_long_bonds=False,
                             gpu = 0,
                             directory=None
                         ):

536
537
    coarseSteps = 1e5+1
    fineSteps = 1e5+1
538

cmaffeo2's avatar
cmaffeo2 committed
539
540
541
542
543
    ret = None
    d_orig = os.getcwd()
    try:
        if directory is None:
            import tempfile
544
            directory = tempfile.mkdtemp(prefix='origami-%s-' % job_id, dir='/dev/shm/')
cmaffeo2's avatar
cmaffeo2 committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        elif not os.path.exists(directory):
            os.makedirs(directory)
        os.chdir(directory)

        output_directory = "output"
        
        """ Read in data """
        part = decode_cadnano_part(json_data)
        model = cadnano_part(part,
                             max_basepairs_per_bead = 7, 
                             max_nucleotides_per_bead = 4,
                             local_twist=False)
        model._generate_strands() # TODO: move into model creation

        # TODO        
        # try:
        #     model.set_cadnano_sequence()
        # finally:
        #     ...
        #     if sequence is not None and len() :
        #         model.strands[0].set_sequence(seq)
        
        if sequence is None or len(sequence) == 0:
            ## default m13mp18
            sequence = list(m13seq)
570
571
572
573
574
575
            try:
                model.strands[0].set_sequence(sequence)
            except:
                ...
        else:
            model.strands[0].set_sequence(sequence)
cmaffeo2's avatar
cmaffeo2 committed
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633

        for s in model.segments:
            s.randomize_unset_sequence()

        # TODO: add the output directory in to the readArbdCoords functions or make it an attribute of the model object 

        """ Coarse simulation """
        # simargs = dict(timestep=200e-6, outputPeriod=1e5, gpu=gpu )
        simargs = dict(timestep=200e-6, outputPeriod=1e5, gpu=gpu, arbd=arbd )
        
        if remove_long_bonds:
            output_prefix = "%s-0" % output_name
            full_output_prefix = "%s/%s" % (output_directory,output_prefix)
            ## TODO Remove long bonds
            model.simulate( outputPrefix = output_prefix, numSteps=0.05*coarseSteps, **simargs )
            coordinates = readArbdCoords('%s.0.restart' % full_output_prefix)
            output_prefix = "%s-1" % output_name
            model._update_segment_positions(coordinates)
            # model._clear_beads()
            # # model._generate_bead_model( 7, 4, False )
            # model._generate_bead_model( 7, 4, False )
            model.simulate( outputPrefix = output_prefix, numSteps=0.95*coarseSteps, **simargs )
        else:
            output_prefix = "%s-1" % output_name
            full_output_prefix = "%s/%s" % (output_directory,output_prefix)
            model.simulate( outputPrefix = output_prefix, numSteps=coarseSteps, **simargs )
        coordinates = readArbdCoords('%s.0.restart' % full_output_prefix)


        """ Fine simulation """ 
        output_prefix = "%s-2" % output_name
        full_output_prefix = "%s/%s" % (output_directory,output_prefix)
        simargs['timestep'] = 50e-6
        model._update_segment_positions(coordinates)
        model._clear_beads()
        model._generate_bead_model( 1, 1, local_twist=True, escapable_twist=True )
        model.simulate( outputPrefix = output_prefix, numSteps=fineSteps, **simargs )
        coordinates = readAvgArbdCoords('%s.psf' % output_prefix,'%s.pdb' % output_prefix, '%s.0.dcd' % full_output_prefix, rmsdThreshold=1)


        """ Freeze twist """
        output_prefix = "%s-3" % output_name
        full_output_prefix = "%s/%s" % (output_directory,output_prefix)
        model._update_segment_positions(coordinates)
        model._clear_beads()
        model._generate_bead_model( 1, 1, local_twist=True, escapable_twist=False )
        model.simulate( outputPrefix = output_prefix, numSteps=fineSteps, **simargs )
        coordinates = readAvgArbdCoords('%s.psf' % output_name,'%s.pdb' % output_prefix, '%s.0.dcd' % full_output_prefix )


        """ Atomic simulation """
        output_prefix = "%s-4" % output_name
        full_output_prefix = "%s/%s" % (output_directory,output_prefix)
        model._update_segment_positions(coordinates)
        model._clear_beads()
        model._generate_atomic_model(scale=0.25)
        model.atomic_simulate( outputPrefix = full_output_prefix )

634
        ret = directory
cmaffeo2's avatar
cmaffeo2 committed
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        
    except:
        raise
    finally:
        os.chdir(d_orig)
        
    return ret

# pynvml.nvmlInit()
# gpus = range(pynvml.nvmlDeviceGetCount())
# pynvml.nvmlShutdown()
# gpus = [0,1,2]
# print(gpus)

if __name__ == '__main__':
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
    loHi1 = [[0,4],[5,7]]
    loHi2 = [[2,4],[5,9]]
    out = [(0, 1), (2, 4), (5, 7), (8, 9)]
    print(loHi1)
    print(loHi2)
    print(combineRegionLists(loHi1,loHi2))
    print(combineCompactRegionLists(loHi1,loHi2))

    loHi1 = [[0,3],[5,7]]
    loHi2 = [[2,4],[5,9]]
    out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
    print(loHi1)
    print(loHi2)
    print(combineRegionLists(loHi1,loHi2))
    print(combineCompactRegionLists(loHi1,loHi2))

    combineRegionLists
    
    # for f in glob('json/*'):
    #     print("Working on {}".format(f))
    #     out = re.match('json/(.*).json',f).group(1)
    #     data = read_json_file(f)
    #     run_simulation_protocol( out, "job-id", data, gpu=0 )