From a4b4e60ab95b1c8d855611c0003d492753d40b7a Mon Sep 17 00:00:00 2001
From: Chris Maffeo <cmaffeo2@illinois.edu>
Date: Tue, 11 Feb 2025 14:15:29 -0600
Subject: [PATCH] Enable force grids

---
 src/Configuration.cpp | 28 ++++++++++++++++++++++++++++
 src/Configuration.h   |  1 +
 2 files changed, 29 insertions(+)

diff --git a/src/Configuration.cpp b/src/Configuration.cpp
index a572f45..6d1cf92 100644
--- a/src/Configuration.cpp
+++ b/src/Configuration.cpp
@@ -382,12 +382,14 @@ Configuration::Configuration(const char* config_file, int simNum, bool debug) :
 	    }
 		if (partForceXGridFile[i].length() != 0) {
 			part[i].forceXGrid = new BaseGrid(partForceXGridFile[i].val());
+			if (partForceGridScale[i] != nullptr) part[i].forceXGrid->scale( partForceGridScale[i][0] );
 			printf("Loaded `%s'.\n", partForceXGridFile[i].val());
 			printf("Grid size %s.\n", part[i].forceXGrid->getExtent().toString().val());
 		}
 
 		if (partForceYGridFile[i].length() != 0) {
 			part[i].forceYGrid = new BaseGrid(partForceYGridFile[i].val());
+			if (partForceGridScale[i] != nullptr) part[i].forceYGrid->scale( partForceGridScale[i][1] );
 			printf("Loaded `%s'.\n", partForceYGridFile[i].val());
 			printf("Grid size %s.\n", part[i].forceYGrid->getExtent().toString().val());
 		}
@@ -396,6 +398,10 @@ Configuration::Configuration(const char* config_file, int simNum, bool debug) :
 			part[i].forceZGrid = new BaseGrid(partForceZGridFile[i].val());
 			printf("Loaded `%s'.\n", partForceZGridFile[i].val());
 			printf("Grid size %s.\n", part[i].forceZGrid->getExtent().toString().val());
+			if (partForceGridScale[i] != nullptr) {
+			    printf("Scaling forceGridZ `%s' by %f.\n", partForceZGridFile[i].val(), partForceGridScale[i][2] );
+			    part[i].forceZGrid->scale( partForceGridScale[i][2] );
+			}
 		}
 
 		if (partDiffusionGridFile[i].length() != 0) {
@@ -755,6 +761,17 @@ void Configuration::copyToCUDA() {
 		    b->diffusionGrid = NULL;
 		}
 		
+		// Copy the diffusion grid
+		if (part[i].forceXGrid != nullptr) {
+		    b->forceXGrid = part[i].forceXGrid->copy_to_cuda();
+		}
+		if (part[i].forceYGrid != nullptr) {
+		    b->forceYGrid = part[i].forceYGrid->copy_to_cuda();
+		}
+		if (part[i].forceZGrid != nullptr) {
+		    b->forceZGrid = part[i].forceZGrid->copy_to_cuda();
+		}
+
 		//b->pmf = pmf;
 		gpuErrchk(cudaMalloc(&part_addr[i], sizeof(BrownianParticleType)));
 		gpuErrchk(cudaMemcpyAsync(part_addr[i], b, sizeof(BrownianParticleType),
@@ -936,6 +953,8 @@ int Configuration::readParameters(const char * config_file) {
 	partForceXGridFile = new String[numParts];
 	partForceYGridFile = new String[numParts];
 	partForceZGridFile = new String[numParts];
+        partForceGridScale  = new float*[numParts];
+
 	partDiffusionGridFile = new String[numParts];
 	partReservoirFile = new String[numParts];
 	partRigidBodyGrid.resize(numParts);
@@ -958,6 +977,7 @@ int Configuration::readParameters(const char * config_file) {
         {
             partGridFile[i] = NULL;
             partGridFileScale[i] = NULL;
+            partForceGridScale[i] = nullptr;
             //part[i].numPartGridFiles = -1;
         }
         //for(int i = 0; i < numParts; ++i)
@@ -1085,6 +1105,14 @@ int Configuration::readParameters(const char * config_file) {
 		} else if (param == String("forceZGridFile")) {
 		    if (currPart < 0) exit(1);
 		    partForceZGridFile[currPart] = value;
+		} else if (param == String("forceGridScale")) {
+		    if (currPart < 0) exit(1);
+		    int tmp;
+		    stringToArray<float>(&value, tmp, &partForceGridScale[currPart]);
+		    if (tmp != 3) {
+			printf("ERROR: Expected three floating point scale values for x,y,z, but got `%s'.\n", param.val());
+			exit(1);
+		    }
 		} else if (param == String("diffusionGridFile")) {
 		    if (currPart < 0) exit(1);
 		    partDiffusionGridFile[currPart] = value;
diff --git a/src/Configuration.h b/src/Configuration.h
index 7fb589d..4be2a25 100644
--- a/src/Configuration.h
+++ b/src/Configuration.h
@@ -237,6 +237,7 @@ public:
 	String* partForceXGridFile;
 	String* partForceYGridFile;
 	String* partForceZGridFile;
+	float **partForceGridScale;
 	String* partTableFile;
 	String* partReservoirFile;
 	int* partTableIndex0;
-- 
GitLab