From d70173e899abbeeab92bd87133b2063d4ddbb087 Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Tue, 29 Oct 2019 17:41:52 -0500
Subject: [PATCH] computeGridGridForce wraps about RB center, not about grid
 origin; may result in performance loss

---
 src/ComputeGridGrid.cu     | 63 +++++++++++++++++++++++++++++++++-----
 src/ComputeGridGrid.cuh    |  8 ++++-
 src/RigidBodyController.cu | 21 ++++++++++---
 src/RigidBodyController.h  |  2 ++
 4 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/src/ComputeGridGrid.cu b/src/ComputeGridGrid.cu
index ff2b7a4..a5002c5 100644
--- a/src/ComputeGridGrid.cu
+++ b/src/ComputeGridGrid.cu
@@ -4,9 +4,42 @@
 #include "CudaUtil.cuh"
 //RBTODO handle periodic boundaries
 //RBTODO: add __restrict__, benchmark (Q: how to restrict member data?)
-__global__
-void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_origin_u,
-			ForceEnergy* retForce, Vector3 * retTorque, int scheme, BaseGrid* sys_d) 
+
+class BasePositionTransformer {
+    /*
+      Abstract class providing for transforming positions around a RB
+      center or not, allowing common_computeGridGridForce to be used
+      for both RB Grid-Grid and Grid-PMF
+    */
+public:
+    __device__ inline virtual Vector3 operator() (Vector3 pos) const { return Vector3(); }
+};
+class GridPositionTransformer : public BasePositionTransformer {
+public:
+    __device__ GridPositionTransformer(const Vector3 o, const Vector3 c, BaseGrid* s) :
+	o(o), c(c), s(s) { }
+    __device__ inline Vector3 operator() (Vector3 pos) const {
+	return s->wrapDiff(pos + o) + c;
+    }
+private:
+    const Vector3 o;
+    const Vector3 c;
+    const BaseGrid* s;
+};
+class PmfPositionTransformer : public BasePositionTransformer {
+public:
+    __device__ PmfPositionTransformer(const Vector3 o) : o(o) { }
+    __device__ inline Vector3 operator() (Vector3 pos) const {
+	return pos + o;
+    }
+private:
+    const Vector3 o;
+};
+
+
+__device__
+inline void common_computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const BasePositionTransformer transformer,
+					ForceEnergy* retForce, Vector3 * retTorque, int scheme)
 {
 
 	extern __shared__ ForceEnergy s[];
@@ -24,10 +57,10 @@ void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, cons
 		// RBTODO: reduce registers used;
 		//   commenting out interpolateForceD still uses ~40 registers
 		//   -- the innocuous-looking fn below is responsible; consumes ~17 registers!
-		Vector3 r_pos= rho->getPosition(r_id); /* i,j,k value of voxel */
+	    Vector3 r_pos= rho->getPosition(r_id); /* i,j,k value of voxel */
 
-		r_pos = basis_rho.transform( r_pos ) + origin_rho_minus_origin_u; /* real space */
-                r_pos = sys_d->wrapDiff(r_pos); /* TODO: wrap about center of RB, not origin of u */
+	    r_pos = basis_rho.transform( r_pos );
+	    r_pos = transformer(r_pos);
 		const Vector3 u_ijk_float = basis_u_inv.transform( r_pos );
 		// RBTODO: Test for non-unit delta
 		/* Vector3 tmpf  = Vector3(0.0f); */
@@ -72,6 +105,22 @@ void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, cons
 	}
 }
 
+__global__
+void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_center_u, const Vector3 center_u_minus_origin_u,
+			ForceEnergy* retForce, Vector3 * retTorque, int scheme, BaseGrid* sys_d)
+{
+    BasePositionTransformer transformer = GridPositionTransformer(origin_rho_minus_center_u, center_u_minus_origin_u, sys_d);
+    common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
+}
+
+__global__
+void computePmfGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_origin_u,
+			 ForceEnergy* retForce, Vector3 * retTorque, int scheme)
+{
+    BasePositionTransformer transformer = PmfPositionTransformer(origin_rho_minus_origin_u);
+    common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
+}
+
 __global__
 void computePartGridForce(const Vector3* __restrict__ pos, Vector3* particleForce,
 				const int num, const int* __restrict__ particleIds, 
@@ -90,7 +139,7 @@ void computePartGridForce(const Vector3* __restrict__ pos, Vector3* particleForc
 	torque[tid] = ForceEnergy(0.f,0.f);
 	if (i < num) {
 		const int id = particleIds[i];
-		Vector3 p = sys_d->wrapDiff(pos[id]-center_u) + center_u - origin_u
+		Vector3 p = sys_d->wrapDiff(pos[id]-center_u) + center_u - origin_u;
 		const Vector3 u_ijk_float = basis_u_inv.transform( p );
 
                 ForceEnergy fe;
diff --git a/src/ComputeGridGrid.cuh b/src/ComputeGridGrid.cuh
index f2b7baa..739aa62 100644
--- a/src/ComputeGridGrid.cuh
+++ b/src/ComputeGridGrid.cuh
@@ -10,9 +10,15 @@ class BaseGrid;
 extern __global__
 void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u,
 				const Matrix3 basis_rho, const Matrix3 basis_u_inv,
-				const Vector3 origin_rho_minus_origin_u,
+				const Vector3 origin_rho_minus_center_u, const Vector3 center_u_minus_origin_u,
 				ForceEnergy* retForce, Vector3 * retTorque, int scheme, BaseGrid* sys_d);
 
+extern __global__
+void computePmfGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u,
+			 const Matrix3 basis_rho, const Matrix3 basis_u_inv,
+			 const Vector3 origin_rho_minus_origin_u,
+			 ForceEnergy* retForce, Vector3 * retTorque, int scheme);
+
 extern __global__
 void computePartGridForce(const Vector3* __restrict__ pos, Vector3* particleForce,
 				const int num, const int* __restrict__ particleIds,
diff --git a/src/RigidBodyController.cu b/src/RigidBodyController.cu
index 6c76292..85b5760 100644
--- a/src/RigidBodyController.cu
+++ b/src/RigidBodyController.cu
@@ -698,6 +698,16 @@ Vector3 RigidBodyForcePair::getOrigin2(const int i) {
 	else
 	    return o;
 }		
+Vector3 RigidBodyForcePair::getCenter2(const int i) {
+    Vector3 c;
+    if (!isPmf)
+	c = rb2->getPosition();
+    else {
+	const int k2 = gridKeyId2[i];
+	Vector3 o = type2->RBC->grids[k2].getCenter();
+    }
+    return c;
+}
 Matrix3 RigidBodyForcePair::getBasis1(const int i) {
 	const int k1 = gridKeyId1[i];
 	return rb1->getOrientation()*type1->RBC->grids[k1].getBasis();
@@ -745,20 +755,21 @@ void RigidBodyForcePair::callGridForceKernel(int pairId, int s, int scheme, Base
 	  	`––––––––––––––––––./
 		*/
 		Matrix3 B1 = getBasis1(i);
-		Vector3 c = getOrigin1(i) - getOrigin2(i);
+		// Vector3 c = getOrigin1(i) - getOrigin2(i);
+		Vector3 center_u = getCenter2(i);
 		Matrix3 B2 = getBasis2(i).inverse();
                 
 		// RBTODO: get energy
 		if (!isPmf) {								/* pair of RBs */
 			computeGridGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>>
 				(&type1->RBC->grids_d[k1], &type2->RBC->grids_d[k2],
-				 B1, B2, c,
+				 B1, B2, getOrigin1(i) - center_u, center_u - getOrigin2(i),
 				 forces_d[i], torques_d[i], scheme, sys_d);
 		} else {										/* RB with a PMF */
-			computeGridGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>>
+			computePmfGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>>
 				(&type1->RBC->grids_d[k1], &type2->RBC->grids_d[k2],
-				 B1, B2, c,
-				 forces_d[i], torques_d[i], scheme, sys_d);
+				 B1, B2, getOrigin1(i) - center_u,
+				 forces_d[i], torques_d[i], scheme);
 		}
 		// retrieveForcesForGrid(i); // this is slower than approach below, unsure why
 		
diff --git a/src/RigidBodyController.h b/src/RigidBodyController.h
index 856229b..ff99091 100644
--- a/src/RigidBodyController.h
+++ b/src/RigidBodyController.h
@@ -91,6 +91,8 @@ private:
 	Matrix3 getBasis2(const int i);
 	Vector3 getOrigin1(const int i);
 	Vector3 getOrigin2(const int i);
+	Vector3 getCenter2(const int i);
+
 
 	static GPUManager gpuman;
 };
-- 
GitLab