diff --git a/src/ComputeGridGrid.cu b/src/ComputeGridGrid.cu
index a5002c5b4a8442b69011b23292e5cb1ba6f36a76..1ecfb9fad191cb1e70d787e384198b57b2a87883 100644
--- a/src/ComputeGridGrid.cu
+++ b/src/ComputeGridGrid.cu
@@ -5,16 +5,7 @@
 //RBTODO handle periodic boundaries
 //RBTODO: add __restrict__, benchmark (Q: how to restrict member data?)
 
-class BasePositionTransformer {
-    /*
-      Abstract class providing for transforming positions around a RB
-      center or not, allowing common_computeGridGridForce to be used
-      for both RB Grid-Grid and Grid-PMF
-    */
-public:
-    __device__ inline virtual Vector3 operator() (Vector3 pos) const { return Vector3(); }
-};
-class GridPositionTransformer : public BasePositionTransformer {
+class GridPositionTransformer {
 public:
     __device__ GridPositionTransformer(const Vector3 o, const Vector3 c, BaseGrid* s) :
 	o(o), c(c), s(s) { }
@@ -26,7 +17,8 @@ private:
     const Vector3 c;
     const BaseGrid* s;
 };
-class PmfPositionTransformer : public BasePositionTransformer {
+//class PmfPositionTransformer : public BasePositionTransformer {
+class PmfPositionTransformer {
 public:
     __device__ PmfPositionTransformer(const Vector3 o) : o(o) { }
     __device__ inline Vector3 operator() (Vector3 pos) const {
@@ -36,9 +28,9 @@ private:
     const Vector3 o;
 };
 
-
+template <typename T>
 __device__
-inline void common_computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const BasePositionTransformer transformer,
+inline void common_computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const T& transformer,
 					ForceEnergy* retForce, Vector3 * retTorque, int scheme)
 {
 
@@ -109,16 +101,16 @@ __global__
 void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_center_u, const Vector3 center_u_minus_origin_u,
 			ForceEnergy* retForce, Vector3 * retTorque, int scheme, BaseGrid* sys_d)
 {
-    BasePositionTransformer transformer = GridPositionTransformer(origin_rho_minus_center_u, center_u_minus_origin_u, sys_d);
-    common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
+    GridPositionTransformer transformer = GridPositionTransformer(origin_rho_minus_center_u, center_u_minus_origin_u, sys_d);
+    common_computeGridGridForce<GridPositionTransformer>(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
 }
 
 __global__
 void computePmfGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_origin_u,
 			 ForceEnergy* retForce, Vector3 * retTorque, int scheme)
 {
-    BasePositionTransformer transformer = PmfPositionTransformer(origin_rho_minus_origin_u);
-    common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
+    PmfPositionTransformer transformer = PmfPositionTransformer(origin_rho_minus_origin_u);
+    common_computeGridGridForce<PmfPositionTransformer>(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme);
 }
 
 __global__
diff --git a/src/RigidBodyController.cu b/src/RigidBodyController.cu
index 85b5760332414728c72797f0bf06064fbae8d02d..3a3babffa7e077585c6fd38568e97bc79195a21c 100644
--- a/src/RigidBodyController.cu
+++ b/src/RigidBodyController.cu
@@ -766,10 +766,10 @@ void RigidBodyForcePair::callGridForceKernel(int pairId, int s, int scheme, Base
 				 B1, B2, getOrigin1(i) - center_u, center_u - getOrigin2(i),
 				 forces_d[i], torques_d[i], scheme, sys_d);
 		} else {										/* RB with a PMF */
-			computePmfGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>>
+			computeGridGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>>
 				(&type1->RBC->grids_d[k1], &type2->RBC->grids_d[k2],
-				 B1, B2, getOrigin1(i) - center_u,
-				 forces_d[i], torques_d[i], scheme);
+				 B1, B2, getOrigin1(i) - center_u, center_u-getOrigin2(i),
+				 forces_d[i], torques_d[i], scheme, sys_d);
 		}
 		// retrieveForcesForGrid(i); // this is slower than approach below, unsure why