Skip to content
Snippets Groups Projects
wino_cell.hpp 6.49 KiB
Newer Older
  • Learn to ignore specific revisions
  • xliu79's avatar
    xliu79 committed
    #include <ap_int.h>
    #include <hls_stream.h>
    
    
    
    
    
    #define IN_WIDTH 8
    #define W_WIDTH 8
    #define OUT_WIDTH 32
    
    #define B_WIDTH_IN 8
    #define B_WIDTH_OUT 12
    #define BT_WIDTH_IN 12
    #define BT_WIDTH_OUT 16
    
    #define GTGG_WIDTH_IN 16
    #define GTGG_WIDTH_W 16
    #define GTGG_WIDTH_OUT 24
    
    #define A_WIDTH_IN 24
    #define A_WIDTH_OUT 28
    
    #define AT_WIDTH_IN 28
    #define AT_WIDTH_OUT 32
    
    
    
    template<int dummy>
    void wino_stream_ceil(
    		hls::stream< ap_uint<IN_WIDTH*36> > & top_stream_in,
    		hls::stream< ap_uint<IN_WIDTH*36> > & bottom_stream_out,
    		hls::stream< ap_uint<W_WIDTH*36> > &left_stream_in,
    		hls::stream< ap_uint<W_WIDTH*36> > &right_stream_out,
    		ap_uint<OUT_WIDTH> out_buffer[16][1024],
    		ap_uint<1> wino_flag,
    		int number)
    {
    
    	#pragma HLS array_partition variable=out_buffer dim=1 complete
    
    		ap_int<W_WIDTH> G1[36];
    	#pragma HLS array_partition variable=G1
    		ap_int<IN_WIDTH> in[36];
    	#pragma HLS array_partition variable=in
    	
    
    	for(int counter=0;counter<number;counter++)
    	{
    		#pragma HLS pipeline
    		ap_uint<IN_WIDTH*36> stream_in_temp;
    		ap_uint<W_WIDTH*36> stream_weight_temp;
    		top_stream_in>>stream_in_temp;
    		bottom_stream_out<<stream_in_temp;
    		left_stream_in>>stream_weight_temp;
    		right_stream_out<<stream_weight_temp;
    
    		for(int k=0;k<36;k++)
    		{
    		#pragma HLS unroll factor=36
    			G1[k].range(W_WIDTH-1,0)=stream_weight_temp.range(W_WIDTH-1+k*W_WIDTH,k*W_WIDTH);
    			in[k].range(IN_WIDTH-1,0)=stream_in_temp.range(IN_WIDTH-1+k*IN_WIDTH,k*IN_WIDTH);
    		}
    
    		ap_int<B_WIDTH_OUT>  dB[6][6];
    		#pragma HLS array_partition variable=dB complete
    		for(int i=0;i<6;i++)
    		{
    	#pragma HLS unroll
    			dB[i][0]=((in[i*6+0]-in[i*6+2])<<2) - in[i*6+2] + in[i*6+4];
    			dB[i][1]=in[i*6+3]+in[i*6+4]-((in[i*6+1]+in[i*6+2])<<2);
    			dB[i][2]=((in[i*6+1]-in[i*6+2])<<2)+in[i*6+4]-in[i*6+3];
    			dB[i][3]=((in[i*6+3]-in[i*6+1])<<1)+in[i*6+4]-in[i*6+2];
    			dB[i][4]=((in[i*6+1]-in[i*6+3])<<1)+in[i*6+4]-in[i*6+2];
    			dB[i][5]=((in[i*6+1]-in[i*6+3])<<2)+in[i*6+5]-in[i*6+3];
    		}
    
    		ap_int<BT_WIDTH_OUT>  BTdB[6][6];
    		#pragma HLS array_partition variable=BTdB complete
    		for(int i=0;i<6;i++)
    		{
    	#pragma HLS unroll
    			BTdB[0][i]=((dB[0][i]-dB[2][i])<<2) - dB[2][i] + dB[4][i];
    			BTdB[1][i]=dB[3][i]+dB[4][i]-((dB[1][i]+dB[2][i])<<2);
    			BTdB[2][i]=((dB[1][i]-dB[2][i])<<2)+dB[4][i]-dB[3][i];
    			BTdB[3][i]=((dB[3][i]-dB[1][i])<<1)+dB[4][i]-dB[2][i];
    			BTdB[4][i]=((dB[1][i]-dB[3][i])<<1)+dB[4][i]-dB[2][i];
    			BTdB[5][i]=((dB[1][i]-dB[3][i])<<2)+dB[5][i]-dB[3][i];
    		}
    
    		ap_int<GTGG_WIDTH_OUT> UV[6][6];
    		#pragma HLS array_partition variable=UV complete
    		
    		for(int i=0;i<6;i++){
    		#pragma HLS unroll
    			for(int j=0;j<6;j++){
    			#pragma HLS unroll
    					if(wino_flag)
    						UV[i][j]=BTdB[i][j]*G1[i*6+j];
    					else
    						UV[i][j]=in[i*6+j]*G1[i*6+j];
    			}
    		}
    		ap_int<A_WIDTH_OUT>  UVA[6][4];
    		#pragma HLS array_partition variable=UVA complete
    
    		for(int i=0;i<6;i++)
    		{
    	#pragma HLS unroll
    			UVA[i][0]=UV[i][0]+UV[i][1]+UV[i][2]+UV[i][3]+UV[i][4];
    			UVA[i][1]=UV[i][1]-UV[i][2]+((UV[i][3]-UV[i][4])<<1);
    			UVA[i][2]=UV[i][1]+UV[i][2]+((UV[i][3]+UV[i][4])<<2);
    			UVA[i][3]=UV[i][1]-UV[i][2]+((UV[i][3]-UV[i][4])<<3)+UV[i][5];
    		}
    
    		ap_int<AT_WIDTH_OUT> ATUVA[16];
    		#pragma HLS array_partition variable=ATUVA complete
    
    		for(int i=0;i<4;i++)
    		{
    	#pragma HLS unroll
    			ATUVA[0+i]=UVA[0][i]+UVA[1][i]+UVA[2][i]+UVA[3][i]+UVA[4][i];
    			ATUVA[4+i]=UVA[1][i]-UVA[2][i]+((UVA[3][i]-UVA[4][i])<<1);
    			ATUVA[8+i]=UVA[1][i]+UVA[2][i]+((UVA[3][i]+UVA[4][i])<<2);
    			ATUVA[12+i]=UVA[1][i]-UVA[2][i]+((UVA[3][i]-UVA[4][i])<<3)+UVA[5][i];
    		}
    
    		for(int i=0;i<16;i++)
    		{
    	#pragma HLS unroll
    			out_buffer[i][counter]=out_buffer[i][counter]+ATUVA[i];
    		}
    		
    	}
    }
    
    
    template<int dummy>
    void wino_stream_ceil_4x4(
    		hls::stream< ap_uint<IN_WIDTH*16> > & top_stream_in,
    		hls::stream< ap_uint<IN_WIDTH*16> > & bottom_stream_out,
    		hls::stream< ap_uint<W_WIDTH*16> > &left_stream_in,
    		hls::stream< ap_uint<W_WIDTH*16> > &right_stream_out,
    		ap_uint<OUT_WIDTH> out_buffer[16][1024],
    		ap_uint<1> wino_flag,
    		int number)
    {
    
    	#pragma HLS array_partition variable=out_buffer dim=1 complete
    
    		ap_int<W_WIDTH> G1[16];
    	#pragma HLS array_partition variable=G1
    		ap_int<IN_WIDTH> in[16];
    	#pragma HLS array_partition variable=in
    	
    
    	for(int counter=0;counter<number;counter++)
    	{
    		#pragma HLS pipeline
    		ap_uint<IN_WIDTH*16> stream_in_temp;
    		ap_uint<W_WIDTH*16> stream_weight_temp;
    		top_stream_in>>stream_in_temp;
    		bottom_stream_out<<stream_in_temp;
    		left_stream_in>>stream_weight_temp;
    		right_stream_out<<stream_weight_temp;
    
    		for(int k=0;k<16;k++)
    		{
    		#pragma HLS unroll factor=36
    			G1[k].range(W_WIDTH-1,0)=stream_weight_temp.range(W_WIDTH-1+k*W_WIDTH,k*W_WIDTH);
    			in[k].range(IN_WIDTH-1,0)=stream_in_temp.range(IN_WIDTH-1+k*IN_WIDTH,k*IN_WIDTH);
    		}
    
    		ap_int<B_WIDTH_OUT>  dB[4][4];
    		#pragma HLS array_partition variable=dB complete
    		for(int i=0;i<4;i++)
    		{
    	#pragma HLS unroll
    			dB[i][0]=in[i*4]-in[i*4+2];
    			dB[i][1]=in[i*4+1]+in[i*4+2];
    			dB[i][2]=-in[i*4+1]+in[i*4+2];
    			dB[i][3]=in[i*4+1]-in[i*4+3];
    		}
    
    		ap_int<BT_WIDTH_OUT>  BTdB[4][4];
    		#pragma HLS array_partition variable=BTdB complete
    
    		if(wino_flag)
    		{
    			for(int i=0;i<4;i++)
    			{
    		#pragma HLS unroll
    				BTdB[i][0]=dB[0][i]-dB[2][i];
    				BTdB[i][1]=dB[1][i]+dB[2][i];
    				BTdB[i][2]=-dB[1][i]+dB[2][i];
    				BTdB[i][3]=dB[1][i]-dB[3][i];
    			}
    		}
    		else
    		{
    			for(int i=0;i<4;i++)
    			{
    		#pragma HLS unroll
    				BTdB[i][0]=in[i*4+0];
    				BTdB[i][1]=in[i*4+1];
    				BTdB[i][2]=in[i*4+2];
    				BTdB[i][3]=in[i*4+3];
    			}
    		}
    
    		ap_int<GTGG_WIDTH_OUT> UV[4][4];
    		#pragma HLS array_partition variable=UV complete
    		
    		for(int i=0;i<4;i++){
    		#pragma HLS unroll
    			for(int j=0;j<4;j++){
    			#pragma HLS unroll
    				if(wino_flag)
    						UV[i][j]=BTdB[i][j]*G1[i*4+j];
    				else
    						UV[i][j]=in[i][j]*G1[i*4+j];
    			}
    		}
    		ap_int<A_WIDTH_OUT>  UVA[4][2];
    		#pragma HLS array_partition variable=UVA complete
    
    		for(int i=0;i<4;i++)
    		{
    	#pragma HLS unroll
    			UVA[i][0]=UV[i][0]+UV[i][1]+UV[i][2];
    			UVA[i][1]=UV[i][1]-UV[i][2]-UV[i][3];
    		}
    
    		ap_int<AT_WIDTH_OUT> ATUVA[16];
    		#pragma HLS array_partition variable=ATUVA complete
    
    		if(wino_flag)
    		{
    			for(int i=0;i<2;i++)
    			{
    		#pragma HLS unroll
    				ATUVA[0+i]=UVA[0][i]+UVA[1][i]+UVA[2][i];
    				ATUVA[2+i]=UVA[1][i]-UVA[2][i]-UVA[3][i];
    			}
    		}
    		else
    		{
    			for(int i=0;i<16;i++)
    			{
    				ATUVA[i]=UV[i/4][i%4];
    			}
    		}
    
    		if(wino_flag)
    		{
    			for(int i=0;i<4;i++)
    			{
    		#pragma HLS unroll
    				out_buffer[i][counter]=out_buffer[i][counter]+ATUVA[i];
    			}
    		}
    		else
    		{
    			for(int i=0;i<4;i++)
    			for(int j=0;j<4;j++)
    			{
    		#pragma HLS unroll
    				out_buffer[i*4+j][counter]=out_buffer[i*4+j][counter]+UV[i][j];
    			}
    		}
    	}
    }