From 9c5c3cdfaf396af2c2a6416058d36b76d788025a Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Tue, 3 Jul 2018 12:43:11 -0500
Subject: [PATCH] Removed mutable default arguments

---
 arbdmodel.py    | 12 ++++++------
 segmentmodel.py | 30 ++++++++++++++++++------------
 2 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/arbdmodel.py b/arbdmodel.py
index 4405da7..0b0a9a1 100644
--- a/arbdmodel.py
+++ b/arbdmodel.py
@@ -57,11 +57,11 @@ class Transformable():
         return obj
 
 class Parent():
-    def __init__(self, children=[], remove_duplicate_bonded_terms=False):
+    def __init__(self, children=None, remove_duplicate_bonded_terms=False):
         self.children = []
-        for x in children:
-            self.add(x)
-        # self.children = children
+        if children is not None:
+            for x in children:
+                self.add(x)
         
         self.remove_duplicate_bonded_terms = remove_duplicate_bonded_terms
         self.bonds = []
@@ -384,7 +384,7 @@ class PointParticle(Transformable, Child):
 
 
 class Group(Transformable, Parent, Child):
-    def __init__(self, name=None, children = [], parent=None, 
+    def __init__(self, name=None, children = None, parent=None, 
                  position = np.array((0,0,0)),
                  orientation = np.array(((1,0,0),(0,1,0),(0,0,1))),
                  remove_duplicate_bonded_terms = False,
@@ -427,7 +427,7 @@ class Group(Transformable, Parent, Child):
         
 class PdbModel(Transformable, Parent):
 
-    def __init__(self, children=[], dimensions=None, remove_duplicate_bonded_terms=False):
+    def __init__(self, children=None, dimensions=None, remove_duplicate_bonded_terms=False):
         Transformable.__init__(self,(0,0,0))
         Parent.__init__(self, children, remove_duplicate_bonded_terms)
         self.dimensions = dimensions
diff --git a/segmentmodel.py b/segmentmodel.py
index 96657b2..a26496b 100644
--- a/segmentmodel.py
+++ b/segmentmodel.py
@@ -112,13 +112,15 @@ class Connection():
 # class ConnectableElement(Transformable):
 class ConnectableElement():
     """ Abstract base class """
-    ## TODO: eliminate mutable default arguments
-    def __init__(self, connection_locations=[], connections=[]):
+    def __init__(self, connection_locations=None, connections=None):
+        if connection_locations is None: connection_locations = []
+        if connections is None: connections = []
+
         ## TODO decide on names
         self.locations = self.connection_locations = connection_locations
         self.connections = connections
 
-    def get_locations(self, type_=None, exclude=[]):
+    def get_locations(self, type_=None, exclude=()):
         locs = [l for l in self.connection_locations if (type_ is None or l.type_ == type_) and l.type_ not in exclude]
         counter = dict()
         for l in locs:
@@ -149,7 +151,7 @@ class ConnectableElement():
             loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
         return loc
 
-    def get_connections_and_locations(self, connection_type=None, exclude=[]):
+    def get_connections_and_locations(self, connection_type=None, exclude=()):
         """ Returns a list with each entry of the form:
             connection, location_in_self, location_in_other """
         type_ = connection_type
@@ -255,8 +257,9 @@ class SegmentParticle(PointParticle):
 
             ## depth-first search
             ## TODO cache distances to nearby locations?
-            def descend_search_tree(seg, contour_in_seg, distance=0, visited_segs=[]):
+            def descend_search_tree(seg, contour_in_seg, distance=0, visited_segs=None):
                 nonlocal cutoff
+                if visited_segs is None: visited_segs = []
 
                 if seg == target_seg:
                     # pdb.set_trace()
@@ -330,10 +333,12 @@ class Segment(ConnectableElement, Group):
                               )
 
     def __init__(self, name, num_nts, 
-                 start_position = np.array((0,0,0)),
+                 start_position = None,
                  end_position = None, 
                  segment_model = None):
 
+        if start_position is None: start_position = np.array((0,0,0))
+
         Group.__init__(self, name, children=[])
         ConnectableElement.__init__(self, connection_locations=[], connections=[])
 
@@ -810,7 +815,7 @@ class DoubleStrandedSegment(Segment):
             end3 = end3.end3
         self._connect_ends( self.end5, end3, type_, force_connection = force_connection )
 
-    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
+    def add_crossover(self, nt, other, other_nt, strands_fwd=(True,False), nt_on_5prime=True):
         """ Add a crossover between two helices """
         ## Validate other, nt, other_nt
         ##   TODO
@@ -885,10 +890,11 @@ class SingleStrandedSegment(Segment):
     """ Class that describes a segment of ssDNA. When built from
     cadnano models, should not span helices """
 
-    def __init__(self, name, num_nts, start_position = np.array((0,0,0)),
+    def __init__(self, name, num_nts, start_position = None,
                  end_position = None, 
                  segment_model = None):
 
+        if start_position is None: start_position = np.array((0,0,0))
         self.distance_per_nt = 5
         Segment.__init__(self, name, num_nts, 
                          start_position,
@@ -923,7 +929,7 @@ class SingleStrandedSegment(Segment):
             conn = Connection( other, my_end, type_="intrahelical" )
             other.container._connect( self, conn, in_3prime_direction=True )
 
-    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False], nt_on_5prime=True):
+    def add_crossover(self, nt, other, other_nt, strands_fwd=(True,False), nt_on_5prime=True):
         """ Add a crossover between two helices """
         ## Validate other, nt, other_nt
         ##   TODO
@@ -1176,7 +1182,7 @@ class SegmentModel(ArbdModel):
         self.useNonbondedScheme( nbDnaScheme )
 
 
-    def get_connections(self,type_=None,exclude=[]):
+    def get_connections(self,type_=None,exclude=()):
         """ Find all connections in model, without double-counting """
         added=set()
         ret=[]
@@ -1186,7 +1192,7 @@ class SegmentModel(ArbdModel):
             ret.extend( list(sorted(items,key=lambda x: x[1].address)) )
         return ret
     
-    def _recursively_get_beads_within_bonds(self,b1,bonds,done=[]):
+    def _recursively_get_beads_within_bonds(self,b1,bonds,done=()):
         ret = []
         done = list(done)
         done.append(b1)
@@ -1663,7 +1669,7 @@ class SegmentModel(ArbdModel):
         """ Add intrahelical exclusions """
         if self.DEBUG: print("Adding intrahelical exclusions")
         beads = dists.keys()
-        def _recursively_get_beads_within(b1,d,done=[]):
+        def _recursively_get_beads_within(b1,d,done=()):
             ret = []
             for b2,sep in dists[b1].items():
                 if b2 in done: continue
-- 
GitLab