From 34556ed556c191db8595d79ee5391e62302a0cbe Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Sat, 15 Sep 2018 10:03:45 -0500
Subject: [PATCH] Improved psf/pdb writer

---
 mrdna/model/arbdmodel.py          | 115 ++++++++++++++++--------------
 mrdna/readers/cadnano_segments.py |  49 +++++++++----
 mrdna/segmentmodel.py             |  48 ++++++++-----
 3 files changed, 128 insertions(+), 84 deletions(-)

diff --git a/mrdna/model/arbdmodel.py b/mrdna/model/arbdmodel.py
index d25b1b2..8cefb88 100644
--- a/mrdna/model/arbdmodel.py
+++ b/mrdna/model/arbdmodel.py
@@ -382,19 +382,63 @@ class PointParticle(Transformable, Child):
             else:
                 raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
 
+    def _get_psfpdb_dictionary(self):
+        p = self
+        try:
+            segname = p.segname
+        except:
+            segname = "A"
+        try:
+            resname = p.resname
+        except:
+            resname = p.name[:3]
+        try:
+            resid = p.resid
+        except:
+            resid = p.idx+1
+
+        try:
+            occ = p.occupancy
+        except:
+            occ = 0
+        try:
+            beta = p.beta
+        except:
+            beta = 0
+
+        data = dict(segname = segname,
+                    resname = resname,
+                    name = str(p.name)[:4],
+                    chain = "A",
+                    resid = int(resid),
+                    idx = p.idx+1,
+                    type = p.type_.name[:4],
+                    charge = p.charge,
+                    mass = p.mass,
+                    occupancy = occ,
+                    beta = beta
+                )
+        return data
+
 
 class Group(Transformable, Parent, Child):
+
     def __init__(self, name=None, children = None, parent=None, 
                  position = np.array((0,0,0)),
                  orientation = np.array(((1,0,0),(0,1,0),(0,0,1))),
                  remove_duplicate_bonded_terms = False,
-    ):
+                 **kwargs):
+
         Transformable.__init__(self, position, orientation)
         Child.__init__(self, parent) # Initialize Child first
         Parent.__init__(self, children, remove_duplicate_bonded_terms)
         self.name = name
         self.isClone = False
 
+        for key,val in kwargs.items():
+            self.__dict__[key] = val
+
+
     def clone(self):
         return Clone(self)
         g = copy(self)
@@ -445,49 +489,32 @@ class PdbModel(Transformable, Parent):
             fh.write("CRYST1{:>9.3f}{:>9.3f}{:>9.3f}  90.00  90.00  90.00 P 1           1\n".format( *self.dimensions ))
 
             ## Write coordinates
-            formatString = "ATOM {:>6.6s} {:^4.4s}{:1.1s}{:3.3s} {:1.1s}{:>5.5s}   {:8.8s}{:8.8s}{:8.8s}{:6.2f}{:6.2f}{:2.2s}{:2d}{:>6s}\n"
+            formatString = "ATOM {idx:>6.6s} {name:^4.4s} {resname:3.3s} {chain:1.1s}{resid:>5.5s}   {x:8.8s}{y:8.8s}{z:8.8s}{occupancy:6.2f}{beta:6.2f}  {charge:2d}{segname:>6s}\n"
             for p in self.particles:
                 ## http://www.wwpdb.org/documentation/file-format-content/format33/sect9.html#ATOM
-                idx = p.idx+1
-                try:
-                    segname = p.segname
-                except:
-                    segname = "A"
-                try:
-                    resname = p.resname
-                except:
-                    resname = p.name[:3]
-                try: 
-                    resid = p.resid
-                except:
-                    resid = idx
-
-                ## TODOTODO test
+                data = p._get_psfpdb_dictionary()
+                idx = data['idx']
+
                 if np.log10(idx) >= 5:
                     idx = " *****"
                 else:
                     idx = "{:>6d}".format(idx)
+                data['idx'] = idx
 
-                # name = str(p.name)[:4]
-                name = str(p.name)[:4]
-                resname = resname
-                chain = "A"
-                charge = 0
-                occ = 0
-                beta = 0
-                # x,y,z = [x for x in p.collapsedPosition()]
                 pos = p.collapsedPosition()
                 dig = [max(int(np.log10(np.abs(x)+1e-6)//1),0)+1 for x in pos]
                 for d in dig: assert( d <= 7 )
                 # assert( np.all(dig <= 7) )
                 fs = ["{: %d.%df}" % (8,7-d) for d in dig]
                 x,y,z = [f.format(x) for f,x in zip(fs,pos)] 
+                data['x'] = x
+                data['y'] = y
+                data['z'] = z
+                assert(data['resid'] < 1e5)
+                data['charge'] = int(data['charge'])
+                data['resid'] = "{:<4d}".format(data['resid'])
+                fh.write( formatString.format(**data) )
 
-                assert(resid < 1e5)
-                resid = "{:<4d}".format(resid)
-
-                fh.write( formatString.format(
-                    idx, name, "", resname, chain, resid, x, y, z, occ, beta, "", charge, segname ))
         return
         
     def writePsf(self, filename):
@@ -506,31 +533,9 @@ class PdbModel(Transformable, Parent):
             formatString = "{idx:>8d} {segname:7.7s} {resid:<10.10s} {resname:7.7s}" + \
                            " {name:7.7s} {type:7.7s} {charge:f} {mass:f}\n"
             for p in self.particles:
-                idx = p.idx + 1
-                try:
-                    segname = p.segname
-                except:
-                    segname = "A"
-                try:
-                    resname = p.resname
-                except:
-                    resname = str(p.name)[:3]
-                try: 
-                    resid = p.resid
-                except:
-                    resid = idx
-
-                data = dict(
-                    idx     = idx,
-                    segname = segname,
-                    resid   = "%d%c%c" % (resid," "," "), # TODO: work with large indices
-                    name    = str(p.name)[:4],
-                    resname = resname,
-                    type    = p.type_.name[:4],
-                    charge  = p.charge,
-                    mass    = p.mass
-                )
-                fh.write(formatString.format( **data ))
+                data = p._get_psfpdb_dictionary()
+                data['resid'] = "%d%c%c" % (data['resid']," "," ") # TODO: work with large indices
+                fh.write( formatString.format(**data) )
             fh.write("\n")
 
             ## Write out bonds
diff --git a/mrdna/readers/cadnano_segments.py b/mrdna/readers/cadnano_segments.py
index bc69c14..708c4cb 100644
--- a/mrdna/readers/cadnano_segments.py
+++ b/mrdna/readers/cadnano_segments.py
@@ -285,26 +285,28 @@ class cadnano_part(SegmentModel):
                 zMid = int(0.5*(zid1+zid2))
                 assert( zMid in strandOccupancies[0] or zMid in strandOccupancies[1] )
 
-                numBps = zid2-zid1+1
-                for ins_idx,length in insertions:
-                    ## TODO: ensure placement of insertions is correct
-                    ##   (e.g. are insertions at the ends handled correctly?)
-                    if ins_idx < zid1:
-                        continue
-                    if ins_idx > zid2:
-                        break
-                    numBps += length
+                bp_to_zidx = []
+                insertion_dict = {idx:length for idx,length in insertions}
+                for i in range(zid1,zid2+1):
+                    if i in insertion_dict:
+                        l = insertion_dict[i]
+                    else:
+                        l = 0
+                    for j in range(i,i+1+l):
+                        bp_to_zidx.append(i)
+                numBps = len(bp_to_zidx)
 
                 # print("Adding helix with length",numBps,zid1,zid2)
 
-                kwargs = dict(name="%d-%d" % (hid,len(helices[hid])))
-
+                name = "%d-%d" % (hid,len(helices[hid]))
+                # "H%03d" % hid
+                kwargs = dict(name=name, segname=name, occupancy=hid)
 
                 posargs1 = dict( start_position = self._get_cadnano_position(hid,zid1),
                                  end_position   = self._get_cadnano_position(hid,zid2) )
                 posargs2 = dict( start_position = posargs1['end_position'],
                                  end_position = posargs1['start_position'])
-                
+
                 ## TODO get sequence from cadnano api
                 if zMid in strandOccupancies[0] and zMid in strandOccupancies[1]:
                     kwargs['num_bp'] = numBps
@@ -317,7 +319,28 @@ class cadnano_part(SegmentModel):
                     seg = SingleStrandedSegment(**kwargs,**posargs2)
                 else:
                     raise Exception("Segment could not be found")
-                                     
+
+                def callback(segment, bp_to_zidx=bp_to_zidx):
+                    for b in segment.beads:
+                        bp = int(round(b.get_nt_position(segment)))
+                        if bp < 0: bp = 0
+                        if bp >= segment.num_nt: bp = segment.num_nt-1
+                        b.beta = bp_to_zidx[bp]
+                        if 'orientation_bead' in b.__dict__:
+                            b.orientation_bead.beta = bp_to_zidx[bp]
+                seg._generate_bead_callbacks.append(callback)
+
+                def atomic_callback(nucleotide, bp_to_zidx=bp_to_zidx):
+                    nt = nucleotide
+                    segment = nucleotide.parent.segment
+                    bp = int(round(segment.contour_to_nt_pos( nt.contour_position )))
+                    if bp < 0: bp = 0
+                    if bp >= segment.num_nt: bp = segment.num_nt-1
+                    nt.beta = bp_to_zidx[bp]
+                    nt.parent.occupancy = segment.occupancy
+                seg._generate_nucleotide_callbacks.append(atomic_callback)
+
+
                 helices[hid].append( seg )
                 helix_ranges[hid].append( (zid1,zid2) )
 
diff --git a/mrdna/segmentmodel.py b/mrdna/segmentmodel.py
index 3c7c05c..2141601 100644
--- a/mrdna/segmentmodel.py
+++ b/mrdna/segmentmodel.py
@@ -196,10 +196,10 @@ class ConnectableElement():
     #     return [c for c in self.connections if c.A == loc or c.B == loc]
 
 class SegmentParticle(PointParticle):
-    def __init__(self, type_, position, name="A", segname="A", **kwargs):
+    def __init__(self, type_, position, name="A", **kwargs):
         self.name = name
         self.contour_position = None
-        PointParticle.__init__(self, type_, position, name=name, segname=segname, **kwargs)
+        PointParticle.__init__(self, type_, position, name=name, **kwargs)
         self.intrahelical_neighbors = []
         self.other_neighbors = []
         self.locations = []
@@ -355,14 +355,17 @@ class Segment(ConnectableElement, Group):
     def __init__(self, name, num_nt, 
                  start_position = None,
                  end_position = None, 
-                 segment_model = None):
+                 segment_model = None,
+                 **kwargs):
 
         if start_position is None: start_position = np.array((0,0,0))
 
-        Group.__init__(self, name, children=[])
+        Group.__init__(self, name, children=[], **kwargs)
         ConnectableElement.__init__(self, connection_locations=[], connections=[])
 
-        self.resname = name
+        if 'segname' not in kwargs:
+            self.segname = name
+        # self.resname = name
         self.start_orientation = None
         self.twist_per_nt = 0
 
@@ -381,6 +384,10 @@ class Segment(ConnectableElement, Group):
         self.start_position = start_position
         self.end_position = end_position
 
+        ## Used to assign cadnano names to beads
+        self._generate_bead_callbacks = []
+        self._generate_nucleotide_callbacks = []
+
         ## Set up interpolation for positions
         self._set_splines_from_ends()
 
@@ -477,7 +484,7 @@ class Segment(ConnectableElement, Group):
     def _generate_one_bead(self, contour_position, nts):
         raise NotImplementedError
 
-    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale):
+    def _generate_atomic_nucleotide(self, contour_position, is_fwd, seq, scale, strand_segment):
         """ Seq should include modifications like 5T, T3 Tsinglet; direction matters too """
 
         # print("Generating nucleotide at {}".format(contour_position))
@@ -507,7 +514,6 @@ class Segment(ConnectableElement, Group):
         nt_dict = canonicalNtFwd if is_fwd else canonicalNtRev
 
         atoms = nt_dict[ key ].generate() # TODO: clone?
-                        
         atoms.orientation = orientation.dot(atoms.orientation)
         if isinstance(self, SingleStrandedSegment):
             if scale is not None and scale != 1:
@@ -526,6 +532,12 @@ class Segment(ConnectableElement, Group):
                         a.position = scale*(a.position-r0) + r0
                         a.beta = 0
             atoms.position = pos
+        
+        atoms.contour_position = contour_position
+        strand_segment.add(atoms)
+
+        for callback in self._generate_nucleotide_callbacks:
+            callback(atoms)
 
         return atoms
 
@@ -820,6 +832,9 @@ class Segment(ConnectableElement, Group):
         #     pdb.set_trace()
         self._rebuild_children(tmp_children)
 
+        for callback in self._generate_bead_callbacks:
+            callback(self)
+
     def _regenerate_beads(self, max_nts_per_bead=4, ):
         ...
     
@@ -835,14 +850,16 @@ class DoubleStrandedSegment(Segment):
                  local_twist = False,
                  num_turns = None,
                  start_orientation = None,
-                 twist_persistence_length = 90 ):
+                 twist_persistence_length = 90,
+                 **kwargs):
         
         self.helical_rise = 10.44
         self.distance_per_nt = 3.4
         Segment.__init__(self, name, num_bp,
                          start_position,
                          end_position, 
-                         segment_model)
+                         segment_model,
+                         **kwargs)
         self.num_bp = self.num_nt
 
         self.local_twist = local_twist
@@ -979,14 +996,16 @@ class SingleStrandedSegment(Segment):
 
     def __init__(self, name, num_nt, start_position = None,
                  end_position = None, 
-                 segment_model = None):
+                 segment_model = None,
+                 **kwargs):
 
         if start_position is None: start_position = np.array((0,0,0))
         self.distance_per_nt = 5
         Segment.__init__(self, name, num_nt, 
                          start_position,
                          end_position, 
-                         segment_model)
+                         segment_model,
+                         **kwargs)
 
         self.start = self.start5 = Location( self, address=0, type_= "end5" ) # TODO change type_?
         self.end = self.end3 = Location( self, address=1, type_ = "end3" )
@@ -1255,12 +1274,9 @@ class Strand(Group):
                 if strand_segment_count == len(s.strand_segments) and c == 1 and not self.is_circular:
                     seq = seq+"3"
 
-                nt = seg._generate_atomic_nucleotide( c, s.is_fwd, seq, scale )
-                # if s.is_fwd:                    
-                # else:
-                #     nt = seg._generate_atomic_nucleotide( c, s.is_fwd, "A" )
-
+                nt = seg._generate_atomic_nucleotide( c, s.is_fwd, seq, scale, s )
                 s.add(nt)
+
                 ## Join last basepairs
                 if last is not None:
                     self.link_nucleotides(last,nt)
-- 
GitLab