From 5d2010c56e3fdb9671ba89114f09b5a2155be429 Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Tue, 25 Feb 2020 11:53:08 -0600
Subject: [PATCH] Improved atomic pdb basepair assignment by making distance
 metric symmetric and by assigning basepairs in order of quality of fit

---
 mrdna/readers/segmentmodel_from_pdb.py | 42 +++++++++++++++++---------
 1 file changed, 27 insertions(+), 15 deletions(-)

diff --git a/mrdna/readers/segmentmodel_from_pdb.py b/mrdna/readers/segmentmodel_from_pdb.py
index f16d907..ca17030 100644
--- a/mrdna/readers/segmentmodel_from_pdb.py
+++ b/mrdna/readers/segmentmodel_from_pdb.py
@@ -154,28 +154,40 @@ def find_basepairs( u, centers, transforms, selection_text='all' ):
         all1 = idx1_to_all[i]
         R1 = transforms[all1]
         c1 = centers[all1]
+        c2_expected = c1 + ref_bp_position.dot(R1)
 
         for j in idx2[ idx1 == i ]:
             all2 = idx2_to_all[j]
             R2 = transforms[all2]
             c2 = centers[all2]
 
-            c2_expected = c1 + ref_bp_position.dot(R1)
-            dist = np.linalg.norm(c2_expected-c2)
-            angle= R1.T.dot(np.array((0,0,1))).dot(R2.T.dot(np.array((0,0,1))))
-            rank[i,j] = dist**2 + 135*(1+angle)**2
-
-    idx1,idx2 = np.where( rank < 25 )
-    for i in np.unique(idx1):
-        all1 = idx1_to_all[i]
-        j = np.argmin(rank[i,:])
-        if i == np.argmin(rank[:,j]):
-            all2 = idx2_to_all[j]
-            basepairs[all1] = all2
-            basepairs[all2] = all1
-        else:
-            raise Exception("Unable to detect basepair")
+            c1_expected = c2 + ref_bp_position.dot(R2)
 
+            dist1 = np.linalg.norm(c1_expected-c1)
+            dist2 = np.linalg.norm(c2_expected-c2)
+            angle= R1.T.dot(np.array((0,0,1))).dot(R2.T.dot(np.array((0,0,1))))
+            rank[i,j] = 0.5*dist1**2 + 0.5*dist2**2 + 135*(1+angle)**2
+
+    flatrank = rank.flatten()
+    IDX = np.argsort(rank, axis=None)
+    IDX = IDX[:np.sum(flatrank < 25)] # truncate
+    I = IDX//rank.shape[1]
+    J = IDX - I * rank.shape[1]
+    ALL1 = idx1_to_all[I]
+    ALL2 = idx2_to_all[J]
+
+    for all1,all2 in zip(ALL1,ALL2):
+        if basepairs[all1] >= 0 or basepairs[all2] >= 0:
+            continue
+        basepairs[all1] = all2
+        basepairs[all2] = all1
+
+    # checksum = np.sum((basepairs[ALL1] < 0) + (basepairs[ALL2] < 0))
+    # if checksum > 0:
+    #     print(checksum)
+    #     import pdb
+    #     pdb.set_trace()
+    #     ...
     return basepairs
 
 def find_stacks_mdanalysis( u, centers, transforms, selection_text='all' ):
-- 
GitLab