Something went wrong on our end
wino_buffer.cpp 15.67 KiB
#ifndef _WINO_BUFFER_HPP_
#define _WINO_BUFFER_HPP_
#include <ap_int.h>
#include <hls_stream.h>
#include "wino_macro.h"
#include "../testbench/debug.hpp"
//template<int dummy>
void input_feed(
ap_uint<16> input_buffer[8][16][INPUT_BUFFER_DEPTH],
hls::stream< ap_uint<32*36> > &input_tile_stream1,
hls::stream< ap_uint<32*36> > &input_tile_stream2,
ap_uint<16> start_row_idx,
ap_uint<16> output_width,
ap_uint<16> input_height,
ap_uint<16> input_width,
ap_uint<16> input_width_ceildiv_16,
ap_uint<16> input_depth_aligned_8,
ap_uint<16> pad_size,
ap_uint<16> weight_outdepth_load_number,
ap_uint<8> window_size,
ap_uint<3> winograd_output_tile_size
)
{
#pragma HLS array_partition variable=input_buffer dim=1 complete
#pragma HLS array_partition variable=input_buffer dim=2 complete
int tile_step = winograd_output_tile_size*2;
for(ap_uint<16> depth_tile_idx=0,depth_base_address=0;depth_tile_idx<input_depth_aligned_8/8;depth_tile_idx++,depth_base_address+=input_width_ceildiv_16)
{
for(int weight_load_outdepth_cnt=0;weight_load_outdepth_cnt<weight_outdepth_load_number;weight_load_outdepth_cnt++)
{
int start_col_idx =0;
int depth_address =0;
int row_tile_offset =0;
for(int window_row_offset=0; window_row_offset<window_size; window_row_offset++)
{
for(int window_col_offset=0; window_col_offset<window_size; window_col_offset++)
{
ap_uint<1> row_legal_flag[6];
#pragma HLS array_partition variable=row_legal_flag complete
ap_uint<3> row_bank_idx[6];
#pragma HLS array_partition variable=row_bank_idx complete
ap_uint<1> row_addres_offset[8];
#pragma HLS array_partition variable=row_addres_offset complete
ap_uint<4> start_row_mod16 = start_row_idx-pad_size+window_row_offset;
ap_uint<3> breakpoint = start_row_mod16.range(2,0);
for(int i=0;i<8;i++)
{
#pragma HLS unroll
if(i< breakpoint)
row_addres_offset[i] = ~start_row_mod16[3];
else
row_addres_offset[i] = start_row_mod16[3];
}
for(int i=0;i<6;i++)
{
#pragma HLS unroll
row_bank_idx[i] = start_row_idx+i-pad_size+window_row_offset;
row_legal_flag[i] = ( start_row_idx+i-pad_size >=0 && start_row_idx+i-pad_size < input_height);
}
ap_uint<10> input_col_idx[12];
#pragma HLS array_partition variable=input_col_idx complete
ap_int<16> input_head_col_idx=-pad_size;
for(int i=0;i<6;i++)
{
#pragma HLS unroll
input_col_idx[i]=i-pad_size+window_col_offset;
input_col_idx[i+6]=i+winograd_output_tile_size-pad_size+window_col_offset;
}
for(ap_uint<10> out_col_idx=0;out_col_idx < output_width; out_col_idx+=tile_step)
{
ap_uint<6> input_buffer_address_by8[16];
#pragma HLS array_partition variable=input_buffer_address_by8 complete
ap_uint<INPUT_BUFFER_DEPTH_BITWIDTH-4> input_head_col_by16 = input_head_col_idx.range(10,4)+depth_base_address;
ap_uint<4> input_buffer_address_break_point = input_head_col_idx.range(3,0);
ap_uint<1> col_legal_flag[12];
#pragma HLS array_partition variable=col_legal_flag complete
for(int i=0;i<12; i++)
{
#pragma HLS unroll
col_legal_flag[i]= ( input_col_idx[i] >=0 && input_col_idx[i] < input_width);
}
for(int i=0;i<16;i++)
{
if(i>=input_buffer_address_break_point)
input_buffer_address_by8[i] = input_head_col_by16;
else
input_buffer_address_by8[i] = input_head_col_by16+1;
}
DEPTH:for(ap_uint<4> depth_idx_in_tile=0;depth_idx_in_tile<8;depth_idx_in_tile++)
{
#pragma HLS pipeline
ap_uint<INPUT_BUFFER_DEPTH_BITWIDTH> input_buffer_address[8][16];
#pragma HLS array_partition variable=input_buffer_address complete
#pragma HLS array_partition variable=col_legal_flag complete
ap_uint<16> input_buffer_val[8][16];
#pragma HLS array_partition variable=input_buffer_val complete
for(int j=0;j<8;j++)
{
#pragma HLS unroll
for(int i=0;i<16;i++)
{
#pragma HLS unroll
input_buffer_address[j][i].range(INPUT_BUFFER_DEPTH_BITWIDTH-1,0) = (row_addres_offset[j],( input_buffer_address_by8[i], (ap_uint<3>) depth_idx_in_tile.range(2,0) ));
}
}
for(int i=0;i<8;i++)
{
#pragma HLS unroll
for(int j=0;j<16;j++)
{
#pragma HLS unroll
input_buffer_val[i][j]=input_buffer[i][j][ (ap_uint<10>) input_buffer_address[i][j] ];
}
}
// printf("depth %d\n", (int) depth_idx_in_tile);
// for(int i=0;i<8;i++)
// {
// for(int j=0;j<16;j++)
// {
// printf("[%3d]", (unsigned char) input_buffer_val[i][j].range(7,0), (unsigned char) input_buffer_val[i][j].range(15,8));
// }
// printf("\n");
// }
ap_uint<16> input_plane_tile_row[6][16];
#pragma HLS array_partition variable=input_plane_tile_row dim=1 complete
#pragma HLS array_partition variable=input_plane_tile_row dim=2 complete
for(int j=0;j<16;j++)
{
#pragma HLS unroll
for(int i=0;i<6;i++)
{
#pragma HLS unroll
if(row_legal_flag[i])
{
input_plane_tile_row[i][j]=input_buffer_val[row_bank_idx[i]][j];
}
else
{
input_plane_tile_row[i][j]=0;
}
}
}
ap_uint<16> input_plane_tile[6][12];
#pragma HLS array_partition variable=input_plane_tile complete
for(int i=0;i<6;i++)
{
#pragma HLS unroll
for(int j=0;j<12;j++)
{
#pragma HLS unroll
if(col_legal_flag[j])
input_plane_tile[i][j]=input_plane_tile_row[i][ (ap_uint<4>) input_col_idx[j].range(3,0) ];
else
input_plane_tile[i][j]=0;
}
}
#if DEBUG_FILE_PRINT
attach_streaming_content<0>(input_plane_tile,start_row_idx-pad_size,out_col_idx-pad_size,depth_tile_idx*8+depth_idx_in_tile,"input_stream_content.txt");
#endif
ap_int<8> in[36][4];
for(int k=0;k<6;k++)
{
#pragma HLS unroll
for(int l=0;l<6;l++)
{
#pragma HLS unroll
for(ap_uint<3> j=0;j<4;j++)
{
#pragma HLS unroll
in[k*6+l][j].range(7,0)= input_plane_tile[k][ j/2*6+l].range( j%2*8+7 ,j%2*8);
}
// printf("[%4x %4x]", (unsigned short) in[k*6+l][1], (unsigned short) in[k*6+l][3]);
}
// printf("\n");
}
// printf("\n");
ap_int<16> dB[6][6][4];
#pragma HLS array_partition variable=dB complete
for(int i=0;i<6;i++)
{
#pragma HLS unroll
for(int j=0;j<4;j++)
{
dB[i][0][j]=(((ap_int<16>)in[i*6+0][j]-(ap_int<16>)in[i*6+2][j])<<2) - (ap_int<16>)in[i*6+2][j] + (ap_int<16>)in[i*6+4][j];
dB[i][1][j]=(ap_int<16>)in[i*6+3][j]+(ap_int<16>)in[i*6+4][j]-(((ap_int<16>)in[i*6+1][j]+(ap_int<16>)in[i*6+2][j])<<2);
dB[i][2][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+2][j])<<2)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+3][j];
dB[i][3][j]=(((ap_int<16>)in[i*6+3][j]-(ap_int<16>)in[i*6+1][j])<<1)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+2][j];
dB[i][4][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+3][j])<<1)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+2][j];
dB[i][5][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+3][j])<<2)+(ap_int<16>)in[i*6+5][j]-(ap_int<16>)in[i*6+3][j];
}
}
ap_int<16> BTDB[6][6][4];
#pragma HLS array_partition variable=BTDB complete
for(int i=0;i<6;i++)
{
#pragma HLS unroll
for(int j=0;j<4;j++)
{
BTDB[0][i][j]=((dB[0][i][j]-dB[2][i][j])<<2) - dB[2][i][j] + dB[4][i][j];
BTDB[1][i][j]=dB[3][i][j]+dB[4][i][j]-((dB[1][i][j]+dB[2][i][j])<<2);
BTDB[2][i][j]=((dB[1][i][j]-dB[2][i][j])<<2)+dB[4][i][j]-dB[3][i][j];
BTDB[3][i][j]=((dB[3][i][j]-dB[1][i][j])<<1)+dB[4][i][j]-dB[2][i][j];
BTDB[4][i][j]=((dB[1][i][j]-dB[3][i][j])<<1)+dB[4][i][j]-dB[2][i][j];
BTDB[5][i][j]=((dB[1][i][j]-dB[3][i][j])<<2)+dB[5][i][j]-dB[3][i][j];
}
}
ap_uint<32*36> stream_out1;
ap_uint<32*36> stream_out2;
for(int i=0;i<6;i++)
{
#pragma HLS unroll
for(int j=0;j<6;j++)
{
#pragma HLS unroll
stream_out1.range( (i*6+j)*32+31,(i*6+j)*32) =(BTDB[i][j][1],BTDB[i][j][0]);
stream_out2.range( (i*6+j)*32+31,(i*6+j)*32) =(BTDB[i][j][3],BTDB[i][j][2]);
}
}
#if DEBUG_FILE_PRINT
attach_streaming_wino<0>(stream_out1,stream_out2,start_row_idx-pad_size,input_head_col_idx,depth_tile_idx*8+depth_idx_in_tile,"input_stream_wino.txt");
#endif
input_tile_stream1<<stream_out1;
input_tile_stream2<<stream_out2;
}
for(int i=0;i<12; i++)
{
#pragma HLS unroll
input_col_idx[i]+=tile_step;
}
input_head_col_idx+=tile_step;
}
}//depth_tile1
}// window_col
}
}//window_row
ap_uint<32*36> stream_out1=0;
ap_uint<32*36> stream_out2=0;
input_tile_stream1<<stream_out1;
input_tile_stream1<<stream_out1;
input_tile_stream2<<stream_out2;
input_tile_stream2<<stream_out2;
}
//template<int dummy>
void load_weight_ddr_one_port(
ap_uint<128>* DDR_interface,
ap_uint<64> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
int load_number,
int port_load_number,
ap_uint<1> pingpong)
{
ap_uint<4> counter=0;
ap_uint<9> port_load_cnt=0;
ap_uint<9> buffer_address_offset=0;
ap_uint<2> buffer_idx;
for(int address = 0; address<load_number; address++)
{
ap_uint<128> temp128 = DDR_interface[address];
ap_uint<64> temp64[2];
#pragma HLS array_partition variable = temp64 complete
for(int i=0;i<2;i++)
{
#pragma HLS unroll
temp64[i]=temp128.range(i*64+63,i*64);
}
ap_uint<10> buffer_address = (pingpong,buffer_address_offset);
if(counter==0)
{
weight_buff[buffer_idx][0][buffer_address]=temp64[0];
weight_buff[buffer_idx][1][buffer_address]=temp64[1];
}
else if(counter == 1)
{
weight_buff[buffer_idx][2][buffer_address]=temp64[0];
weight_buff[buffer_idx][3][buffer_address]=temp64[1];
}
else if(counter == 2)
{
weight_buff[buffer_idx][4][buffer_address]=temp64[0];
weight_buff[buffer_idx][5][buffer_address]=temp64[1];
}
else if(counter == 3)
{
weight_buff[buffer_idx][6][buffer_address]=temp64[0];
weight_buff[buffer_idx][7][buffer_address]=temp64[1];
}
else if(counter == 4)
{
weight_buff[buffer_idx][8][buffer_address]=temp64[0];
}
if(port_load_cnt==port_load_number-1)
{
buffer_address_offset =0;
}
else if(counter==4 )
{
buffer_address_offset++;
}
if(port_load_cnt==port_load_number-1)
{
buffer_idx++;
port_load_cnt=0;
}
else
{
port_load_cnt++;
}
if(counter==4)
{
counter=0;
}
else
{
counter++;
}
}
}
//template<int dummy>
void load_weight_ddr(
ap_uint<128>* weight_DDR0,
ap_uint<128>* weight_DDR1,
ap_uint<128>* weight_DDR2,
ap_uint<128>* weight_DDR3,
ap_uint<64> weight_buff[4][WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
int DDR_offset,
int load_number,
int port_load_number,
ap_uint<1> skip_flag,
ap_uint<1> pingpong)
{
if(skip_flag)
return;
// printf("Load number %d, offset %d\n", load_number, DDR_offset);
// fflush(stdout);
load_weight_ddr_one_port(
weight_DDR0+DDR_offset,
weight_buff[0],
load_number,
port_load_number,
pingpong);
load_weight_ddr_one_port(
weight_DDR1+DDR_offset,
weight_buff[1],
load_number,
port_load_number,
pingpong);
load_weight_ddr_one_port(
weight_DDR2+DDR_offset,
weight_buff[2],
load_number,
port_load_number,
pingpong);
load_weight_ddr_one_port(
weight_DDR3+DDR_offset,
weight_buff[3],
load_number,
port_load_number,
pingpong);
}
//template<int dummy>
void weight_stream(
ap_uint<64> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
hls::stream<ap_uint<16*36> > & weight_stream0,
hls::stream<ap_uint<16*36> > & weight_stream1,
#endif
int row_repeat_time, // output_row ceil_div out tiles
int weight_feed_total_size,
ap_int<1> pingpong
)
{
int weight_feed_total_size_by2 = weight_feed_total_size/2;
for(int i=0;i<row_repeat_time;i++)
{
for(ap_uint<9> buffer_addr_offset=0; buffer_addr_offset<weight_feed_total_size_by2; buffer_addr_offset++)
{
ap_uint<64> temp18[WEIGHT_FEED_NUMBER_PER_PORT][9];
#pragma HLS array_partition variable = temp18 complete
ap_uint<10> buffer_addr=(pingpong,buffer_addr_offset);
for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
{
#pragma HLS unroll
for(int j18=0;j18<9;j18++)
{
#pragma HLS unroll
temp18[buffer_idx][j18]=weight_buff[buffer_idx][j18][buffer_addr];
}
}
ap_uint<16*36> temp16x36[WEIGHT_FEED_NUMBER_PER_PORT];
for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
{
for(int j18=0;j18<9;j18++)
{
#pragma HLS unroll
temp16x36[buffer_idx].range(j18*64+63,j18*64)=temp18[buffer_idx][j18];
}
}
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
weight_stream0<<temp16x36[0];
weight_stream1<<temp16x36[1];
#endif
}
}
}
//template<int dummy>
void weight_feed(
ap_uint<128>* weight_DDR0,
ap_uint<128>* weight_DDR1,
ap_uint<128>* weight_DDR2,
ap_uint<128>* weight_DDR3,
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
hls::stream<ap_uint<16*36> > & weight_stream0_0,
hls::stream<ap_uint<16*36> > & weight_stream0_1,
hls::stream<ap_uint<16*36> > & weight_stream1_0,
hls::stream<ap_uint<16*36> > & weight_stream1_1,
hls::stream<ap_uint<16*36> > & weight_stream2_0,
hls::stream<ap_uint<16*36> > & weight_stream2_1,
hls::stream<ap_uint<16*36> > & weight_stream3_0,
hls::stream<ap_uint<16*36> > & weight_stream3_1,
#endif
ap_uint<16> weight_total_load_number,
ap_uint<16> weight_total_feed_size,
ap_uint<16> ddr_load_length,
ap_uint<16> ddr_load_length_per_feed,
ap_uint<16> row_repeat_times,
ap_uint<16> first_flag,
ap_uint<16> last_flag
)
{
static ap_uint<16> DDR_offset;
static ap_uint<16> DDR_load_cnt;
static ap_uint<1> pingpong;
static ap_uint<64> weight_buff[4][WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH];
if(first_flag)
DDR_offset=ddr_load_length;
DDR_load_cnt=1;
pingpong = 0;
load_weight_ddr(
weight_DDR0,
weight_DDR1,
weight_DDR2,
weight_DDR3,
weight_buff,
0,
ddr_load_length,
ddr_load_length_per_feed,
~first_flag,
pingpong);
for(int cnt=0;cnt<weight_total_load_number;cnt++)
{
// printf("load cnt%d\n", cnt);
pingpong = ~pingpong;
load_weight_ddr(
weight_DDR0,
weight_DDR1,
weight_DDR2,
weight_DDR3,
weight_buff,
DDR_offset,
ddr_load_length,
ddr_load_length_per_feed,
last_flag & (DDR_load_cnt==0) ,
pingpong);
weight_stream(
weight_buff[0],
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
weight_stream0_0,
weight_stream0_1,
#endif
row_repeat_times, // output_row ceil_div out tiles
weight_total_feed_size,
~pingpong);
weight_stream(
weight_buff[1],
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
weight_stream1_0,
weight_stream1_1,
#endif
row_repeat_times, // output_row ceil_div out tiles
weight_total_feed_size,
~pingpong);
weight_stream(
weight_buff[2],
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
weight_stream2_0,
weight_stream2_1,
#endif
row_repeat_times, // output_row ceil_div out tiles
weight_total_feed_size,
~pingpong);
weight_stream(
weight_buff[3],
#if WEIGHT_FEED_NUMBER_PER_PORT == 2
weight_stream3_0,
weight_stream3_1,
#endif
row_repeat_times, // output_row ceil_div out tiles
weight_total_feed_size,
~pingpong);
if(DDR_load_cnt == weight_total_load_number-1)
{
DDR_load_cnt = 0;
DDR_offset = 0;
}
else
{
DDR_load_cnt+=1;
DDR_offset+=ddr_load_length;
}
}
}
#endif