From 21aa6634736240f107ffa075b51ccbdc59842f9b Mon Sep 17 00:00:00 2001 From: Chris Maffeo <cmaffeo2@illinois.edu> Date: Thu, 3 Mar 2016 15:23:41 -0600 Subject: [PATCH] cherry-pick: made matrix3 diag-aware; ~5% speedup of tabKernel Conflicts: ComputeForce.cuh TabulatedPotential.h --- BaseGrid.h | 7 +++++++ ComputeForce.cuh | 1 - TabulatedPotential.h | 20 ++++++++++++++++++-- useful.cu | 45 ++++++++++++++++++++++++++------------------ useful.h | 25 +++++++++++++++++++----- 5 files changed, 72 insertions(+), 26 deletions(-) diff --git a/BaseGrid.h b/BaseGrid.h index ce171db..129b6a2 100644 --- a/BaseGrid.h +++ b/BaseGrid.h @@ -617,6 +617,13 @@ public: } // Wrap vector distance, -0.5*l <= x < 0.5*l && ... + /* HOST DEVICE inline Vector3 wrapDiff(Vector3 r) const { */ + /* Vector3 l = basisInv.transform(r); */ + /* l.x = wrapDiff(l.x, nx); */ + /* l.y = wrapDiff(l.y, ny); */ + /* l.z = wrapDiff(l.z, nz); */ + /* return basis.transform(l); */ + /* } */ HOST DEVICE inline Vector3 wrapDiff(Vector3 r) const { Vector3 l = basisInv.transform(r); l.x = wrapDiff(l.x, nx); diff --git a/ComputeForce.cuh b/ComputeForce.cuh index 98b7fa3..bda2c57 100644 --- a/ComputeForce.cuh +++ b/ComputeForce.cuh @@ -416,7 +416,6 @@ __global__ void computeTabulatedKernel(Vector3* force, Vector3* pos, int* type, if (get_energy && aj > ai) atomicAdd( &(g_energies[ai]), fe[tid].e ); } - } diff --git a/TabulatedPotential.h b/TabulatedPotential.h index 4fb3b59..d3f6f5f 100644 --- a/TabulatedPotential.h +++ b/TabulatedPotential.h @@ -46,7 +46,7 @@ public: Vector3 computeForce(Vector3 r); - HOST DEVICE inline EnergyForce compute(Vector3 r) { + HOST DEVICE inline EnergyForce computeOLD(Vector3 r) { float d = r.length(); Vector3 rUnit = -r/d; int home = int(floorf((d - r0)/dr)); @@ -54,7 +54,7 @@ public: if (home >= n) return EnergyForce(e0, Vector3(0.0f)); float homeR = home*dr + r0; float w = (d - homeR)/dr; - + // Interpolate. float energy = v3[home]*w*w*w + v2[home]*w*w + v1[home]*w + v0[home]; Vector3 force = -(3.0f*v3[home] * w * w @@ -63,6 +63,22 @@ public: return EnergyForce(energy,force); } + HOST DEVICE inline EnergyForce compute(Vector3 r) { + float d = r.length(); + float w = (d - r0)/dr; + int home = int( floorf(w) ); + w = w - home; + if (home < 0) return EnergyForce(v0[0], Vector3(0.0f)); + if (home >= n) return EnergyForce(e0, Vector3(0.0f)); + + // Interpolate. + float energy = v3[home]*w*w*w + v2[home]*w*w + v1[home]*w + v0[home]; + Vector3 force = (-(3.0f*v3[home] *w*w + + 2.0f*v2[home] *w + + + v1[home])/(d*dr)) * r; + return EnergyForce(energy,force); + } + // private: public: float* pot; diff --git a/useful.cu b/useful.cu index 090665f..9e1a118 100644 --- a/useful.cu +++ b/useful.cu @@ -355,6 +355,7 @@ Matrix3::Matrix3(float s) { ezx = 0.0f; ezy = 0.0f; ezz = s; + isDiag = true; } Matrix3::Matrix3(float xx, float xy, float xz, float yx, float yy, float yz, float zx, float zy, float zz) { @@ -367,6 +368,7 @@ Matrix3::Matrix3(float xx, float xy, float xz, float yx, float yy, float yz, flo ezx = zx; ezy = zy; ezz = zz; + setIsDiag(); } Matrix3::Matrix3(float x, float y, float z) { @@ -379,6 +381,7 @@ Matrix3::Matrix3(float x, float y, float z) { ezx = 0.0f; ezy = 0.0f; ezz = z; + isDiag = true; } Matrix3::Matrix3(const Vector3& ex, const Vector3& ey, const Vector3& ez) { @@ -391,7 +394,7 @@ Matrix3::Matrix3(const Vector3& ex, const Vector3& ey, const Vector3& ez) { exz = ez.x; eyz = ez.y; ezz = ez.z; - + setIsDiag(); } Matrix3::Matrix3(const float* d) { @@ -403,8 +406,9 @@ Matrix3::Matrix3(const float* d) { eyz = d[5]; ezx = d[6]; ezy = d[7]; - ezz = d[8]; -} + ezz = d[8]; + setIsDiag(); +} const Matrix3 Matrix3::operator*(float s) const { Matrix3 m; @@ -417,7 +421,7 @@ const Matrix3 Matrix3::operator*(float s) const { m.ezx = s*ezx; m.ezy = s*ezy; m.ezz = s*ezz; - + m.isDiag = isDiag; return m; } @@ -435,6 +439,7 @@ const Matrix3 Matrix3::operator*(const Matrix3& m) const { ret.exz = exx*m.exz + exy*m.eyz + exz*m.ezz; ret.eyz = eyx*m.exz + eyy*m.eyz + eyz*m.ezz; ret.ezz = ezx*m.exz + ezy*m.eyz + ezz*m.ezz; + ret.setIsDiag(); return ret; } @@ -449,29 +454,33 @@ const Matrix3 Matrix3::operator-() const { m.ezx = -ezx; m.ezy = -ezy; m.ezz = -ezz; - + m.isDiag = isDiag; return m; } Matrix3 Matrix3::inverse() const { Matrix3 m; - float det = exx*(eyy*ezz-eyz*ezy) - exy*(eyx*ezz-eyz*ezx) + exz*(eyx*ezy-eyy*ezx); - - m.exx = (eyy*ezz - eyz*ezy)/det; - m.exy = -(exy*ezz - exz*ezy)/det; - m.exz = (exy*eyz - exz*eyy)/det; - m.eyx = -(eyx*ezz - eyz*ezx)/det; - m.eyy = (exx*ezz - exz*ezx)/det; - m.eyz = -(exx*eyz - exz*eyx)/det; - m.ezx = (eyx*ezy - eyy*ezx)/det; - m.ezy = -(exx*ezy - exy*ezx)/det; - m.ezz = (exx*eyy - exy*eyx)/det; - + if (isDiag) { + m = Matrix3(1.0f/exx,1.0f/eyy,1.0f/ezz); + } else { + float det = exx*(eyy*ezz-eyz*ezy) - exy*(eyx*ezz-eyz*ezx) + exz*(eyx*ezy-eyy*ezx); + m.exx = (eyy*ezz - eyz*ezy)/det; + m.exy = -(exy*ezz - exz*ezy)/det; + m.exz = (exy*eyz - exz*eyy)/det; + m.eyx = -(eyx*ezz - eyz*ezx)/det; + m.eyy = (exx*ezz - exz*ezx)/det; + m.eyz = -(exx*eyz - exz*eyx)/det; + m.ezx = (eyx*ezy - eyy*ezx)/det; + m.ezy = -(exx*ezy - exy*ezx)/det; + m.ezz = (exx*eyy - exy*eyx)/det; + m.isDiag = isDiag; + } return m; } float Matrix3::det() const { - return exx*(eyy*ezz-eyz*ezy) - exy*(eyx*ezz-eyz*ezx) + exz*(eyx*ezy-eyy*ezx); + return isDiag ? exx*eyy*ezz : + exx*(eyy*ezz-eyz*ezy) - exy*(eyx*ezz-eyz*ezx) + exz*(eyx*ezy-eyy*ezx); } diff --git a/useful.h b/useful.h index 1e52796..75a5d78 100644 --- a/useful.h +++ b/useful.h @@ -284,6 +284,7 @@ public: m.ezx = exz; m.ezy = eyz; m.ezz = ezz; + m.isDiag = isDiag; return m; } @@ -294,9 +295,15 @@ public: HOST DEVICE inline Vector3 transform(const Vector3& v) const { Vector3 w; - w.x = exx*v.x + exy*v.y + exz*v.z; - w.y = eyx*v.x + eyy*v.y + eyz*v.z; - w.z = ezx*v.x + ezy*v.y + ezz*v.z; + if (isDiag) { + w.x = exx*v.x; + w.y = eyy*v.y; + w.z = ezz*v.z; + } else { + w.x = exx*v.x + exy*v.y + exz*v.z; + w.y = eyx*v.x + eyy*v.y + eyz*v.z; + w.z = ezx*v.x + ezy*v.y + ezz*v.z; + } return w; } @@ -314,9 +321,18 @@ public: ret.exz = exx*m.exz + exy*m.eyz + exz*m.ezz; ret.eyz = eyx*m.exz + eyy*m.eyz + eyz*m.ezz; ret.ezz = ezx*m.exz + ezy*m.eyz + ezz*m.ezz; + ret.setIsDiag(); return ret; } + HOST DEVICE void setIsDiag() { + isDiag = (exy == 0 && exz == 0 && + eyx == 0 && eyz == 0 && + ezx == 0 && ezy == 0) ? true : false; + } + + + Vector3 ex() const; Vector3 ey() const; Vector3 ez() const; @@ -328,11 +344,10 @@ public: float exx, exy, exz; float eyx, eyy, eyz; float ezx, ezy, ezz; + bool isDiag; }; - - Matrix3 operator*(float s, Matrix3 m); Matrix3 operator/(Matrix3 m, float s); -- GitLab