From 861434ba5bb2aa44fda455baf24250d98fc2242d Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Thu, 21 Mar 2019 17:03:35 -0500
Subject: [PATCH] Slightly improved pdb reader basepair and stack search

---
 mrdna/readers/segmentmodel_from_pdb.py | 138 +++++++++++++------------
 1 file changed, 74 insertions(+), 64 deletions(-)

diff --git a/mrdna/readers/segmentmodel_from_pdb.py b/mrdna/readers/segmentmodel_from_pdb.py
index 53b0c31..70032d2 100644
--- a/mrdna/readers/segmentmodel_from_pdb.py
+++ b/mrdna/readers/segmentmodel_from_pdb.py
@@ -70,37 +70,31 @@ def read_ter_from_pdb( *args ):
                 last = line
     return ter
 
-def residues_within( cutoff, coords1, coords2, sel1, sel2, universe, selection_text ):
-    assert(len(sel1.atoms) > 0 and len(sel2.atoms) > 0)
-    distance_matrix = distance_array( coords1, coords2 )
-    if sel1 == sel2:
-        ## ignore diagonal
-        distance_matrix = distance_matrix + np.eye(len(distance_matrix))*10*cutoff
-    # arr_idxs = np.where((distance_matrix > 0) * (distance_matrix < cutoff))
-    arr_idxs = np.where((distance_matrix < cutoff))
-    arr_idx1,arr_idx2 = arr_idxs
+def residues_within( cutoff, coords1, coords2, sel1, sel2 ):
+    """
+    Returns lists of index of residues within sel1 and sel2 such that
+    sel1.residue[arr_idx1[i]] is within cutoff of sel2.residue[arr_idx2[i]] for all i
 
-    atoms = universe.select_atoms(selection_text)
-    within = -np.ones( atoms.resindices[-1]+1, dtype=np.int )
+    Also returns distance_matrix
 
-    if len(arr_idx1) == len(np.unique(arr_idx1)) and len(arr_idx2) == len(np.unique(arr_idx2)):
-        within[sel1.resindices[arr_idx1]] = sel2.resindices[arr_idx2]
-    else:
-        unique,ids,counts = np.unique(arr_idx1, return_index=True, return_counts=True)
-        res1 = sel1.resindices[unique[counts == 1]]
-        res2 = sel2.resindices[arr_idx2[ids[counts == 1]]]
-        within[res1] = res2
-        for i in np.where(counts > 1)[0]:
-            ids2 = arr_idx1 == unique[i]  
-            ij = np.argmin(distance_matrix[arr_idx1[ids2],arr_idx2[ids2]])
-            res1 = sel1.resindices[arr_idx1[ids2][ij]]
-            res2 = sel2.resindices[arr_idx2[ids2][ij]]
-            within[res1] = res2
+    Ignores residues that are the same in sel1 and sel2
+    """
 
+    assert(len(sel1.atoms) > 0 and len(sel2.atoms) > 0)
+    assert(len(sel1.residues) == len(coords1))
+    assert(len(sel2.residues) == len(coords2))
 
+    distance_matrix = distance_array( coords1, coords2 )
 
-        # within[sel1.resindices[arr_idx1]] = sel2.resindices[arr_idx2]
-    return within
+    ## Ignore comparisons with self
+    res_diff = sel1.residues.resindices[:,np.newaxis] - sel2.residues.resindices[np.newaxis,:]
+    assert( res_diff.shape == distance_matrix.shape )
+    distance_matrix[ res_diff == 0 ] = distance_matrix[ res_diff == 0 ] + 10 *cutoff
+
+    arr_idxs = np.where((distance_matrix < cutoff))
+    arr_idx1,arr_idx2 = arr_idxs
+
+    return arr_idx1, arr_idx2, distance_matrix
 
 def find_base_position_orientation( u, selection_text='all' ):
     ## Find orientation and center of each nucleotide
@@ -133,45 +127,54 @@ def find_basepairs( u, centers, transforms, selection_text='all' ):
     bonds = { b:"resname {} and name {}".format(resnames[b],bp_bond_atoms[b])
               for b in bases }
 
+    all_ = u.select_atoms(selection_text)
     sel1 = u.select_atoms("({}) and (({}) or ({}))".format(selection_text,
                                                            bonds['A'],bonds['C'])) 
     sel2 = u.select_atoms("({}) and (({}) or ({}))".format(selection_text,
                                                            bonds['T'],bonds['G']))
+    allres = all_.residues.resindices
+    basepairs = -np.ones(len(allres), dtype=np.int)
+
+    ## Find likely candidates for basepairs
+    idx1, idx2, dists = residues_within( cutoff = 3.8,
+                                         coords1=sel1.positions,
+                                         coords2=sel2.positions,
+                                         sel1 = sel1, sel2 = sel2 )
     
-    possible_basepairs = residues_within( cutoff = 3.8,
-                                          coords1=sel1.positions,
-                                          coords2=sel2.positions,
-                                          sel1 = sel1, sel2 = sel2,
-                                          universe = u,
-                                          selection_text = selection_text )
-    ## Filter by expected position
-    ids = possible_basepairs >= 0
-    possible_resI = np.where( ids )[0]
-    possible_resJ = possible_basepairs[ ids ].astype(int) 
-    resI,resJ = [[],[]]
-    for i,j,R1,R2,c1,c2 in zip(possible_resI,possible_resJ,
-                            transforms[possible_resI],
-                            transforms[possible_resJ],
-                            centers[possible_resI],
-                            centers[possible_resJ]):
-        c2_expected = c1 + ref_bp_position.dot(R1)
-        # fh.write("graphics top cylinder {{{}}} {{{}}} radius 0.2 resolution 30\n".format(printv(c1),printv(c2)))
-
-        if np.linalg.norm(c2_expected-c2) < 3.5:
+    ## Find mapping from idx1/idx2 to index of same residue in all_
+    def get_idx_to_all(sel):
+        selres = sel.residues.resindices
+        return np.searchsorted( allres, selres )
+    idx1_to_all, idx2_to_all = [ get_idx_to_all(sel) for sel in (sel1,sel2) ]
+
+    ## Rank possible basepairs by expected position and angles between base orientation
+    rank = 100*np.ones(dists.shape)
+
+    for i in np.unique(idx1):
+        all1 = idx1_to_all[i]
+        R1 = transforms[all1]
+        c1 = centers[all1]
+
+        for j in idx2[ idx1 == i ]:
+            all2 = idx2_to_all[j]
+            R2 = transforms[all2]
+            c2 = centers[all2]
+
+            c2_expected = c1 + ref_bp_position.dot(R1)
+            dist = np.linalg.norm(c2_expected-c2)
             angle= R1.T.dot(np.array((0,0,1))).dot(R2.T.dot(np.array((0,0,1))))
-            if angle < -0.7:
-                resI.append(i)
-                resJ.append(j)
-
-
-    resI = np.array(resI, dtype=np.int)
-    resJ = np.array(resJ, dtype=np.int)
-
-    ## Add reciprocal basepairs
-    assert( (possible_basepairs[resJ]  == -1).all() )
-    basepairs = -np.ones(possible_basepairs.shape, dtype=np.int)
-    basepairs[resI] = resJ
-    basepairs[resJ] = resI
+            rank[i,j] = dist**2 + 135*(1+angle)**2
+
+    idx1,idx2 = np.where( rank < 25 )
+    for i in np.unique(idx1):
+        all1 = idx1_to_all[i]
+        j = np.argmin(rank[i,:])
+        if i == np.argmin(rank[:,j]):
+            all2 = idx2_to_all[j]
+            basepairs[all1] = all2
+            basepairs[all2] = all1
+        else:
+            raise Exception("Unable to detect basepair")
 
     return basepairs
 
@@ -185,11 +188,18 @@ def find_stacks( u, centers, transforms, selection_text='all' ):
     expected_stack_positions = np.array(expected_stack_positions, dtype=np.float32)
        
     sel = u.select_atoms("({}) and name C1'".format( selection_text ))
-    stacks_above = residues_within( cutoff = 3.5,
-                                    coords1=centers,
-                                    coords2=expected_stack_positions,
-                                    sel1 = sel, sel2 = sel,
-                              universe = u, selection_text = selection_text )
+    idx1, idx2, dists = residues_within( cutoff = 3.5,
+                                         coords1 = centers,
+                                         coords2 = expected_stack_positions,
+                                         sel1 = sel, sel2 = sel )
+
+    ## Convert distances to stacks
+    stacks_above = -np.ones(len(sel), dtype=np.int)
+    for i in np.unique(idx1):
+        js = idx2[ idx1 == i ]
+        j = np.argmin(dists[i])
+        stacks_above[i] = j
+
     return stacks_above
 
 def basepairs_and_stacks_to_helixmap(basepairs,stacks_above):
-- 
GitLab