From 3efde87ab39284a9a9fb9c852f68924ba86a8ba9 Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Thu, 16 Feb 2023 13:00:12 -0600
Subject: [PATCH] Launch group site kernels with correct number of blocks

---
 src/GrandBrownTown.cu  | 12 +++++++-----
 src/GrandBrownTown.cuh | 13 ++++++++-----
 2 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/src/GrandBrownTown.cu b/src/GrandBrownTown.cu
index ef6ef5d..5df60f0 100644
--- a/src/GrandBrownTown.cu
+++ b/src/GrandBrownTown.cu
@@ -650,7 +650,9 @@ void GrandBrownTown::run()
 		}
 	    }
 
-	    if (numGroupSites > 0) updateGroupSites<<<(numGroupSites/32+1),32>>>(_pos[0], groupSiteData_d, num + num_rb_attached_particles, numGroupSites, numReplicas);
+	    gpuman.sync();
+	    if (numGroupSites > 0) updateGroupSites<<<(numGroupSites*numReplicas/32+1),32>>>(_pos[0], groupSiteData_d, num + num_rb_attached_particles, numGroupSites, numReplicas);
+	    gpuman.sync();
 
 	    #ifdef USE_NCCL
 	    if (gpuman.gpus.size() > 1) {
@@ -771,7 +773,7 @@ void GrandBrownTown::run()
 	    }
 	    #endif
 
-	    if (numGroupSites > 0) distributeGroupSiteForces<false><<<(numGroupSites/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
+	    if (numGroupSites > 0) distributeGroupSiteForces<false><<<(numGroupSites*numReplicas/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
 
         }//if step == 1
 
@@ -928,7 +930,7 @@ void GrandBrownTown::run()
 	if (numGroupSites > 0) {
  	  PUSH_NVTX("Update collective coordinates",2)
 	    gpuman.sync();
-	    updateGroupSites<<<(numGroupSites/32+1),32>>>(internal->getPos_d()[0], groupSiteData_d, num + num_rb_attached_particles, numGroupSites, numReplicas);
+	    updateGroupSites<<<(numGroupSites*numReplicas/32+1),32>>>(internal->getPos_d()[0], groupSiteData_d, num + num_rb_attached_particles, numGroupSites, numReplicas);
 	    gpuman.sync();
 	  POP_NVTX
 	}
@@ -1051,9 +1053,9 @@ void GrandBrownTown::run()
 	  PUSH_NVTX("Spread collective coordinate forces to constituent particles",4)
 	    gpuman.sync();
 	    // if ((s%100) == 0) {
-	    distributeGroupSiteForces<true><<<(numGroupSites/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
+	    distributeGroupSiteForces<false><<<(numGroupSites*numReplicas/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
 	// } else {
-	//     distributeGroupSiteForces<false><<<(numGroupSites/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
+	//     distributeGroupSiteForces<false><<<(numGroupSites*numReplicas/32+1),32>>>(internal->getForceInternal_d()[0], groupSiteData_d, num+num_rb_attached_particles, numGroupSites, numReplicas);
 	// }
 	    gpuman.sync();
 	  POP_NVTX
diff --git a/src/GrandBrownTown.cuh b/src/GrandBrownTown.cuh
index 44ffe6f..2783943 100644
--- a/src/GrandBrownTown.cuh
+++ b/src/GrandBrownTown.cuh
@@ -407,6 +407,7 @@ inline Vector3 step(Vector3& r0, float kTlocal, Vector3 force, float diffusion,
 	return sys->wrap(r);
 }
 
+template<bool print=false>
 __global__
 void updateGroupSites(Vector3 pos[], int* groupSiteData, int num, int numGroupSites, int numReplicas) {
     int i = blockIdx.x * blockDim.x + threadIdx.x;
@@ -432,12 +433,14 @@ void updateGroupSites(Vector3 pos[], int* groupSiteData, int num, int numGroupSi
 	    const int aj = groupSiteData[j] + num*rep;
 	    tmp += weight * pos[aj];
 	}
-	// printf("GroupSite %d (mod %d) COM (start,finish, x,y,z): (%d,%d, %f,%f,%f)\n",i, imod, start, finish, tmp.x, tmp.y, tmp.z);
+	if (print) {
+	    printf("GroupSite %d (rep %d/%d) COM (start,finish, x,y,z): (%d,%d, %f,%f,%f)\n",i, rep, numReplicas, start, finish, tmp.x, tmp.y, tmp.z);
+	}
 	pos[num*numReplicas + i] = tmp;
     }
 }
 
-template<bool print>
+template<bool print=false>
 __global__
 void distributeGroupSiteForces(Vector3 force[], int* groupSiteData, int num, int numGroupSites, int numReplicas) {
     // TODO, handle groupsite energies
@@ -452,9 +455,9 @@ void distributeGroupSiteForces(Vector3 force[], int* groupSiteData, int num, int
 	float weight = 1.0 / (finish-start);
 
 	const Vector3 tmp = weight*force[num*numReplicas+i];
-	// if (print) {
-	//     printf("GroupSite %d Force rep %d: %f %f %f\n",i, rep, tmp.x, tmp.y, tmp.z);
-	// }
+	if (print) {
+	    printf("GroupSite %d Force rep %d/%d: %f %f %f\n",i, rep, numReplicas, tmp.x, tmp.y, tmp.z);
+	}
 
 	for (int j = start; j < finish; j++) {
 	    const int aj = groupSiteData[j] + num*rep;
-- 
GitLab