Skip to content
Snippets Groups Projects
Commit daffb6ea authored by Yasmin Sarita's avatar Yasmin Sarita
Browse files

faster depthwise conv

parent 3239b2b6
No related branches found
No related tags found
No related merge requests found
......@@ -124,6 +124,304 @@ __global__ void depthwise_conv(float* const __restrict__ y,
#undef x4d
}
__global__ void depthwise_conv12(float* const __restrict__ y,
const float* const __restrict__ x,
const float* const __restrict__ w,
const int B, const int M,
const int H, const int W, const int KH,
const int KW, const int H_out, const int W_out,
const int H_pad, const int W_pad,
const int H_stride, const int W_stride)
{
#define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (W_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (M * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
const int num = 12;
const int b = num * blockIdx.x;
const int m = blockIdx.y; //current filter/channel
const int tx = threadIdx.x;
const int start_h = (threadIdx.x / W_out) * H_stride - H_pad;
const int start_w = (threadIdx.x % W_out) * W_stride - W_pad;
float C[num] = { 0 };
const float* weights = &w[m * KH * KW];
for (int k = 0; k < KH * KW; k++) {
int p = k / KW;
int q = k % KW;
if (start_h + p > -1 && start_h + p < H &&
start_w + q > -1 && start_w + q < W) {
#pragma unroll
for (int i = 0; i < num; i++) {
//if(b + i < B)
C[i] += x4d(b + i, m, start_h + p, start_w + q) * weights[k];
}
}
}
#pragma unroll
for (int i = 0; i < num; i++) {
//if(b + i < B)
y4d(b + i, m, 0, tx) = C[i];
}
#undef y4d
#undef x4d
}
__global__ void depthwise_convNew(float* const __restrict__ y,
const float* const __restrict__ x,
const float* const __restrict__ w,
const int B, const int M,
const int H, const int W, const int KH,
const int KW, const int H_out, const int W_out,
const int H_pad, const int W_pad,
const int H_stride, const int W_stride)
{
#define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (W_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (M * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
const int num = 12;
const int b = num * blockIdx.x;
const int m = (blockIdx.y * blockDim.x + threadIdx.x)/ (H_out * W_out);
const int tx = (blockIdx.y * blockDim.x + threadIdx.x) % (H_out * W_out);
const int start_h = (tx / W_out) * H_stride - H_pad;
const int start_w = (tx % W_out) * W_stride - W_pad;
float C[num] = { 0 };
const float* weights = &w[m * KH * KW];
for (int k = 0; k < KH * KW; k++) {
int p = k / KW;
int q = k % KW;
if (start_h + p > -1 && start_h + p < H &&
start_w + q > -1 && start_w + q < W) {
#pragma unroll
for (int i = 0; i < num; i++) {
if(b + i < B)
C[i] += x4d(b + i, m, start_h + p, start_w + q) * weights[k];
}
}
}
#pragma unroll
for (int i = 0; i < num; i++) {
if(b + i < B)
y4d(b + i, m, 0, tx) = C[i];
}
#undef y4d
#undef x4d
}
__global__ void depthwise_convNew8(float* const __restrict__ y,
const float* const __restrict__ x,
const float* const __restrict__ w,
const int B, const int M,
const int H, const int W, const int KH,
const int KW, const int H_out, const int W_out,
const int H_pad, const int W_pad,
const int H_stride, const int W_stride)
{
#define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (W_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (M * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
const int num = 8;
const int b = num * blockIdx.x;
const int m = (blockIdx.y * blockDim.x + threadIdx.x)/ (H_out * W_out);
if(m < M){
const int tx = (blockIdx.y * blockDim.x + threadIdx.x) % (H_out * W_out);
const int start_h = (tx / W_out) * H_stride - H_pad;
const int start_w = (tx % W_out) * W_stride - W_pad;
float c0 = 0;
float c1 = 0;
float c2 = 0;
float c3 = 0;
float c4 = 0;
float c5 = 0;
float c6 = 0;
float c7 = 0;
const float* weights = &w[m * KH * KW];
for (int k = 0; k < KH * KW; k++) {
int p = k / KW;
int q = k % KW;
if (start_h + p > -1 && start_h + p < H &&
start_w + q > -1 && start_w + q < W) {
c0 += x4d(b, m, start_h + p, start_w + q) * weights[k];
if(b + 1 < B)
c1 += x4d(b + 1, m, start_h + p, start_w + q) * weights[k];
if(b + 2 < B)
c2 += x4d(b + 2, m, start_h + p, start_w + q) * weights[k];
if(b + 3 < B)
c3 += x4d(b + 3, m, start_h + p, start_w + q) * weights[k];
if(b + 4 < B)
c4 += x4d(b + 4, m, start_h + p, start_w + q) * weights[k];
if(b + 5 < B)
c5 += x4d(b + 5, m, start_h + p, start_w + q) * weights[k];
if(b + 6 < B)
c6 += x4d(b + 6, m, start_h + p, start_w + q) * weights[k];
if(b + 7 < B)
c7 += x4d(b + 7, m, start_h + p, start_w + q) * weights[k];
}
}
y4d(b, m, 0, tx) = c0;
if(b + 1 < B)
y4d(b + 1, m, 0, tx) = c1;
if(b + 2 < B)
y4d(b + 2, m, 0, tx) = c2;
if(b + 3 < B)
y4d(b + 3, m, 0, tx) = c3;
if(b + 4 < B)
y4d(b + 4, m, 0, tx) = c4;
if(b + 5 < B)
y4d(b + 5, m, 0, tx) = c5;
if(b + 6 < B)
y4d(b + 6, m, 0, tx) = c6;
if(b + 7 < B)
y4d(b + 7, m, 0, tx) = c7;
}
#undef y4d
#undef x4d
}
__global__ void depthwise_convNew12(float* const __restrict__ y,
const float* const __restrict__ x,
const float* const __restrict__ w,
const int B, const int M,
const int H, const int W, const int KH,
const int KW, const int H_out, const int W_out,
const int H_pad, const int W_pad,
const int H_stride, const int W_stride)
{
#define y4d(i3, i2, i1, i0) y[(i3) * (M * H_out * W_out) + (i2) * (H_out * W_out) + (i1) * (W_out) + i0]
#define x4d(i3, i2, i1, i0) x[(i3) * (M * H * W) + (i2) * (H * W) + (i1) * (W) + i0]
const int num = 12;
const int b = num * blockIdx.x;
const int m = (blockIdx.y * blockDim.x + threadIdx.x)/ (H_out * W_out);
if(m < M){
const int tx = (blockIdx.y * blockDim.x + threadIdx.x) % (H_out * W_out);
const int start_h = (tx / W_out) * H_stride - H_pad;
const int start_w = (tx % W_out) * W_stride - W_pad;
float c0 = 0;
float c1 = 0;
float c2 = 0;
float c3 = 0;
float c4 = 0;
float c5 = 0;
float c6 = 0;
float c7 = 0;
float c8 = 0;
float c9 = 0;
float c10 = 0;
float c11 = 0;
const float* weights = &w[m * KH * KW];
for (int k = 0; k < KH * KW; k++) {
int p = k / KW;
int q = k % KW;
if (start_h + p > -1 && start_h + p < H &&
start_w + q > -1 && start_w + q < W) {
c0 += x4d(b, m, start_h + p, start_w + q) * weights[k];
if(b + 1 < B)
c1 += x4d(b + 1, m, start_h + p, start_w + q) * weights[k];
if(b + 2 < B)
c2 += x4d(b + 2, m, start_h + p, start_w + q) * weights[k];
if(b + 3 < B)
c3 += x4d(b + 3, m, start_h + p, start_w + q) * weights[k];
if(b + 4 < B)
c4 += x4d(b + 4, m, start_h + p, start_w + q) * weights[k];
if(b + 5 < B)
c5 += x4d(b + 5, m, start_h + p, start_w + q) * weights[k];
if(b + 6 < B)
c6 += x4d(b + 6, m, start_h + p, start_w + q) * weights[k];
if(b + 7 < B)
c7 += x4d(b + 7, m, start_h + p, start_w + q) * weights[k];
if(b + 8 < B)
c8 += x4d(b + 8, m, start_h + p, start_w + q) * weights[k];
if(b + 9 < B)
c9 += x4d(b + 9, m, start_h + p, start_w + q) * weights[k];
if(b + 10 < B)
c10 += x4d(b + 10, m, start_h + p, start_w + q) * weights[k];
if(b + 11 < B)
c11 += x4d(b + 11, m, start_h + p, start_w + q) * weights[k];
}
}
y4d(b, m, 0, tx) = c0;
if(b + 1 < B)
y4d(b + 1, m, 0, tx) = c1;
if(b + 2 < B)
y4d(b + 2, m, 0, tx) = c2;
if(b + 3 < B)
y4d(b + 3, m, 0, tx) = c3;
if(b + 4 < B)
y4d(b + 4, m, 0, tx) = c4;
if(b + 5 < B)
y4d(b + 5, m, 0, tx) = c5;
if(b + 6 < B)
y4d(b + 6, m, 0, tx) = c6;
if(b + 7 < B)
y4d(b + 7, m, 0, tx) = c7;
if(b + 8 < B)
y4d(b + 8, m, 0, tx) = c8;
if(b + 9 < B)
y4d(b + 9, m, 0, tx) = c9;
if(b + 10 < B)
y4d(b + 10, m, 0, tx) = c10;
if(b + 11 < B)
y4d(b + 11, m, 0, tx) = c11;
}
#undef y4d
#undef x4d
}
void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
int vertical_pad, int horizontal_pad,
int vertical_stride, int horizontal_stride,
......@@ -176,22 +474,37 @@ void* tensorConvCutlass(void* input_ptr, void* filter_ptr,
KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride);
}*/
dim3 grid((n / 8), c);
/*
dim3 grid((n / 12), c);
dim3 block(h * w);
depthwise_conv8 <<<grid, block >>> ((float*)output->gpu_data,
depthwise_conv12 <<<grid, block >>> ((float*)output->gpu_data,
(float*)input->gpu_data, (float*)filter->gpu_data,
input->dims.dim_sizes[0], input->dims.dim_sizes[1], input->dims.dim_sizes[2], input->dims.dim_sizes[3],
KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride);
if(n % 8 > 0){
dim3 grid2((n % 8), c);
if(n % 12 > 0){
dim3 grid2((n % 12), c);
dim3 block(h * w);
depthwise_conv <<<grid, block >>> ((float*)output->gpu_data,
(float*)input->gpu_data, (float*)filter->gpu_data,
input->dims.dim_sizes[0], input->dims.dim_sizes[1], input->dims.dim_sizes[2], input->dims.dim_sizes[3],
KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride, 8 * (n/8));
}
KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride, 12 * (n/12));
}
*/
int blockSize;
if(h * w > 1023)
blockSize = 256;
else
blockSize = 128;
dim3 grid(((n + 7)/ 8), (c * h * w + blockSize - 1)/ blockSize);
dim3 block(blockSize);
depthwise_convNew8<<<grid, block>>> ((float*)output->gpu_data,
(float*)input->gpu_data, (float*)filter->gpu_data,
input->dims.dim_sizes[0], input->dims.dim_sizes[1], input->dims.dim_sizes[2], input->dims.dim_sizes[3],
KH, KW, h, w, vertical_pad, horizontal_pad, vertical_stride, horizontal_stride);
}
else {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment