From 28d4d2776fc1e65d3665f00540605360719673fd Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Wed, 11 Mar 2020 11:38:41 -0500
Subject: [PATCH] Updated routines for removing DNA to optionally allow
 removing locations as well

---
 mrdna/segmentmodel.py | 77 +++++++++++++++++++++++++++++++++++--------
 1 file changed, 63 insertions(+), 14 deletions(-)

diff --git a/mrdna/segmentmodel.py b/mrdna/segmentmodel.py
index 99315a5..f758df5 100644
--- a/mrdna/segmentmodel.py
+++ b/mrdna/segmentmodel.py
@@ -66,9 +66,6 @@ class Location():
         self.connection = None
         self.is_3prime_side_of_connection = None
 
-        self.prev_in_strand = None
-        self.next_in_strand = None
-        
         self.combine = None     # some locations might be combined in bead model 
 
     def get_connected_location(self):
@@ -93,6 +90,12 @@ class Location():
                 raise
         return pos
 
+    def delete(self):
+        if self.connection is not None:
+            self.connection.delete()
+        if self.container is not None:
+            self.container.locations.remove(self)
+
     def __repr__(self):
         if self.on_fwd_strand:
             on_fwd = "on_fwd_strand"
@@ -515,7 +518,7 @@ class Segment(ConnectableElement, Group):
         for l,p in zip(self.locations, new_nt_positions):
             l.address = self.nt_pos_to_contour(p)
 
-    def remove_dna(self, first_nt: int, last_nt: int):
+    def remove_dna(self, first_nt:int, last_nt:int, remove_locations:bool = False):
         """ Removes nucleotides between first_nt and last_nt, inclusive """
         assert(np.isclose(np.around(first_nt),first_nt))
         assert(np.isclose(np.around(last_nt),last_nt))
@@ -530,28 +533,69 @@ class Segment(ConnectableElement, Group):
         if first_nt == last_nt:
             return
 
+        def _transform_contour_positions(contours):
+            nt = self.contour_to_nt_pos(contours)
+            first = first_nt if first_nt > 0 else -np.inf
+            last = last_nt if last_nt < self.num_nt-1 else np.inf
+            bad_ids = (nt >= first) * (nt <= last)
+            nt[nt>last] = nt[nt>last]-removed_nt
+            nt[bad_ids] = np.nan
+            return self.nt_pos_to_contour(nt)* self.num_nt/num_nt
+
         first_nt = np.around(first_nt)
         last_nt = np.around(last_nt)
 
-        nt_positions = self._get_location_positions()
-
-        bad_locations = list(filter(lambda p: p >= first_nt and p <= last_nt, nt_positions))
-        if len(bad_locations) > 0:
-            raise Exception("Attempted to remove DNA containing locations {} from {} between {} and {}".format(bad_locations,self,first_nt,last_nt))
-
         removed_nt = last_nt-first_nt+1
-        new_nt_positions = [p if p <= last_nt else p-removed_nt for p in nt_positions]
         num_nt = self.num_nt-removed_nt
 
+        c = np.array([l.address for l in self.locations])
+        new_c = _transform_contour_positions(c)
+
+        if np.any(np.isnan(new_c)) and not remove_locations:
+            raise Exception("Attempted to remove DNA containing locations from {} between {} and {}".format(self,first_nt,last_nt))
+
+        for l,v in zip(list(self.locations),new_c):
+            if np.isnan(v):
+                l.delete()
+            else:
+                l.address = v
+
+        assert( len(self.locations) == np.logical_not(np.isnan(new_c)).sum() )
+        tmp_nt = np.array([l.address*num_nt+0.5 for l in self.locations])
+        assert(np.all( (tmp_nt < 1) + (tmp_nt > num_nt-1) + np.isclose((np.round(tmp_nt) - tmp_nt), 0) ))
+
         if self.sequence is not None and len(self.sequence) == self.num_nt:
             self.sequence = [s for s,i in zip(self.sequence,range(self.num_nt)) 
                                 if i < first_nt or i > last_nt]
             assert( len(self.sequence) == num_nt )
 
+        ## Rescale splines
+        def _rescale_splines(u, get_coord_function, set_spline_function):
+            new_u = _transform_contour_positions(u)
+            ids = np.logical_not(np.isnan(new_u))
+            new_u = new_u[ids]
+            new_p = get_coord_function(u[ids])
+            set_spline_function( new_u, new_p )
+
+        _rescale_splines(self.position_spline_params[1],
+                         self.contour_to_position,
+                         self.set_splines)
+
+        if self.quaternion_spline_params is not None:
+            def _contour_to_quaternion(contours):
+                os = [self.contour_to_orientation(c) for c in contours]
+                qs = [quaternion_from_matrix( o ) for o in os]
+                return np.array(qs)
+            _rescale_splines(self.quaternion_spline_params[1],
+                             self.contour_to_quaternion,
+                             self.set_splines)
+
         self.num_nt = num_nt
 
-        for l,p in zip(self.locations, new_nt_positions):
-            l.address = self.nt_pos_to_contour(p)
+    def delete(self):
+        for c,loc,other in self.get_connections_and_locations():
+            c.delete()
+        self.parent.segments.remove(self)
 
     def __filter_contours(contours, positions, position_filter, contour_filter):
         u = contours
@@ -1763,9 +1807,12 @@ class SegmentModel(ArbdModel):
                 except:
                     pass
                 self.segments.append(newseg)
+                newseg.parent = self
             if include_strands:
                 for s in other.strands:
-                    self.strands.append(deepcopy(s))
+                    newstrand = deepcopy(s)
+                    self.strands.append(newstrand)
+                    newstrand.parent = self
         else:
             for s in other.segments:
                 try:
@@ -1773,9 +1820,11 @@ class SegmentModel(ArbdModel):
                 except:
                     pass
                 self.segments.append(s)
+                s.parent = self
             if include_strands:
                 for s in other.strands:
                     self.strands.append(s)
+                    s.parent = self
         self._clear_beads()
 
     def update(self, segment , copy=False):
-- 
GitLab