diff --git a/src/ComputeForce.cu b/src/ComputeForce.cu
index ff6255045df2846ffd2641744cebdb2d53e2b5ad..d7f055c3d4c495acc01698df6b60304cdac95705 100644
--- a/src/ComputeForce.cu
+++ b/src/ComputeForce.cu
@@ -123,16 +123,17 @@ ComputeForce::ComputeForce(const Configuration& c, const int numReplicas = 1) :
//Han-Yi Chou
int nCells = decomp.nCells.x * decomp.nCells.y * decomp.nCells.z;
//int* nCells_dev;
- int3 *Cells_dev;
-
- gpuErrchk(cudaMalloc(&CellNeighborsList,sizeof(int)*27*nCells));
- //gpuErrchk(cudaMalloc(&nCells_dev,sizeof(int)));
- gpuErrchk(cudaMalloc(&Cells_dev,sizeof(int3)));
- //gpuErrchk(cudaMemcpy(nCells_dev,&nCells,1,cudaMemcpyHostToDevice);
- gpuErrchk(cudaMemcpy(Cells_dev,&(decomp.nCells),sizeof(int3),cudaMemcpyHostToDevice));
- createNeighborsList<<<256,256>>>(Cells_dev,CellNeighborsList);
- gpuErrchk(cudaFree(Cells_dev));
- cudaBindTexture(0, NeighborsTex, CellNeighborsList, 27*nCells*sizeof(int));
+ if (nCells < MAX_CELLS_FOR_CELLNEIGHBORLIST) {
+ int3 *Cells_dev;
+ gpuErrchk(cudaMalloc(&CellNeighborsList,sizeof(int)*27*nCells));
+ //gpuErrchk(cudaMalloc(&nCells_dev,sizeof(int)));
+ gpuErrchk(cudaMalloc(&Cells_dev,sizeof(int3)));
+ //gpuErrchk(cudaMemcpy(nCells_dev,&nCells,1,cudaMemcpyHostToDevice);
+ gpuErrchk(cudaMemcpy(Cells_dev,&(decomp.nCells),sizeof(int3),cudaMemcpyHostToDevice));
+ createNeighborsList<<<256,256>>>(Cells_dev,CellNeighborsList);
+ gpuErrchk(cudaFree(Cells_dev));
+ cudaBindTexture(0, NeighborsTex, CellNeighborsList, 27*nCells*sizeof(int));
+ }
}
//Calculate the number of blocks the grid should contain
diff --git a/src/ComputeForce.cuh b/src/ComputeForce.cuh
index babbe0ad40f7d5d9c387363f3914a6bfc393d817..2f818b89cd244930d35deeaf28c506e0981671d9 100644
--- a/src/ComputeForce.cuh
+++ b/src/ComputeForce.cuh
@@ -8,6 +8,9 @@
#include "TabulatedMethods.cuh"
#define BD_PI 3.1415927f
+
+#define MAX_CELLS_FOR_CELLNEIGHBORLIST 1<<25
+
texture<int, 1, cudaReadModeElementType> NeighborsTex;
texture<int, 1, cudaReadModeElementType> pairTabPotTypeTex;
texture<int2, 1, cudaReadModeElementType> pairListsTex;
@@ -234,6 +237,35 @@ int getExSum() {
return tmp;
}
//
+__device__
+int computeCellNeighbor( const int3 cells, const int3 cell_idx, const int dx, const int dy, const int dz )
+{
+ int idx = cell_idx.x;
+ int idy = cell_idx.y;
+ int idz = cell_idx.z;
+
+ int u = idx + dx;
+ int v = idy + dy;
+ int w = idz + dz;
+
+ int nID;
+ if (cells.x == 1 and u != 0) nID = -1;
+ else if (cells.y == 1 and v != 0) nID = -1;
+ else if (cells.z == 1 and w != 0) nID = -1;
+ else if (cells.x == 2 and (u < 0 || u > 1)) nID = -1;
+ else if (cells.y == 2 and (v < 0 || v > 1)) nID = -1;
+ else if (cells.z == 2 and (w < 0 || w > 1)) nID = -1;
+ else
+ {
+ u = (u + cells.x) % cells.x;
+ v = (v + cells.y) % cells.y;
+ w = (w + cells.z) % cells.z;
+ nID = w + cells.z * (v + cells.y * u);
+ }
+
+ return nID;
+}
+
__global__
void createNeighborsList(const int3 *Cells,int* __restrict__ CellNeighborsList)
{
@@ -254,25 +286,7 @@ void createNeighborsList(const int3 *Cells,int* __restrict__ CellNeighborsList)
for (int dy = -1; dy <= 1; ++dy) {
for (int dz = -1; dz <= 1; ++dz) {
- int u = idx + dx;
- int v = idy + dy;
- int w = idz + dz;
-
- if (cells.x == 1 and u != 0) nID = -1;
- else if (cells.y == 1 and v != 0) nID = -1;
- else if (cells.z == 1 and w != 0) nID = -1;
- else if (cells.x == 2 and (u < 0 || u > 1)) nID = -1;
- else if (cells.y == 2 and (v < 0 || v > 1)) nID = -1;
- else if (cells.z == 2 and (w < 0 || w > 1)) nID = -1;
- else
- {
- u = (u + cells.x) % cells.x;
- v = (v + cells.y) % cells.y;
- w = (w + cells.z) % cells.z;
-
- nID = w + cells.z * (v + cells.y * u);
- }
-
+ nID = computeCellNeighbor( cells, make_int3(idx,idy,idz), dx, dy, dz );
CellNeighborsList[size_t(count+27*cID)] = nID;
++count;
//__syncthreads();
@@ -342,7 +356,20 @@ __global__ void createPairlists(Vector3* __restrict__ pos, const int num, const
int currEx = ex_pair.x;
int nextEx = (ex_pair.x >= 0) ? excludes[currEx].ind2 : -1;
- int neighbor_cell = tex1Dfetch(NeighborsTex,idx+27*cellid_i);
+ int neighbor_cell;
+ if (nCells < MAX_CELLS_FOR_CELLNEIGHBORLIST) {
+ neighbor_cell = tex1Dfetch(NeighborsTex,idx+27*cellid_i);
+ } else {
+ int3 cells = decomp->nCells;
+ int3 cell_idx = make_int3(cellid_i % cells.z,
+ cellid_i / cells.z % cells.y,
+ cellid_i / (cells.z * cells.y));
+
+ int dz = (idx % 3) - 1;
+ int dy = ((idx/3) % 3) - 1;
+ int dx = ((idx/9) % 3) - 1;
+ neighbor_cell = computeCellNeighbor( decomp->nCells, cell_idx, dx, dy, dz );
+ }
if(neighbor_cell < 0)
{