Skip to content
Snippets Groups Projects
wino_buffer.cpp 15.67 KiB


#ifndef _WINO_BUFFER_HPP_
#define _WINO_BUFFER_HPP_

#include <ap_int.h>
#include <hls_stream.h>
#include "wino_macro.h"

#include "../testbench/debug.hpp"

//template<int dummy>
void input_feed(
	ap_uint<16> input_buffer[8][16][INPUT_BUFFER_DEPTH],
	hls::stream< ap_uint<32*36> > &input_tile_stream1, 
	hls::stream< ap_uint<32*36> > &input_tile_stream2,
	ap_uint<16> start_row_idx,
	ap_uint<16> output_width,
	ap_uint<16> input_height,
	ap_uint<16> input_width,
	ap_uint<16> input_width_ceildiv_16,
	ap_uint<16> input_depth_aligned_8,
	ap_uint<16> pad_size,
    ap_uint<16> weight_outdepth_load_number,
	ap_uint<8> window_size,
	ap_uint<3> winograd_output_tile_size
)
{
	#pragma HLS array_partition variable=input_buffer dim=1 complete
	#pragma HLS array_partition variable=input_buffer dim=2 complete
	int tile_step =  winograd_output_tile_size*2;




    for(ap_uint<16> depth_tile_idx=0,depth_base_address=0;depth_tile_idx<input_depth_aligned_8/8;depth_tile_idx++,depth_base_address+=input_width_ceildiv_16)
	{
		for(int weight_load_outdepth_cnt=0;weight_load_outdepth_cnt<weight_outdepth_load_number;weight_load_outdepth_cnt++)
		{
			int start_col_idx =0;
			int depth_address =0;
			int row_tile_offset =0;


			for(int window_row_offset=0; window_row_offset<window_size; window_row_offset++)
			{
				for(int window_col_offset=0; window_col_offset<window_size; window_col_offset++)
				{
					ap_uint<1> row_legal_flag[6];
					#pragma HLS array_partition variable=row_legal_flag complete

					ap_uint<3> row_bank_idx[6];
					#pragma HLS array_partition variable=row_bank_idx complete

					ap_uint<1> row_addres_offset[8];
					#pragma HLS array_partition variable=row_addres_offset complete

					ap_uint<4> start_row_mod16 = start_row_idx-pad_size+window_row_offset;
					ap_uint<3> breakpoint = start_row_mod16.range(2,0);


					for(int i=0;i<8;i++)
					{
						#pragma HLS unroll
						if(i< breakpoint)
							row_addres_offset[i] = ~start_row_mod16[3];
						else
							row_addres_offset[i] = start_row_mod16[3];
					}			
					
					for(int i=0;i<6;i++)
					{
						#pragma HLS unroll
						row_bank_idx[i] = start_row_idx+i-pad_size+window_row_offset;
						row_legal_flag[i] = ( start_row_idx+i-pad_size >=0 && start_row_idx+i-pad_size < input_height);
					}


					
					ap_uint<10> input_col_idx[12];
	#pragma HLS array_partition variable=input_col_idx complete
					ap_int<16> input_head_col_idx=-pad_size;

					for(int i=0;i<6;i++)
					{
						#pragma HLS unroll
						input_col_idx[i]=i-pad_size+window_col_offset;
						input_col_idx[i+6]=i+winograd_output_tile_size-pad_size+window_col_offset;
					}

					for(ap_uint<10> out_col_idx=0;out_col_idx < output_width; out_col_idx+=tile_step)
					{		
					

						ap_uint<6> input_buffer_address_by8[16];
	#pragma HLS array_partition variable=input_buffer_address_by8 complete
						ap_uint<INPUT_BUFFER_DEPTH_BITWIDTH-4>  input_head_col_by16 = input_head_col_idx.range(10,4)+depth_base_address;
						ap_uint<4>  input_buffer_address_break_point = input_head_col_idx.range(3,0);
						
						ap_uint<1> col_legal_flag[12];
	#pragma HLS array_partition variable=col_legal_flag complete
						for(int i=0;i<12; i++)
						{
							#pragma HLS unroll
							col_legal_flag[i]= ( input_col_idx[i] >=0 && input_col_idx[i] < input_width);
						}
						for(int i=0;i<16;i++)
						{
							if(i>=input_buffer_address_break_point)
								input_buffer_address_by8[i] = input_head_col_by16;
							else
								input_buffer_address_by8[i] = input_head_col_by16+1;
						}

						DEPTH:for(ap_uint<4> depth_idx_in_tile=0;depth_idx_in_tile<8;depth_idx_in_tile++)
						{	

	#pragma HLS pipeline
							ap_uint<INPUT_BUFFER_DEPTH_BITWIDTH> input_buffer_address[8][16];
	#pragma HLS array_partition variable=input_buffer_address complete

	#pragma HLS array_partition variable=col_legal_flag complete
							ap_uint<16> input_buffer_val[8][16];
	#pragma HLS array_partition variable=input_buffer_val complete
							for(int j=0;j<8;j++)
							{
							#pragma HLS unroll
								for(int i=0;i<16;i++)
								{
									#pragma HLS unroll
									input_buffer_address[j][i].range(INPUT_BUFFER_DEPTH_BITWIDTH-1,0) = (row_addres_offset[j],( input_buffer_address_by8[i], (ap_uint<3>) depth_idx_in_tile.range(2,0) ));
								}
							}

							for(int i=0;i<8;i++)
							{
								#pragma HLS unroll
								for(int j=0;j<16;j++)
								{
									#pragma HLS unroll
									input_buffer_val[i][j]=input_buffer[i][j][ (ap_uint<10>) input_buffer_address[i][j] ];   
								}
							}
							// printf("depth %d\n", (int) depth_idx_in_tile);
							// for(int i=0;i<8;i++)
							// {
							// 	for(int j=0;j<16;j++)
							// 	{
							// 		printf("[%3d]", (unsigned char) input_buffer_val[i][j].range(7,0),  (unsigned char) input_buffer_val[i][j].range(15,8));
							// 	}
							// 	printf("\n");
							// }


							ap_uint<16> input_plane_tile_row[6][16];



	#pragma HLS array_partition variable=input_plane_tile_row dim=1 complete

	#pragma HLS array_partition variable=input_plane_tile_row dim=2 complete
							for(int j=0;j<16;j++)
							{
	#pragma HLS unroll
								for(int i=0;i<6;i++)
								{
	#pragma HLS unroll
									if(row_legal_flag[i])
									{
										input_plane_tile_row[i][j]=input_buffer_val[row_bank_idx[i]][j];
									}
									else
									{
										input_plane_tile_row[i][j]=0;
									}
								}
							}


							ap_uint<16> input_plane_tile[6][12];
							#pragma HLS array_partition variable=input_plane_tile complete
							for(int i=0;i<6;i++)
							{
							#pragma HLS unroll
								for(int j=0;j<12;j++)
								{
								#pragma HLS unroll
									if(col_legal_flag[j])
										input_plane_tile[i][j]=input_plane_tile_row[i][  (ap_uint<4>) input_col_idx[j].range(3,0) ];
									else
										input_plane_tile[i][j]=0;
								}
							}

							#if DEBUG_FILE_PRINT
							attach_streaming_content<0>(input_plane_tile,start_row_idx-pad_size,out_col_idx-pad_size,depth_tile_idx*8+depth_idx_in_tile,"input_stream_content.txt");
							#endif

							ap_int<8> in[36][4];

		
							for(int k=0;k<6;k++)
							{
							#pragma HLS unroll
								for(int l=0;l<6;l++)
								{
								#pragma HLS unroll
									for(ap_uint<3> j=0;j<4;j++)
									{
									#pragma HLS unroll
										in[k*6+l][j].range(7,0)=  input_plane_tile[k][ j/2*6+l].range( j%2*8+7 ,j%2*8);
									}
									// printf("[%4x %4x]", (unsigned short) in[k*6+l][1], (unsigned short) in[k*6+l][3]);
								}
								// printf("\n");
							}
							// printf("\n");
							





							ap_int<16>  dB[6][6][4];
							#pragma HLS array_partition variable=dB complete
	
							for(int i=0;i<6;i++)
							{
							#pragma HLS unroll
								for(int j=0;j<4;j++)
								{
									dB[i][0][j]=(((ap_int<16>)in[i*6+0][j]-(ap_int<16>)in[i*6+2][j])<<2) - (ap_int<16>)in[i*6+2][j] + (ap_int<16>)in[i*6+4][j];
									dB[i][1][j]=(ap_int<16>)in[i*6+3][j]+(ap_int<16>)in[i*6+4][j]-(((ap_int<16>)in[i*6+1][j]+(ap_int<16>)in[i*6+2][j])<<2);
									dB[i][2][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+2][j])<<2)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+3][j];
									dB[i][3][j]=(((ap_int<16>)in[i*6+3][j]-(ap_int<16>)in[i*6+1][j])<<1)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+2][j];
									dB[i][4][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+3][j])<<1)+(ap_int<16>)in[i*6+4][j]-(ap_int<16>)in[i*6+2][j];
									dB[i][5][j]=(((ap_int<16>)in[i*6+1][j]-(ap_int<16>)in[i*6+3][j])<<2)+(ap_int<16>)in[i*6+5][j]-(ap_int<16>)in[i*6+3][j];	
								}						
							}






							ap_int<16>  BTDB[6][6][4];
							#pragma HLS array_partition variable=BTDB complete
							for(int i=0;i<6;i++)
							{
							#pragma HLS unroll
								for(int j=0;j<4;j++)
								{
									BTDB[0][i][j]=((dB[0][i][j]-dB[2][i][j])<<2) - dB[2][i][j] + dB[4][i][j];
									BTDB[1][i][j]=dB[3][i][j]+dB[4][i][j]-((dB[1][i][j]+dB[2][i][j])<<2);
									BTDB[2][i][j]=((dB[1][i][j]-dB[2][i][j])<<2)+dB[4][i][j]-dB[3][i][j];
									BTDB[3][i][j]=((dB[3][i][j]-dB[1][i][j])<<1)+dB[4][i][j]-dB[2][i][j];
									BTDB[4][i][j]=((dB[1][i][j]-dB[3][i][j])<<1)+dB[4][i][j]-dB[2][i][j];
									BTDB[5][i][j]=((dB[1][i][j]-dB[3][i][j])<<2)+dB[5][i][j]-dB[3][i][j];
								}
							}

							ap_uint<32*36> stream_out1;
							ap_uint<32*36> stream_out2;
							for(int i=0;i<6;i++)
							{
							#pragma HLS unroll
								for(int j=0;j<6;j++)
								{
								#pragma HLS unroll
									stream_out1.range( (i*6+j)*32+31,(i*6+j)*32) =(BTDB[i][j][1],BTDB[i][j][0]);
									stream_out2.range( (i*6+j)*32+31,(i*6+j)*32) =(BTDB[i][j][3],BTDB[i][j][2]);
								}
							}
							#if DEBUG_FILE_PRINT
							attach_streaming_wino<0>(stream_out1,stream_out2,start_row_idx-pad_size,input_head_col_idx,depth_tile_idx*8+depth_idx_in_tile,"input_stream_wino.txt");
							#endif
							input_tile_stream1<<stream_out1;
							input_tile_stream2<<stream_out2;
						}

						for(int i=0;i<12; i++)
						{
							#pragma HLS unroll
							input_col_idx[i]+=tile_step;
						}
						input_head_col_idx+=tile_step;
					}
				}//depth_tile1
			}// window_col
		}
	}//window_row

	ap_uint<32*36> stream_out1=0;
	ap_uint<32*36> stream_out2=0;
	input_tile_stream1<<stream_out1;
	input_tile_stream1<<stream_out1;
	input_tile_stream2<<stream_out2;
	input_tile_stream2<<stream_out2;
}



//template<int dummy>
void load_weight_ddr_one_port(
	ap_uint<128>* DDR_interface,
	ap_uint<64> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
	int load_number,
	int port_load_number,
	ap_uint<1> pingpong)
{
	ap_uint<4> counter=0;

	ap_uint<9> port_load_cnt=0;

	ap_uint<9> buffer_address_offset=0;
	ap_uint<2> buffer_idx;
	

	for(int address = 0; address<load_number; address++)
	{
		ap_uint<128> temp128 = DDR_interface[address];
		ap_uint<64> temp64[2];
		#pragma HLS array_partition  variable = temp64 complete

		for(int i=0;i<2;i++)
		{
			#pragma HLS unroll
			temp64[i]=temp128.range(i*64+63,i*64);
		}

		ap_uint<10> buffer_address = (pingpong,buffer_address_offset);
		

		if(counter==0)
		{
			weight_buff[buffer_idx][0][buffer_address]=temp64[0];
			weight_buff[buffer_idx][1][buffer_address]=temp64[1];
		}
		else if(counter == 1)
		{
			weight_buff[buffer_idx][2][buffer_address]=temp64[0];
			weight_buff[buffer_idx][3][buffer_address]=temp64[1];
		}
		else if(counter == 2)
		{
			weight_buff[buffer_idx][4][buffer_address]=temp64[0];
			weight_buff[buffer_idx][5][buffer_address]=temp64[1];
		}
		else if(counter == 3)
		{
			weight_buff[buffer_idx][6][buffer_address]=temp64[0];
			weight_buff[buffer_idx][7][buffer_address]=temp64[1];
		}
		else if(counter == 4)
		{
			weight_buff[buffer_idx][8][buffer_address]=temp64[0];
		}

		if(port_load_cnt==port_load_number-1)
		{
			buffer_address_offset =0;
		}
		else if(counter==4 )
		{
			buffer_address_offset++;
		}

		if(port_load_cnt==port_load_number-1)
		{
			buffer_idx++;
			port_load_cnt=0;
		}
		else
		{
			port_load_cnt++;
		}

		if(counter==4)
		{
			counter=0;
		}
		else
		{
			counter++;
		}
	}
}

//template<int dummy>
void load_weight_ddr(
	ap_uint<128>* weight_DDR0,
	ap_uint<128>* weight_DDR1,
	ap_uint<128>* weight_DDR2,
	ap_uint<128>* weight_DDR3,
	ap_uint<64> weight_buff[4][WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
	int DDR_offset,
	int load_number,
	int port_load_number,
	ap_uint<1> skip_flag,
	ap_uint<1> pingpong)
{
	
	if(skip_flag)
		return;

	// printf("Load number %d, offset %d\n", load_number, DDR_offset);
	// fflush(stdout);

	load_weight_ddr_one_port(
	weight_DDR0+DDR_offset,
	weight_buff[0],
	load_number,
	port_load_number,
	pingpong);

	load_weight_ddr_one_port(
	weight_DDR1+DDR_offset,
	weight_buff[1],
	load_number,
	port_load_number,
	pingpong);

	load_weight_ddr_one_port(
	weight_DDR2+DDR_offset,
	weight_buff[2],
	load_number,
	port_load_number,
	pingpong);

	load_weight_ddr_one_port(
	weight_DDR3+DDR_offset,
	weight_buff[3],
	load_number,
	port_load_number,
	pingpong);
}

//template<int dummy>
void weight_stream(
	ap_uint<64> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH],
	#if WEIGHT_FEED_NUMBER_PER_PORT == 2
	hls::stream<ap_uint<16*36> > & weight_stream0,
	hls::stream<ap_uint<16*36> > & weight_stream1,
	#endif
	int row_repeat_time, // output_row ceil_div out tiles
	int weight_feed_total_size,
	ap_int<1> pingpong
)
{
	int weight_feed_total_size_by2 = weight_feed_total_size/2;
	for(int i=0;i<row_repeat_time;i++)
	{
		for(ap_uint<9> buffer_addr_offset=0; buffer_addr_offset<weight_feed_total_size_by2; buffer_addr_offset++)
		{
			ap_uint<64> temp18[WEIGHT_FEED_NUMBER_PER_PORT][9];
			#pragma HLS array_partition variable = temp18 complete

			ap_uint<10> buffer_addr=(pingpong,buffer_addr_offset);

			for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
			{
			#pragma HLS unroll
				for(int j18=0;j18<9;j18++)
				{
				#pragma HLS unroll
					temp18[buffer_idx][j18]=weight_buff[buffer_idx][j18][buffer_addr];
				}
			}
			ap_uint<16*36> temp16x36[WEIGHT_FEED_NUMBER_PER_PORT];
			
			for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
			{
				for(int j18=0;j18<9;j18++)
				{
				#pragma HLS unroll
				temp16x36[buffer_idx].range(j18*64+63,j18*64)=temp18[buffer_idx][j18];
				}
			}
			#if WEIGHT_FEED_NUMBER_PER_PORT == 2
			weight_stream0<<temp16x36[0];
			weight_stream1<<temp16x36[1];
			#endif
		}
	}
}


//template<int dummy>
void weight_feed(
	ap_uint<128>* weight_DDR0,
	ap_uint<128>* weight_DDR1,
	ap_uint<128>* weight_DDR2,
	ap_uint<128>* weight_DDR3,
	#if WEIGHT_FEED_NUMBER_PER_PORT == 2
	hls::stream<ap_uint<16*36> > & weight_stream0_0,
	hls::stream<ap_uint<16*36> > & weight_stream0_1,
	hls::stream<ap_uint<16*36> > & weight_stream1_0,
	hls::stream<ap_uint<16*36> > & weight_stream1_1,
	hls::stream<ap_uint<16*36> > & weight_stream2_0,
	hls::stream<ap_uint<16*36> > & weight_stream2_1,
	hls::stream<ap_uint<16*36> > & weight_stream3_0,
	hls::stream<ap_uint<16*36> > & weight_stream3_1,
	#endif
	ap_uint<16> weight_total_load_number,
	ap_uint<16> weight_total_feed_size,
	ap_uint<16> ddr_load_length,
	ap_uint<16> ddr_load_length_per_feed,
	ap_uint<16> row_repeat_times,
	ap_uint<16> first_flag,
	ap_uint<16> last_flag
)
{
	static ap_uint<16> DDR_offset;
	static ap_uint<16> DDR_load_cnt;
	static ap_uint<1> pingpong;

	static ap_uint<64> weight_buff[4][WEIGHT_FEED_NUMBER_PER_PORT][9][WEIGHT_BUFFER_DEPTH];

	if(first_flag)
		DDR_offset=ddr_load_length;
		DDR_load_cnt=1;
		pingpong = 0;

	load_weight_ddr(
	weight_DDR0,
	weight_DDR1,
	weight_DDR2,
	weight_DDR3,
	weight_buff,
	0,
	ddr_load_length,
	ddr_load_length_per_feed,
	~first_flag,
	pingpong);

	for(int cnt=0;cnt<weight_total_load_number;cnt++)
	{

		// printf("load cnt%d\n", cnt);
		pingpong = ~pingpong;

		load_weight_ddr(
		weight_DDR0,
		weight_DDR1,
		weight_DDR2,
		weight_DDR3,
		weight_buff,
		DDR_offset,
		ddr_load_length,
		ddr_load_length_per_feed,
		last_flag & (DDR_load_cnt==0) ,
		pingpong);

		weight_stream(
		weight_buff[0],
		#if WEIGHT_FEED_NUMBER_PER_PORT == 2
		weight_stream0_0,
		weight_stream0_1,
		#endif
		row_repeat_times, // output_row ceil_div out tiles
		weight_total_feed_size,
		~pingpong);

		weight_stream(
		weight_buff[1],
		#if WEIGHT_FEED_NUMBER_PER_PORT == 2
		weight_stream1_0,
		weight_stream1_1,
		#endif
		row_repeat_times, // output_row ceil_div out tiles
		weight_total_feed_size,
		~pingpong);

		weight_stream(
		weight_buff[2],
		#if WEIGHT_FEED_NUMBER_PER_PORT == 2
		weight_stream2_0,
		weight_stream2_1,
		#endif
		row_repeat_times, // output_row ceil_div out tiles
		weight_total_feed_size,
		~pingpong);

		weight_stream(
		weight_buff[3],
		#if WEIGHT_FEED_NUMBER_PER_PORT == 2
		weight_stream3_0,
		weight_stream3_1,
		#endif
		row_repeat_times, // output_row ceil_div out tiles
		weight_total_feed_size,
		~pingpong);

		if(DDR_load_cnt == weight_total_load_number-1)
		{
			DDR_load_cnt = 0;
			DDR_offset = 0;
		}
		else
		{
			DDR_load_cnt+=1;
			DDR_offset+=ddr_load_length;
		}
	}	
}

#endif