diff --git a/src/ComputeGridGrid.cu b/src/ComputeGridGrid.cu index a5002c5b4a8442b69011b23292e5cb1ba6f36a76..1ecfb9fad191cb1e70d787e384198b57b2a87883 100644 --- a/src/ComputeGridGrid.cu +++ b/src/ComputeGridGrid.cu @@ -5,16 +5,7 @@ //RBTODO handle periodic boundaries //RBTODO: add __restrict__, benchmark (Q: how to restrict member data?) -class BasePositionTransformer { - /* - Abstract class providing for transforming positions around a RB - center or not, allowing common_computeGridGridForce to be used - for both RB Grid-Grid and Grid-PMF - */ -public: - __device__ inline virtual Vector3 operator() (Vector3 pos) const { return Vector3(); } -}; -class GridPositionTransformer : public BasePositionTransformer { +class GridPositionTransformer { public: __device__ GridPositionTransformer(const Vector3 o, const Vector3 c, BaseGrid* s) : o(o), c(c), s(s) { } @@ -26,7 +17,8 @@ private: const Vector3 c; const BaseGrid* s; }; -class PmfPositionTransformer : public BasePositionTransformer { +//class PmfPositionTransformer : public BasePositionTransformer { +class PmfPositionTransformer { public: __device__ PmfPositionTransformer(const Vector3 o) : o(o) { } __device__ inline Vector3 operator() (Vector3 pos) const { @@ -36,9 +28,9 @@ private: const Vector3 o; }; - +template <typename T> __device__ -inline void common_computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const BasePositionTransformer transformer, +inline void common_computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const T& transformer, ForceEnergy* retForce, Vector3 * retTorque, int scheme) { @@ -109,16 +101,16 @@ __global__ void computeGridGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_center_u, const Vector3 center_u_minus_origin_u, ForceEnergy* retForce, Vector3 * retTorque, int scheme, BaseGrid* sys_d) { - BasePositionTransformer transformer = GridPositionTransformer(origin_rho_minus_center_u, center_u_minus_origin_u, sys_d); - common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme); + GridPositionTransformer transformer = GridPositionTransformer(origin_rho_minus_center_u, center_u_minus_origin_u, sys_d); + common_computeGridGridForce<GridPositionTransformer>(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme); } __global__ void computePmfGridForce(const RigidBodyGrid* rho, const RigidBodyGrid* u, const Matrix3 basis_rho, const Matrix3 basis_u_inv, const Vector3 origin_rho_minus_origin_u, ForceEnergy* retForce, Vector3 * retTorque, int scheme) { - BasePositionTransformer transformer = PmfPositionTransformer(origin_rho_minus_origin_u); - common_computeGridGridForce(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme); + PmfPositionTransformer transformer = PmfPositionTransformer(origin_rho_minus_origin_u); + common_computeGridGridForce<PmfPositionTransformer>(rho, u, basis_rho, basis_u_inv, transformer, retForce, retTorque, scheme); } __global__ diff --git a/src/RigidBodyController.cu b/src/RigidBodyController.cu index 85b5760332414728c72797f0bf06064fbae8d02d..3a3babffa7e077585c6fd38568e97bc79195a21c 100644 --- a/src/RigidBodyController.cu +++ b/src/RigidBodyController.cu @@ -766,10 +766,10 @@ void RigidBodyForcePair::callGridForceKernel(int pairId, int s, int scheme, Base B1, B2, getOrigin1(i) - center_u, center_u - getOrigin2(i), forces_d[i], torques_d[i], scheme, sys_d); } else { /* RB with a PMF */ - computePmfGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>> + computeGridGridForce<<< nb, NUMTHREADS, 2*sizeof(ForceEnergy)*NUMTHREADS, s>>> (&type1->RBC->grids_d[k1], &type2->RBC->grids_d[k2], - B1, B2, getOrigin1(i) - center_u, - forces_d[i], torques_d[i], scheme); + B1, B2, getOrigin1(i) - center_u, center_u-getOrigin2(i), + forces_d[i], torques_d[i], scheme, sys_d); } // retrieveForcesForGrid(i); // this is slower than approach below, unsure why