#include <ap_int.h>
#include <hls_stream.h>





#define IN_WIDTH 8
#define W_WIDTH 8
#define OUT_WIDTH 32

#define B_WIDTH_IN 8
#define B_WIDTH_OUT 12
#define BT_WIDTH_IN 12
#define BT_WIDTH_OUT 16

#define GTGG_WIDTH_IN 16
#define GTGG_WIDTH_W 16
#define GTGG_WIDTH_OUT 24

#define A_WIDTH_IN 24
#define A_WIDTH_OUT 28

#define AT_WIDTH_IN 28
#define AT_WIDTH_OUT 32



template<int dummy>
void wino_stream_ceil(
		hls::stream< ap_uint<IN_WIDTH*36> > & top_stream_in,
		hls::stream< ap_uint<IN_WIDTH*36> > & bottom_stream_out,
		hls::stream< ap_uint<W_WIDTH*36> > &left_stream_in,
		hls::stream< ap_uint<W_WIDTH*36> > &right_stream_out,
		ap_uint<OUT_WIDTH> out_buffer[16][1024],
		ap_uint<1> wino_flag,
		int number)
{

	#pragma HLS array_partition variable=out_buffer dim=1 complete

		ap_int<W_WIDTH> G1[36];
	#pragma HLS array_partition variable=G1
		ap_int<IN_WIDTH> in[36];
	#pragma HLS array_partition variable=in
	

	for(int counter=0;counter<number;counter++)
	{
		#pragma HLS pipeline
		ap_uint<IN_WIDTH*36> stream_in_temp;
		ap_uint<W_WIDTH*36> stream_weight_temp;
		top_stream_in>>stream_in_temp;
		bottom_stream_out<<stream_in_temp;
		left_stream_in>>stream_weight_temp;
		right_stream_out<<stream_weight_temp;

		for(int k=0;k<36;k++)
		{
		#pragma HLS unroll factor=36
			G1[k].range(W_WIDTH-1,0)=stream_weight_temp.range(W_WIDTH-1+k*W_WIDTH,k*W_WIDTH);
			in[k].range(IN_WIDTH-1,0)=stream_in_temp.range(IN_WIDTH-1+k*IN_WIDTH,k*IN_WIDTH);
		}

		ap_int<B_WIDTH_OUT>  dB[6][6];
		#pragma HLS array_partition variable=dB complete
		for(int i=0;i<6;i++)
		{
	#pragma HLS unroll
			dB[i][0]=((in[i*6+0]-in[i*6+2])<<2) - in[i*6+2] + in[i*6+4];
			dB[i][1]=in[i*6+3]+in[i*6+4]-((in[i*6+1]+in[i*6+2])<<2);
			dB[i][2]=((in[i*6+1]-in[i*6+2])<<2)+in[i*6+4]-in[i*6+3];
			dB[i][3]=((in[i*6+3]-in[i*6+1])<<1)+in[i*6+4]-in[i*6+2];
			dB[i][4]=((in[i*6+1]-in[i*6+3])<<1)+in[i*6+4]-in[i*6+2];
			dB[i][5]=((in[i*6+1]-in[i*6+3])<<2)+in[i*6+5]-in[i*6+3];
		}

		ap_int<BT_WIDTH_OUT>  BTdB[6][6];
		#pragma HLS array_partition variable=BTdB complete
		for(int i=0;i<6;i++)
		{
	#pragma HLS unroll
			BTdB[0][i]=((dB[0][i]-dB[2][i])<<2) - dB[2][i] + dB[4][i];
			BTdB[1][i]=dB[3][i]+dB[4][i]-((dB[1][i]+dB[2][i])<<2);
			BTdB[2][i]=((dB[1][i]-dB[2][i])<<2)+dB[4][i]-dB[3][i];
			BTdB[3][i]=((dB[3][i]-dB[1][i])<<1)+dB[4][i]-dB[2][i];
			BTdB[4][i]=((dB[1][i]-dB[3][i])<<1)+dB[4][i]-dB[2][i];
			BTdB[5][i]=((dB[1][i]-dB[3][i])<<2)+dB[5][i]-dB[3][i];
		}

		ap_int<GTGG_WIDTH_OUT> UV[6][6];
		#pragma HLS array_partition variable=UV complete
		
		for(int i=0;i<6;i++){
		#pragma HLS unroll
			for(int j=0;j<6;j++){
			#pragma HLS unroll
					if(wino_flag)
						UV[i][j]=BTdB[i][j]*G1[i*6+j];
					else
						UV[i][j]=in[i*6+j]*G1[i*6+j];
			}
		}
		ap_int<A_WIDTH_OUT>  UVA[6][4];
		#pragma HLS array_partition variable=UVA complete

		for(int i=0;i<6;i++)
		{
	#pragma HLS unroll
			UVA[i][0]=UV[i][0]+UV[i][1]+UV[i][2]+UV[i][3]+UV[i][4];
			UVA[i][1]=UV[i][1]-UV[i][2]+((UV[i][3]-UV[i][4])<<1);
			UVA[i][2]=UV[i][1]+UV[i][2]+((UV[i][3]+UV[i][4])<<2);
			UVA[i][3]=UV[i][1]-UV[i][2]+((UV[i][3]-UV[i][4])<<3)+UV[i][5];
		}

		ap_int<AT_WIDTH_OUT> ATUVA[16];
		#pragma HLS array_partition variable=ATUVA complete

		for(int i=0;i<4;i++)
		{
	#pragma HLS unroll
			ATUVA[0+i]=UVA[0][i]+UVA[1][i]+UVA[2][i]+UVA[3][i]+UVA[4][i];
			ATUVA[4+i]=UVA[1][i]-UVA[2][i]+((UVA[3][i]-UVA[4][i])<<1);
			ATUVA[8+i]=UVA[1][i]+UVA[2][i]+((UVA[3][i]+UVA[4][i])<<2);
			ATUVA[12+i]=UVA[1][i]-UVA[2][i]+((UVA[3][i]-UVA[4][i])<<3)+UVA[5][i];
		}

		for(int i=0;i<16;i++)
		{
	#pragma HLS unroll
			out_buffer[i][counter]=out_buffer[i][counter]+ATUVA[i];
		}
		
	}
}


template<int dummy>
void wino_stream_ceil_4x4(
		hls::stream< ap_uint<IN_WIDTH*16> > & top_stream_in,
		hls::stream< ap_uint<IN_WIDTH*16> > & bottom_stream_out,
		hls::stream< ap_uint<W_WIDTH*16> > &left_stream_in,
		hls::stream< ap_uint<W_WIDTH*16> > &right_stream_out,
		ap_uint<OUT_WIDTH> out_buffer[16][1024],
		ap_uint<1> wino_flag,
		int number)
{

	#pragma HLS array_partition variable=out_buffer dim=1 complete

		ap_int<W_WIDTH> G1[16];
	#pragma HLS array_partition variable=G1
		ap_int<IN_WIDTH> in[16];
	#pragma HLS array_partition variable=in
	

	for(int counter=0;counter<number;counter++)
	{
		#pragma HLS pipeline
		ap_uint<IN_WIDTH*16> stream_in_temp;
		ap_uint<W_WIDTH*16> stream_weight_temp;
		top_stream_in>>stream_in_temp;
		bottom_stream_out<<stream_in_temp;
		left_stream_in>>stream_weight_temp;
		right_stream_out<<stream_weight_temp;

		for(int k=0;k<16;k++)
		{
		#pragma HLS unroll factor=36
			G1[k].range(W_WIDTH-1,0)=stream_weight_temp.range(W_WIDTH-1+k*W_WIDTH,k*W_WIDTH);
			in[k].range(IN_WIDTH-1,0)=stream_in_temp.range(IN_WIDTH-1+k*IN_WIDTH,k*IN_WIDTH);
		}

		ap_int<B_WIDTH_OUT>  dB[4][4];
		#pragma HLS array_partition variable=dB complete
		for(int i=0;i<4;i++)
		{
	#pragma HLS unroll
			dB[i][0]=in[i*4]-in[i*4+2];
			dB[i][1]=in[i*4+1]+in[i*4+2];
			dB[i][2]=-in[i*4+1]+in[i*4+2];
			dB[i][3]=in[i*4+1]-in[i*4+3];
		}

		ap_int<BT_WIDTH_OUT>  BTdB[4][4];
		#pragma HLS array_partition variable=BTdB complete

		if(wino_flag)
		{
			for(int i=0;i<4;i++)
			{
		#pragma HLS unroll
				BTdB[i][0]=dB[0][i]-dB[2][i];
				BTdB[i][1]=dB[1][i]+dB[2][i];
				BTdB[i][2]=-dB[1][i]+dB[2][i];
				BTdB[i][3]=dB[1][i]-dB[3][i];
			}
		}
		else
		{
			for(int i=0;i<4;i++)
			{
		#pragma HLS unroll
				BTdB[i][0]=in[i*4+0];
				BTdB[i][1]=in[i*4+1];
				BTdB[i][2]=in[i*4+2];
				BTdB[i][3]=in[i*4+3];
			}
		}

		ap_int<GTGG_WIDTH_OUT> UV[4][4];
		#pragma HLS array_partition variable=UV complete
		
		for(int i=0;i<4;i++){
		#pragma HLS unroll
			for(int j=0;j<4;j++){
			#pragma HLS unroll
				if(wino_flag)
						UV[i][j]=BTdB[i][j]*G1[i*4+j];
				else
						UV[i][j]=in[i][j]*G1[i*4+j];
			}
		}
		ap_int<A_WIDTH_OUT>  UVA[4][2];
		#pragma HLS array_partition variable=UVA complete

		for(int i=0;i<4;i++)
		{
	#pragma HLS unroll
			UVA[i][0]=UV[i][0]+UV[i][1]+UV[i][2];
			UVA[i][1]=UV[i][1]-UV[i][2]-UV[i][3];
		}

		ap_int<AT_WIDTH_OUT> ATUVA[16];
		#pragma HLS array_partition variable=ATUVA complete

		if(wino_flag)
		{
			for(int i=0;i<2;i++)
			{
		#pragma HLS unroll
				ATUVA[0+i]=UVA[0][i]+UVA[1][i]+UVA[2][i];
				ATUVA[2+i]=UVA[1][i]-UVA[2][i]-UVA[3][i];
			}
		}
		else
		{
			for(int i=0;i<16;i++)
			{
				ATUVA[i]=UV[i/4][i%4];
			}
		}

		if(wino_flag)
		{
			for(int i=0;i<4;i++)
			{
		#pragma HLS unroll
				out_buffer[i][counter]=out_buffer[i][counter]+ATUVA[i];
			}
		}
		else
		{
			for(int i=0;i<4;i++)
			for(int j=0;j<4;j++)
			{
		#pragma HLS unroll
				out_buffer[i*4+j][counter]=out_buffer[i*4+j][counter]+UV[i][j];
			}
		}
	}
}