From 2bd05afb75cc67c87f280fbf90aadc8f6644d0b5 Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Thu, 14 Jun 2018 14:37:58 -0500
Subject: [PATCH] Improved scripts for building cadnano structures; seems it
 will be neccesary to make contour go from nt[0]-0.5 to nt[-1]+0.5

---
 cadnano_segments.py | 498 ++++++++++++++++++++++++++++++++++++++++----
 run.py              |  15 +-
 segmentmodel.py     | 237 ++++++++++++++-------
 3 files changed, 632 insertions(+), 118 deletions(-)

diff --git a/cadnano_segments.py b/cadnano_segments.py
index 88961bb..e9d0f57 100644
--- a/cadnano_segments.py
+++ b/cadnano_segments.py
@@ -29,6 +29,167 @@ m13seq = readSequenceFile(sequenceFile='resources/cadnano2pdb.seq.m13mp18.dat')
 ##   - helices that should be stacked across an empty region (crossovers from and end in the helix to another end in the helix)
 ##   - circular constructs
 
+def combineRegionLists(loHi1,loHi2,intersect=False):
+
+    """Combines two lists of (lo,hi) pairs specifying integer
+    regions a single list of regions.  """
+
+    ## Validate input
+    for l in (loHi1,loHi2):
+        ## Assert each region in lists is sorted
+        for pair in l:
+            assert(len(pair) == 2)
+            assert(pair[0] <= pair[1])
+
+    if len(loHi1) == 0:
+        if intersect:
+            return []
+        else:
+            return loHi2
+    if len(loHi2) == 0:
+        if intersect:
+            return []
+        else:
+            return loHi1
+
+    ## Break input into lists of compact regions
+    compactRegions1,compactRegions2 = [[],[]]
+    for compactRegions,loHi in zip(
+            [compactRegions1,compactRegions2],
+            [loHi1,loHi2]):
+        tmp = []
+        lastHi = loHi[0][0]-1
+        for lo,hi in loHi:
+            if lo-1 != lastHi:
+                compactRegions.append(tmp)
+                tmp = []
+            tmp.append((lo,hi))
+            lastHi = hi
+        if len(tmp) > 0:
+            compactRegions.append(tmp)
+
+    ## Build result
+    result = []
+    region = []
+    i,j = [0,0]
+    compactRegions1.append([[1e10]])
+    compactRegions2.append([[1e10]])
+    while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
+        cr1 = compactRegions1[i]
+        cr2 = compactRegions2[j]
+
+        ## initialize region
+        if len(region) == 0:
+            if cr1[0][0] <= cr2[0][0]:
+                region = cr1
+                i += 1
+                continue
+            else:
+                region = cr2
+                j += 1
+                continue
+
+        if region[-1][-1] >= cr1[0][0]:
+            region = combineCompactRegionLists(region, cr1, intersect=False)
+            i+=1
+        elif region[-1][-1] >= cr2[0][0]:
+            region = combineCompactRegionLists(region, cr2, intersect=False)
+            j+=1
+        else:
+            result.extend(region)
+            region = []
+
+    assert( len(region) > 0 )
+    result.extend(region)
+    result = sorted(result)
+
+    # print("loHi1:",loHi1)
+    # print("loHi2:",loHi2)
+    # print(result,"\n")
+
+    if intersect:
+        lo = max( [loHi1[0][0], loHi2[0][0]] )
+        hi = min( [loHi1[-1][1], loHi2[-1][1]] )
+        result = [r for r in result if r[0] >= lo and r[1] <= hi]
+
+    return result
+
+def combineCompactRegionLists(loHi1,loHi2,intersect=False):
+
+    """Combines two lists of (lo,hi) pairs specifying regions within a
+    compact integer set into a single list of regions.
+
+    examples:
+    loHi1 = [[0,4],[5,7]]
+    loHi2 = [[2,4],[5,9]]
+    out = [(0, 1), (2, 4), (5, 7), (8, 9)]
+
+    loHi1 = [[0,3],[5,7]]
+    loHi2 = [[2,4],[5,9]]
+    out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
+    """
+
+    ## Validate input
+    for l in (loHi1,loHi2):
+        ## Assert each region in lists is sorted
+        for pair in l:
+            assert(len(pair) == 2)
+            assert(pair[0] <= pair[1])
+        ## Assert lists are compact
+        for pair1,pair2 in zip(l[::2],l[1::2]):
+            assert(pair1[1]+1 == pair2[0])
+
+    if len(loHi1) == 0:
+        if intersect:
+            return []
+        else:
+            return loHi2
+    if len(loHi2) == 0:
+        if intersect:
+            return []
+        else:
+            return loHi1
+
+    ## Find the ends of the region
+    lo = min( [loHi1[0][0], loHi2[0][0]] )
+    hi = max( [loHi1[-1][1], loHi2[-1][1]] )
+
+    ## Make a list of indices where each region will be split
+    splitAfter = []
+    for l,h in loHi2:
+        if l != lo:
+            splitAfter.append(l-1)
+        if h != hi:
+            splitAfter.append(h)
+
+    for l,h in loHi1:
+        if l != lo:
+            splitAfter.append(l-1)
+        if h != hi:
+            splitAfter.append(h)
+    splitAfter = sorted(list(set(splitAfter)))
+
+    # print("splitAfter:",splitAfter)
+
+    split=[]
+    last = -2
+    for s in splitAfter:
+        split.append(s)
+        last = s
+
+    # print("split:",split)
+    returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]
+
+    if intersect:
+        lo = max( [loHi1[0][0], loHi2[0][0]] )
+        hi = min( [loHi1[-1][1], loHi2[-1][1]] )
+        returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]
+
+    # print("loHi1:",loHi1)
+    # print("loHi2:",loHi2)
+    # print(returnList,"\n")
+    return returnList
+
 class cadnano_part(SegmentModel):
     def __init__(self, part, 
                  max_basepairs_per_bead = 7, 
@@ -92,6 +253,7 @@ class cadnano_part(SegmentModel):
 
         ## Get dictionary of insertions 
         self.insertions = allInsertions = part.insertions()
+        self.strand_occupancies = dict()
 
         ## Build helices 
         for hid in range(numHID):
@@ -116,14 +278,16 @@ class cadnano_part(SegmentModel):
             strandOccupancies = [ [x for i in range(0,len(e),2) 
                                    for x in range(e[i],e[i+1]+1)] 
                                   for e in (ends1,ends2) ]
+            self.strand_occupancies[hid] = strandOccupancies
 
-            for i in range( len(reqNodeZids)-1 ):
-                zid1,zid2 = reqNodeZids[i:i+2]
+            ends1,ends2 = [ [(e[i],e[i+1]) for i in range(0,len(e),2)] for e in (ends1,ends2) ]
 
-                ## Check that there are nts between zid1 and zid2 before adding nodes
+            regions = combineRegionLists(ends1,ends2)
+            if hid == 43:
+                import pdb
+            for zid1,zid2 in regions:
                 zMid = int(0.5*(zid1+zid2))
-                if zMid not in strandOccupancies[0] and zMid not in strandOccupancies[1]:
-                    continue
+                assert( zMid in strandOccupancies[0] or zMid in strandOccupancies[1] )
 
                 numBps = zid2-zid1+1
                 for ins_idx,length in insertions:
@@ -131,28 +295,33 @@ class cadnano_part(SegmentModel):
                     ##   (e.g. are insertions at the ends handled correctly?)
                     if ins_idx < zid1:
                         continue
-                    if ins_idx >= zid2:
+                    if ins_idx > zid2:
                         break
                     numBps += length
 
                 print("Adding helix with length",numBps,zid1,zid2)
 
                 kwargs = dict(name="%d-%d" % (hid,len(helices[hid])),
-                              num_nts = numBps,
-                              start_position = self._get_cadnano_position(hid,zid1),
-                              end_position   = self._get_cadnano_position(hid,zid2) )
+                              num_nts = numBps)
+
+                posargs1 = dict( start_position = self._get_cadnano_position(hid,zid1),
+                                 end_position   = self._get_cadnano_position(hid,zid2) )
+                posargs2 = dict( start_position = posargs1['end_position'],
+                                 end_position = posargs1['start_position'])
                 
                 ## TODO get sequence from cadnano api
                 if zMid in strandOccupancies[0] and zMid in strandOccupancies[1]:
-                    seg = DoubleStrandedSegment(**kwargs)
-                elif zMid in strandOccupancies[0] or zMid in strandOccupancies[1]:
-                    seg = SingleStrandedSegment(**kwargs)
-
+                    seg = DoubleStrandedSegment(**kwargs,**posargs1)
+                elif zMid in strandOccupancies[0]:
+                    seg = SingleStrandedSegment(**kwargs,**posargs1)
+                elif zMid in strandOccupancies[1]:
+                    seg = SingleStrandedSegment(**kwargs,**posargs2)
+                else:
+                    raise Exception("Segment could not be found")
                                      
                 helices[hid].append( seg )
                 helix_ranges[hid].append( (zid1,zid2) )
 
-
     def _get_cadnano_position(self, hid, zid):
         return [10*a for a in self.origins[hid]] + [-3.4*zid]
 
@@ -175,6 +344,191 @@ class cadnano_part(SegmentModel):
                 endList.append(lastStrand[1])
         return endLists
 
+    def combineRegionLists(self,loHi1,loHi2,intersect=False):
+
+        """Combines two lists of (lo,hi) pairs specifying integer
+        regions a single list of regions.  """
+
+        ## Validate input
+        for l in (loHi1,loHi2):
+            ## Assert each region in lists is sorted
+            for pair in l:
+                assert(len(pair) == 2)
+                assert(pair[0] <= pair[1])
+
+        if len(loHi1) == 0:
+            if intersect:
+                return []
+            else:
+                return loHi2
+        if len(loHi2) == 0:
+            if intersect:
+                return []
+            else:
+                return loHi1
+
+        ## Break input into lists of compact regions
+        compactRegions1,compactRegions2 = [[],[]]
+        for compactRegions,loHi in zip(
+                [compactRegions1,compactRegions2],
+                [loHi1,loHi2]):
+            tmp = []
+            lastHi = loHi[0][0]-1
+            for lo,hi in loHi:
+                if lo-1 != lastHi:
+                    compactRegions.append(tmp)
+                    tmp = []
+                tmp.append((lo,hi))
+                lastHi = hi
+            if len(tmp) > 0:
+                compactRegions.append(tmp)
+
+        ## Build result
+        result = []
+        region = []
+        i,j = [0,0]
+        compactRegions1.append([[1e10]])
+        compactRegions2.append([[1e10]])
+        while i < len(compactRegions1)-1 or j < len(compactRegions2)-1:
+            cr1 = compactRegions1[i]
+            cr2 = compactRegions2[j]
+
+            ## initialize region
+            if len(region) == 0:
+                if cr1[0][0] <= cr2[0][0]:
+                    region = cr1
+                    i += 1
+                    continue
+                else:
+                    region = cr2
+                    j += 1
+                    continue
+
+            if region[-1][-1] >= cr1[0][0]:
+                region = self.combineCompactRegionLists(region, cr1, intersect=False)
+                i+=1
+            elif region[-1][-1] >= cr2[0][0]:
+                region = self.combineCompactRegionLists(region, cr2, intersect=False)
+                j+=1
+            else:
+                result.extend(region)
+                region = []
+
+        assert( len(region) > 0 )
+        result.extend(region)
+        result = sorted(result)
+
+        # print("loHi1:",loHi1)
+        # print("loHi2:",loHi2)
+        # print(result,"\n")
+
+        if intersect:
+            lo = max( [loHi1[0][0], loHi2[0][0]] )
+            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
+            result = [r for r in result if r[0] >= lo and r[1] <= hi]
+
+        return result
+
+    def combineCompactRegionLists(self,loHi1,loHi2,intersect=False):
+        """Combines two lists of (lo,hi) pairs specifying regions within a
+        compact integer set into a single list of regions.
+
+        examples:
+        loHi1 = [[0,4],[5,7]]
+        loHi2 = [[2,4],[5,9]]
+        out = [(0, 1), (2, 4), (5, 7), (8, 9)]
+
+        loHi1 = [[0,3],[5,7]]
+        loHi2 = [[2,4],[5,9]]
+        out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
+        """
+
+        ## Validate input
+        for l in (loHi1,loHi2):
+            ## Assert each region in lists is sorted
+            for pair in l:
+                assert(len(pair) == 2)
+                assert(pair[0] <= pair[1])
+            ## Assert lists are compact
+            for pair1,pair2 in zip(l[::2],l[1::2]):
+                assert(pair1[1]+1 == pair2[0])
+
+        if len(loHi1) == 0:
+            if intersect:
+                return []
+            else:
+                return loHi2
+        if len(loHi2) == 0:
+            if intersect:
+                return []
+            else:
+                return loHi1
+
+        ## Find the ends of the region
+        lo = min( [loHi1[0][0], loHi2[0][0]] )
+        hi = max( [loHi1[-1][1], loHi2[-1][1]] )
+
+        ## Make a list of indices where each region will be split
+        splitAfter = []
+        for l,h in loHi2:
+            if l != lo:
+                splitAfter.append(l-1)
+            if h != hi:
+                splitAfter.append(h)
+
+        for l,h in loHi1:
+            if l != lo:
+                splitAfter.append(l-1)
+            if h != hi:
+                splitAfter.append(h)
+        splitAfter = sorted(list(set(splitAfter)))
+
+        # print("splitAfter:",splitAfter)
+
+        split=[]
+        last = -2
+        for s in splitAfter:
+            split.append(s)
+            last = s
+
+        # print("split:",split)
+        returnList = [(i+1,j) if i != j else (i,j) for i,j in zip([lo-1]+split,split+[hi])]
+
+        if intersect:
+            lo = max( [loHi1[0][0], loHi2[0][0]] )
+            hi = min( [loHi1[-1][1], loHi2[-1][1]] )
+            returnList = [r for r in returnList if r[0] >= lo and r[1] <= hi]
+
+        # print("loHi1:",loHi1)
+        # print("loHi2:",loHi2)
+        # print(returnList,"\n")
+        return returnList
+
+
+    def _helix_strands_to_segment_ranges(self, helix_strands):
+        """Utility method to convert cadnano strand lists into list of
+        indices of terminal points"""
+        def _join(strands):
+            ends = []
+            lastEnd = None
+            for start,end in strands:
+                if lastEnd is None:
+                    ends.append([start])
+                elif lastEnd != start-1:
+                    ends[-1].append(lastEnd)
+                    ends.append([start])
+                lastEnd = end
+            if lastEnd is not None:
+                ends[-1].append(lastEnd)
+            return ends
+
+        s1,s2 = [_join(s) for s in helix_strands]
+        i = j = 0
+        
+        ## iterate through strands
+        while i < len(s1) and j < len(s2):
+            min(s1[i][0],s2[j][0])
+
     def _get_segment(self, hid, zid):
         ## TODO: rename these variables to segments
         segs = self.helices[hid]
@@ -186,6 +540,7 @@ class cadnano_part(SegmentModel):
         raise Exception("Could not find segment in helix %d at position %d" % (hid,zid))
                 
     def _get_nucleotide(self, hid, zid):
+        raise Exception("Deprecated")
         seg = self._get_segment(hid,zid)
         sid = self.helices[hid].index(seg)
         zmin,zmax = self.helix_ranges[hid][sid]
@@ -199,42 +554,80 @@ class cadnano_part(SegmentModel):
                 nt += self.insertions[hid][i].length()
         return nt
 
+    def _get_segment_nucleotide(self, hid, zid):
+        """ returns segments and zero-based nucleotide index """
+        seg = self._get_segment(hid,zid)
+        sid = self.helices[hid].index(seg)
+        zmin,zmax = self.helix_ranges[hid][sid]
+
+        zMid = int(0.5*(zmin+zmax))
+        occ = self.strand_occupancies[hid]
+
+        ## TODO combine if/else when nested TODO is resolved
+        if (zMid not in occ[0]) and (zMid in occ[1]):
+            ## reversed ssDNA strand
+            nt = zmax-zid
+            # TODO: for i in range(zmin,zid+1): ?
+            for i in range(zid,zmax):
+                if i in self.insertions[hid]:
+                    nt += self.insertions[hid][i].length()
+        else:
+            ## normal condition
+            nt = zid-zmin
+            # TODO: for i in range(zmin,zid+1): ?
+            for i in range(zmin,zid):
+                if i in self.insertions[hid]:
+                    nt += self.insertions[hid][i].length()
+
+        ## Find insertions
+        return seg, nt
+
+        
+
     """ Routines to add connnections between helices """
     def _add_intrahelical_connections(self):
         for hid,segs in self.helices.items():
+            occ = self.strand_occupancies[hid]
             for i in range(len(segs)-1):
-                
-                seg1,seg2 = [segs[j] for j in (i,i+1)]
-                r1,r2 = [self.helix_ranges[hid][j] for j in (i,i+1)]
-                if r1[1]+1 == r2[0]:
-                    ## TODO: connect correct ends 
-                    seg1.connect_end3(seg2.start5)
-
+                    seg1,seg2 = [segs[j] for j in (i,i+1)]
+                    r1,r2 = [self.helix_ranges[hid][j] for j in (i,i+1)]
+                    if r1[1]+1 == r2[0]:
+                        ## TODO: handle nicks that are at intrahelical connections(?)
+                        zmid1 = int(0.5*(r1[0]+r1[1]))
+                        zmid2 = int(0.5*(r2[0]+r2[1]))
+                        ## TODO: validate
+                        if zmid1 in occ[0] and zmid2 in occ[0]:
+                            seg1.connect_end3(seg2.start5)
+
+                        if zmid1 in occ[1] and zmid2 in occ[1]:
+                            if zmid1 in occ[0]:
+                                seg2.connect_end3(seg1.end5)
+                            else:
+                                seg2.connect_end3(seg1.start5)
+                        
     def _add_crossovers(self):
         for h1,f1,z1,h2,f2,z2 in self.xover_list:
-            seg1 = self._get_segment(h1,z1)
-            seg2 = self._get_segment(h2,z2)
-            nt1 = self._get_nucleotide(h1,z1)
-            nt2 = self._get_nucleotide(h2,z2)
+            if (h1 == 52 or h2 == 52) and z1 == 221:
+                import pdb
+                pdb.set_trace()
+            seg1, nt1 = self._get_segment_nucleotide(h1,z1)
+            seg2, nt2 = self._get_segment_nucleotide(h2,z2)
             ## TODO: use different types of crossovers
             seg1.add_crossover(nt1,seg2,nt2,[f1,f2])
 
     def _add_prime_ends(self):
         for h,fwd,z in self._5prime_list:
-            seg = self._get_segment(h,z)
-            nt = self._get_nucleotide(h,z)
+            seg, nt = self._get_segment_nucleotide(h,z)
             print(seg.name,nt,fwd)
             seg.add_5prime(nt,fwd)
 
         for h,fwd,z in self._3prime_list:
-            seg = self._get_segment(h,z)
-            nt = self._get_nucleotide(h,z)
+            seg, nt = self._get_segment_nucleotide(h,z)
             seg.add_3prime(nt,fwd) 
    
     def get_bead(self, hid, zid):
         # get segment, get nucleotide,
-        seg = self._get_segment(hid,zid)
-        nt = self._get_nucleotide(hid,zid)
+        seg, nt = self._get_segment_nucleotide(h,z)
         # return seg.get_nearest_bead(seg,nt / seg.num_nts)
         return seg.get_nearest_bead(seg,nt / (seg.num_nts-1))
         
@@ -285,8 +678,8 @@ def run_simulation_protocol( output_name, job_id, json_data,
                              directory=None
                          ):
 
-    coarseSteps = 1e7
-    fineSteps = 1e6
+    coarseSteps = 1e5+1
+    fineSteps = 1e5+1
 
     ret = None
     d_orig = os.getcwd()
@@ -319,10 +712,12 @@ def run_simulation_protocol( output_name, job_id, json_data,
         if sequence is None or len(sequence) == 0:
             ## default m13mp18
             sequence = list(m13seq)
-        # print(sequence)
-        # for i in sequence:
-        #     print( i, i in ('A','T','C','G') )
-        model.strands[0].set_sequence(sequence)
+            try:
+                model.strands[0].set_sequence(sequence)
+            except:
+                ...
+        else:
+            model.strands[0].set_sequence(sequence)
 
         for s in model.segments:
             s.randomize_unset_sequence()
@@ -396,10 +791,27 @@ def run_simulation_protocol( output_name, job_id, json_data,
 # gpus = [0,1,2]
 # print(gpus)
 
-
 if __name__ == '__main__':
-    for f in glob('json/*'):
-        print("Working on {}".format(f))
-        out = re.match('json/(.*).json',f).group(1)
-        data = read_json_file(f)
-        run_simulation_protocol( out, "job-id", data, gpu=0 )
+    loHi1 = [[0,4],[5,7]]
+    loHi2 = [[2,4],[5,9]]
+    out = [(0, 1), (2, 4), (5, 7), (8, 9)]
+    print(loHi1)
+    print(loHi2)
+    print(combineRegionLists(loHi1,loHi2))
+    print(combineCompactRegionLists(loHi1,loHi2))
+
+    loHi1 = [[0,3],[5,7]]
+    loHi2 = [[2,4],[5,9]]
+    out = [(0, 1), (2, 3), (4, 4), (5, 7), (8, 9)]
+    print(loHi1)
+    print(loHi2)
+    print(combineRegionLists(loHi1,loHi2))
+    print(combineCompactRegionLists(loHi1,loHi2))
+
+    combineRegionLists
+    
+    # for f in glob('json/*'):
+    #     print("Working on {}".format(f))
+    #     out = re.match('json/(.*).json',f).group(1)
+    #     data = read_json_file(f)
+    #     run_simulation_protocol( out, "job-id", data, gpu=0 )
diff --git a/run.py b/run.py
index 851dccd..e27f4cd 100644
--- a/run.py
+++ b/run.py
@@ -1,12 +1,21 @@
 # -*- coding: utf-8 -*-
+import sys
 from glob import glob
 import re
 from cadnano_segments import read_json_file, run_simulation_protocol
 
 if __name__ == '__main__':
-
-    for f in glob('json/*'):
+    if len(sys.argv) > 1:
+        json_files = sys.argv[1:]
+    else:
+        print(len(sys.argv))
+        json_files = glob('json/*')
+    for f in json_files:
         print("Working on {}".format(f))
         out = re.match('json/(.*).json',f).group(1)
-        data = read_json_file(f)
+        try:
+            data = read_json_file(f)
+        except:
+            print("WARNING: skipping unreadable json file {}".format(f))
+            continue
         run_simulation_protocol( out, "job-id", data, gpu=0 )
diff --git a/segmentmodel.py b/segmentmodel.py
index b961485..58a2767 100644
--- a/segmentmodel.py
+++ b/segmentmodel.py
@@ -47,6 +47,7 @@ class Location():
         self.type_ = type_
         self.particle = None
         self.connection = None
+        self.is_3prime_side_of_connection = None
 
         self.prev_in_strand = None
         self.next_in_strand = None
@@ -59,8 +60,9 @@ class Location():
         else:
             return self.connection.other(self)
 
-    def set_connection(self,connection):
+    def set_connection(self, connection, is_3prime_side_of_connection):
         self.connection = connection # TODO weakref? 
+        self.is_3prime_side_of_connection = is_3prime_side_of_connection
 
     def __repr__(self):
         if self.on_fwd_strand:
@@ -106,7 +108,27 @@ class ConnectableElement():
             else:
                 counter[l] = 1
         assert( np.all( [counter[l] == 1 for l in locs] ) )
-        return locs                
+        return locs
+
+    def get_location_at(self, address, on_fwd_strand=True, new_type="crossover"):
+        loc = None
+        if (self.num_nts == 1):
+            # import pdb
+            # pdb.set_trace()
+            ## Assumes that intrahelical connections have been made before crossovers
+            for l in self.locations:
+                if l.on_fwd_strand == on_fwd_strand and l.connection is None:
+                    assert(loc is None)
+                    loc = l
+            assert( loc is not None )
+        else:
+            for l in self.locations:
+                if l.address == address and l.on_fwd_strand == on_fwd_strand:
+                    assert(loc is None)
+                    loc = l
+        if loc is None:
+            loc = Location( self, address=address, type_=new_type, on_fwd_strand=on_fwd_strand )
+        return loc
 
     def get_connections_and_locations(self, connection_type=None, exclude=[]):
         """ Returns a list with each entry of the form:
@@ -125,9 +147,13 @@ class ConnectableElement():
                     raise Exception("Object contains connection that fails to refer to object")
         return ret
 
-    def _connect(self, other, connection):
+    def _connect(self, other, connection, in_3prime_direction=None):
         ## TODO fix circular references        
         A,B = [connection.A, connection.B]
+        if in_3prime_direction is not None:
+            A.is_3prime_side_of_connection = not in_3prime_direction
+            B.is_3prime_side_of_connection = in_3prime_direction
+            
         A.connection = B.connection = connection
         self.connections.append(connection)
         other.connections.append(connection)
@@ -296,10 +322,19 @@ class Segment(ConnectableElement, Group):
         for c,loc,other in self.get_connections_and_locations():
             loc.particle = None
 
-    def contour_to_nt_pos(self,contour_pos):
-        return contour_pos*(self.num_nts-1)
+    def contour_to_nt_pos(self, contour_pos, round_nt=False):
+        nt = contour_pos*(self.num_nts-1)
+        if round_nt:
+            assert( (np.around(nt) - nt)**2 < 1e-3 )
+            nt = np.around(nt)
+        return nt
+
     def nt_pos_to_contour(self,nt_pos):
-        return nt_pos/(self.num_nts-1)
+        if self.num_nts == 1:
+            assert(nt_pos == 0)
+            return 0
+        else:
+            return nt_pos/(self.num_nts-1)
 
     def contour_to_position(self,s):
         p = interpolate.splev( s, self.position_spline_params )
@@ -439,25 +474,36 @@ class Segment(ConnectableElement, Group):
             yield c
 
     ## TODO rename
-    def get_strand_segment(self, nt_pos, is_fwd):
+    def get_strand_segment(self, nt_pos, is_fwd, move_at_least=0.5):
         """ Walks through locations, checking for crossovers """
+        # if self.name in ("6-1","1-1"):
+        #     import pdb
+        #     pdb.set_trace()
+        move_at_least = 0
 
         ## Iterate through locations
         locations = sorted(self.locations, key=lambda l:(l.address,not l.on_fwd_strand), reverse=(not is_fwd))
+        # print(locations)
+
         for l in locations:
-            pos = self.contour_to_nt_pos(l.address)
+            pos = self.contour_to_nt_pos(l.address, round_nt=True)
 
             ## DEBUG
 
             ## Skip locations encountered before our strand
-            tol = 0.1
-            if is_fwd:
-                if pos-nt_pos <= tol: continue 
-            elif pos-nt_pos >= -tol: continue
+            # tol = 0.1
+            # if is_fwd:
+            #     if pos-nt_pos <= tol: continue 
+            # elif   nt_pos-pos <= tol: continue
+            if (pos-nt_pos)*(2*is_fwd-1) < move_at_least: continue
+            ## TODO: remove move_at_least
+            if np.isclose(pos,nt_pos):
+                if l.is_3prime_side_of_connection: continue
 
             ## Stop if we found the 3prime end
             if l.on_fwd_strand == is_fwd and l.type_ == "3prime":
-                return pos, None, None, None
+                print("  found end at",l)
+                return pos, None, None, None, None
 
             ## Check location connections
             c = l.connection
@@ -466,16 +512,19 @@ class Segment(ConnectableElement, Group):
 
             ## Found a location on the same strand?
             if l.on_fwd_strand == is_fwd:
-                # print("  passing through",l)
-                # print("from {}, connection {} to {}".format(contour_pos,l,B))
-                Bpos = B.container.contour_to_nt_pos(B.address)
-                return pos, B.container, Bpos, B.on_fwd_strand
+                print("  passing through",l)
+                print("from {}, connection {} to {}".format(nt_pos,l,B))
+                Bpos = B.container.contour_to_nt_pos(B.address, round_nt=True)
+                return pos, B.container, Bpos, B.on_fwd_strand, 0.5
                 
             ## Stop at other strand crossovers so basepairs line up
             elif c.type_ == "crossover":
-                # print("  pausing at",l)
-                return pos, l.container, pos+(2*is_fwd-1), is_fwd
+                if nt_pos == pos: continue
+                print("  pausing at",l)
+                return pos, l.container, pos+(2*is_fwd-1), is_fwd, 0
 
+        import pdb
+        pdb.set_trace()
         raise Exception("Shouldn't be here")
         # print("Shouldn't be here")
         ## Made it to the end of the segment without finding a connection
@@ -584,10 +633,13 @@ class Segment(ConnectableElement, Group):
 
         if True:
             print("WARNING: DEBUG")
+            ## Remove duplicates, preserving order
             tmp = []
             for c in new_children:
                 if c not in tmp:
                     tmp.append(c)
+                else:
+                    print("  duplicate particle found!")
             new_children = tmp
 
         for b in new_children:
@@ -606,7 +658,7 @@ class Segment(ConnectableElement, Group):
 
     def _generate_beads(self, bead_model, max_basepairs_per_bead, max_nucleotides_per_bead):
 
-        """ Generate beads (positions, types, etcl) and bonds, angles, dihedrals, exclusions """
+        """ Generate beads (positions, types, etc) and bonds, angles, dihedrals, exclusions """
         ## TODO: decide whether to remove bead_model argument
         ##       (currently unused)
 
@@ -616,7 +668,6 @@ class Segment(ConnectableElement, Group):
         # existing_beads = [l.particle for l in locs if l.particle is not None]
         existing_beads = {l.particle for l in self.locations if l.particle is not None}
         existing_beads = sorted( list(existing_beads), key=lambda b: b.get_contour_position(self) )
-
         
         if len(existing_beads) != len(set(existing_beads)):
             pdb.set_trace()
@@ -627,13 +678,13 @@ class Segment(ConnectableElement, Group):
         ## TODOTODO: test 1 nt segments?
         if len(existing_beads) == 0 or existing_beads[0].get_contour_position(self) > 0:
             if len(existing_beads) > 0:            
-                assert(existing_beads[0].get_nt_position(self) > 1.5)
+                assert(existing_beads[0].get_nt_position(self) >= 0.5)
 
             b = self._generate_one_bead(0, 0)
             existing_beads = [b] + existing_beads
         if existing_beads[-1].get_contour_position(self) < 1:
-            # assert((1-existing_beads[0].get_contour_position(self))*(self.num_nts-1) > 1.5)
-            assert(self.num_nts-1-existing_beads[0].get_nt_position(self) > 1.5)
+            # assert((1-existing_beads[0].get_contour_position(self))*(self.num_nts-1) >= 0.5)
+            assert(self.num_nts-1-existing_beads[0].get_nt_position(self) >= 0.5)
             b = self._generate_one_bead(1, 0)
             existing_beads.append(b)
         assert(len(existing_beads) > 1)
@@ -764,26 +815,23 @@ class DoubleStrandedSegment(Segment):
         ## Validate other, nt, other_nt
         ##   TODO
 
-        ## Create locations, connections and add to segments
-        c = self.nt_pos_to_contour(nt)
-        assert(c >= 0 and c <= 1)
-        
-        def get_loc(seg, address, on_fwd_strand):
-            loc = None
-            for l in seg.locations:
-                if l.address == address and l.on_fwd_strand == on_fwd_strand:
-                    assert(loc is None)
-                    loc = l
-            if loc is None:
-                loc = Location( seg, address=address, type_="crossover", on_fwd_strand=on_fwd_strand )
-            return loc
+        if isinstance(other,SingleStrandedSegment):
+            other.add_crossover(other_nt, self, nt, strands_fwd[::-1])
+        else:
 
-        loc = get_loc(self, c, strands_fwd[0])
+            ## Create locations, connections and add to segments
+            c = self.nt_pos_to_contour(nt)
+            assert(c >= 0 and c <= 1)
 
-        c = other.nt_pos_to_contour(other_nt)
-        assert(c >= 0 and c <= 1)
-        other_loc = get_loc( other, c, strands_fwd[1] )
-        self._connect(other, Connection( loc, other_loc, type_="crossover" ))
+            loc = self.get_location_at(c, strands_fwd[0])
+
+            c = other.nt_pos_to_contour(other_nt)
+            assert(c >= 0 and c <= 1)
+            other_loc = other.get_location_at(c, strands_fwd[1])
+            self._connect(other, Connection( loc, other_loc, type_="crossover" ))
+            loc.is_3prime_side_of_connection = not strands_fwd[0]
+            other_loc.is_3prime_side_of_connection = not strands_fwd[1]
+            
 
     ## Real work
     def _connect_ends(self, end1, end2, type_, force_connection):
@@ -794,8 +842,10 @@ class DoubleStrandedSegment(Segment):
             assert( end.type_ in ("end3","end5") )
         assert( end1.type_ != end2.type_ )
         ## Create and add connection
-        end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ) )
-
+        if end2.type_ == "end3":
+            end1.container._connect( end2.container, Connection( end1, end2, type_=type_ ), in_3prime_direction=True )
+        else:
+            end2.container._connect( end1.container, Connection( end2, end1, type_=type_ ), in_3prime_direction=True )
     def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
         return int(contour*self.num_nts // max_basepairs_per_bead)
 
@@ -820,7 +870,6 @@ class DoubleStrandedSegment(Segment):
                                     contour_position=contour_position )
         self._add_bead(bead)
         return bead
-        
 
 class SingleStrandedSegment(Segment):
 
@@ -842,10 +891,10 @@ class SingleStrandedSegment(Segment):
         for l in (self.start5,self.end3):
             self.locations.append(l)
 
-    def connect_3end(self, end5, force_connection=False):
+    def connect_end3(self, end5, force_connection=False):
         self._connect_end( end5,  _5_to_3 = False, force_connection = force_connection )
 
-    def connect_5end(self, end3, force_connection=False):
+    def connect_5end(self, end3, force_connection=False): # TODO: change name or possibly deprecate
         self._connect_end( end3,  _5_to_3 = True, force_connection = force_connection )
 
     def _connect_end(self, other, _5_to_3, force_connection):
@@ -853,11 +902,47 @@ class SingleStrandedSegment(Segment):
         if _5_to_3 == True:
             my_end = self.end5
             assert( other.type_ == "end3" )
+            conn = Connection( my_end, other, type_="intrahelical" )
+            self._connect( other.container, conn, in_3prime_direction=True )
         else:
             my_end = self.end3
             assert( other.type_ == "end5" )
+            conn = Connection( other, my_end, type_="intrahelical" )
+            other.container._connect( self, conn, in_3prime_direction=True )
 
-        self._connect( other.container, Connection( my_end, other, type_="intrahelical" ) )
+    def add_crossover(self, nt, other, other_nt, strands_fwd=[True,False]):
+        """ Add a crossover between two helices """
+        ## Validate other, nt, other_nt
+        ##   TODO
+       
+        ## TODO: fix direction
+
+        c1 = self.nt_pos_to_contour(nt)
+        ## Ensure connections occur at ends, otherwise the structure doesn't make sense
+        assert(np.isclose(c1,0) or np.isclose(c1,1))
+        loc = self.get_location_at(c1, True)
+
+        c2 = other.nt_pos_to_contour(other_nt)
+        if isinstance(other,SingleStrandedSegment):
+            ## Ensure connections occur at opposing ends
+            assert(np.isclose(c2,0) or np.isclose(c2,1))
+            other_loc = other.get_location_at( c2, True )
+            assert( loc.type_ in ("end3","end5"))
+            assert( other_loc.type_ in ("end3","end5"))
+            if loc.type_ == "end3":
+                self.connect_end3( other_loc )
+            else:
+                assert( other_loc.type_ == "end3" )
+                other.connect_end3( self )
+
+        else:
+            assert( loc.type_ in ("end3","end5"))
+            assert(c2 >= 0 and c2 <= 1)
+            other_loc = other.get_location_at( c2, strands_fwd[1] )
+            if loc.type_ == "end3":
+                self._connect(other, Connection( loc, other_loc, type_="sscrossover" ), in_3prime_direction=True )
+            else:
+                other._connect(self, Connection( other_loc, loc, type_="sscrossover" ), in_3prime_direction=True )
 
     def _get_num_beads(self, contour, max_basepairs_per_bead, max_nucleotides_per_bead):
         return int(contour*self.num_nts // max_nucleotides_per_bead)
@@ -925,11 +1010,17 @@ class Strand(Group):
     def add_dna(self, segment, start, end, is_fwd):
         """ start/end should be provided expressed as contour_length, is_fwd tuples """
         if not (segment.contour_to_nt_pos(np.abs(start-end)) > 0.9):
+            import pdb
             pdb.set_trace()
         for s in self.strand_segments:
             if s.segment == segment and s.is_fwd == is_fwd:
-                assert( s.start not in (start,end) )
-                assert( s.end not in (start,end) )
+                # assert( s.start not in (start,end) )
+                # assert( s.end not in (start,end) )
+                if s.start in (start,end) or s.end in (start,end):
+                    import pdb
+                    pdb.set_trace()
+                    print("  CIRCULAR DNA")
+
         s = StrandInSegment( segment, start, end, is_fwd )
         self.add( s )
         self.num_nts += s.num_nts
@@ -1307,27 +1398,29 @@ class SegmentModel(ArbdModel):
             s1,s2 = [l.container for l in (A,B)]
             if A.particle is not None and B.particle is not None:
                 continue
-            assert( A.particle is None )
-            assert( B.particle is None )
+            # assert( A.particle is None )
+            # assert( B.particle is None )
 
             ## TODO: offload the work here to s1/s2 (?)
             a1,a2 = [l.address   for l in (A,B)]
 
-            b = s1.get_nearest_bead(a1)
-            if b is not None and s1.contour_to_nt_pos(np.abs(b.contour_position-a1)) < 1.9:
-                ## combine beads
-                b.contour_position = 0.5*(b.contour_position + a1) # avg position
-                A.particle = b
-            else:
-                A.particle = s1._generate_one_bead(a1,0)
-
-            b = s2.get_nearest_bead(a2)
-            if b is not None and s2.contour_to_nt_pos(np.abs(b.contour_position-a2)) < 1.9:
-                ## combine beads
-                b.contour_position = 0.5*(b.contour_position + a2) # avg position
-                B.particle = b
-            else:
-                B.particle = s2._generate_one_bead(a2,0)
+            if A.particle is None:
+                b = s1.get_nearest_bead(a1)
+                if b is not None and s1.contour_to_nt_pos(np.abs(b.contour_position-a1)) < 1.9:
+                    ## combine beads
+                    b.contour_position = 0.5*(b.contour_position + a1) # avg position
+                    A.particle = b
+                else:
+                    A.particle = s1._generate_one_bead(a1,0)
+
+            if B.particle is None:
+                b = s2.get_nearest_bead(a2)
+                if b is not None and s2.contour_to_nt_pos(np.abs(b.contour_position-a2)) < 1.9:
+                    ## combine beads
+                    b.contour_position = 0.5*(b.contour_position + a2) # avg position
+                    B.particle = b
+                else:
+                    B.particle = s2._generate_one_bead(a2,0)
 
         """ Some tests """
         for c,A,B in self.get_connections("intrahelical"):
@@ -1758,27 +1851,27 @@ class SegmentModel(ArbdModel):
         #     if c[0]
 
         """ Build strands from connectivity of helices """
-        def _recursively_build_strand(strand, segment, pos, is_fwd, mycounter=0):
+        def _recursively_build_strand(strand, segment, pos, is_fwd, mycounter=0, move_at_least=0.5):
             mycounter+=1
             if mycounter > 1000:
                 import pdb
                 pdb.set_trace()
             s,seg = [strand, segment]
 
-            end_pos, next_seg, next_pos, next_dir = seg.get_strand_segment(pos, is_fwd)
+            end_pos, next_seg, next_pos, next_dir, move_at_least = seg.get_strand_segment(pos, is_fwd, move_at_least)
             s.add_dna(seg, pos, end_pos, is_fwd)
 
             if next_seg is not None:
                 # print("  next_dir: {}".format(next_dir))
-                _recursively_build_strand(s, next_seg, next_pos, next_dir, mycounter)
+                _recursively_build_strand(s, next_seg, next_pos, next_dir, mycounter, move_at_least)
 
         for seg in self.segments:
             locs = seg.get_5prime_locations()
             if locs is None: continue
             # for pos, is_fwd in locs:
             for l in locs:
-                # print("Tracing",l)
-                pos = seg.contour_to_nt_pos(l.address)
+                print("Tracing",l)
+                pos = seg.contour_to_nt_pos(l.address, round_nt=True)
                 is_fwd = l.on_fwd_strand
                 s = Strand()
                 _recursively_build_strand(s, seg, pos, is_fwd)
-- 
GitLab