#ifndef _WINO_BUFFER_HPP_
#define _WINO_BUFFER_HPP_
#include "wino_macro.h"
#include "../software/param.h"
#include <ap_int.h>
#include <hls_stream.h>


void input_feed_underconstruction(
	ap_uint<16> input_buffer[INBUFFER_HEIGHT][INBUFFER_WIDTH][INPUT_BUFFER_DEPTH],
	// hls::stream< ap_uint<16*BATCH_SIZE*36> > &input_tile_stream1, 
	// hls::stream< ap_uint<16*BATCH_SIZE*36> > &input_tile_stream2,
	ap_uint<16> inheight,
	ap_uint<16> inwidth,
	ap_uint<16> pad_size,
	ap_uint<16> weightbuffer_load_indepth_number,
	ap_uint<16> weightbuffer_load_outdepth_number,
	ap_uint<16> wino_output_tile_size,
	ap_uint<32> input_buffer_feeding_loop_bound,
	ap_uint<16> loop_wino_tile_col_reset_cycle,
	ap_uint<16> loop_outdepth_minitile_baseidx_reset_cycle,
	ap_uint<10> buffer_address_mid_minitile_depth_step,
	ap_uint<16> wino_out_size_by_wino_width,
	ap_uint<16> start_row_idx,
	ap_int<16> start_row_idx_minus_pad_size
	#if DEBUG_FILE_PRINT
	,ConvDesc_t conv_desc
	#endif
)
{
	// row_selection preparation 
	ap_uint<1> row_legal_flag[WINO_DOMAIN_SIZE];
	#pragma HLS array_partition variable=row_legal_flag complete
	ap_uint<1> row_address_offset[INBUFFER_HEIGHT];
	#pragma HLS array_partition variable=row_address_offset complete
	ap_uint<INBUFFER_HEIGHT_BITWIDTH> row_bank_idx[WINO_DOMAIN_SIZE];
	#pragma HLS array_partition variable=row_bank_idx complete

	//  = wino_output_tile_size<<WINO_WIDTH_BITWIDTH;

	ap_uint<INBUFFER_HEIGHT_BITWIDTH> row_breakpoint = start_row_idx_minus_pad_size.range(INBUFFER_HEIGHT_BITWIDTH-1,0);

	for(int i=0;i<8;i++)
	{
	#pragma HLS unroll
		if(i< row_breakpoint)
			row_address_offset[i] = ~start_row_idx_minus_pad_size[INBUFFER_HEIGHT_BITWIDTH];
		else
			row_address_offset[i] = start_row_idx_minus_pad_size[INBUFFER_HEIGHT_BITWIDTH];
	}			
	for(int i=0;i<6;i++)
	{
	#pragma HLS unroll
		row_bank_idx[i] = start_row_idx_minus_pad_size+i;
		row_legal_flag[i] = ( start_row_idx_minus_pad_size+i >=0 && start_row_idx_minus_pad_size+i < inheight);
	}

	
	ap_int<10> input_col_idx[WINO_WIDTH][WINO_DOMAIN_SIZE];
	#pragma HLS array_partition variable=input_col_idx dim=1 complete
	#pragma HLS array_partition variable=input_col_idx dim=2 complete

	ap_uint<16> wino_col_offset_constant[WINO_WIDTH];
	#pragma HLS array_partition variable=wino_col_offset_constant complete

	for(int i=0;i<WINO_WIDTH;i++)
	{
		#pragma HLS unroll
		wino_col_offset_constant[i]=wino_output_tile_size*i;
	}


	ap_uint<16>  first_col_idx=0;





	ap_uint<INDEPTH_MINITILE_SIZE_BITWIDTH> loop_indepth_minitile_idx=0;

	ap_uint<16> loop_wino_tile_col_cnt=1;
	ap_uint<16> loop_indepth_minitile_baseidx_cnt =1;

	ap_uint<16> loop_outdepth_minitile_baseidx_cnt =1;


	// loop_wino_tile_col_reset_cycle =conv_desc.wino_tile_number_in_outwidth*conv_desc.weightbuffer_outdepth_minitile_number*INDEPTH_MINITILE_SIZE;
	// loop_outdepth_minitile_baseidx_reset_cycle =conv_desc.weightbuffer_outdepth_minitile_number*INDEPTH_MINITILE_SIZE;




	ap_uint<INBUFFER_MID_ADDR_BITWIDTH> buffer_address_mid_minitile_depth_offset=0;
	ap_uint<INBUFFER_MID_ADDR_BITWIDTH> buffer_address_mid_buffertile_depth_offset=0;

	

	ap_int<16> input_head_col_idx=wino_col_offset_constant[0]-pad_size;

	for(int wino_array_col=0;wino_array_col<WINO_WIDTH;wino_array_col++)
	{
		#pragma HLS unroll
		for(int i=0;i<WINO_DOMAIN_SIZE;i++)
		{
			#pragma HLS unroll
			input_col_idx[wino_array_col][i]=i-pad_size+wino_col_offset_constant[wino_array_col];
		}
	}



	for(ap_uint<16> outdepth_buffertile_idx=0;outdepth_buffertile_idx<weightbuffer_load_outdepth_number;outdepth_buffertile_idx++)
	{

		buffer_address_mid_minitile_depth_offset = 0;

		for(int counter=0;counter<input_buffer_feeding_loop_bound;counter++ )
		{
			#pragma HLS pipeline II =1
			// it is a flattened loop which does following
			// for(ap_uint<16> indepth_buffertile_baseidx=0;indepth_buffertile_baseidx<weightbuffer_load_indepth_number;indepth_buffertile_baseidx++)
			// for( int indepth_minitile_baseidx=0;indepth_minitile_baseidx<weightbuffer_indepth_minitile_number; indepth_minitile_baseidx ++)
			// for(int wino_tile_col_idx =1; wino_tile_col_idx < wino_tile_number_in_outwidth+1 ; wino_tile_col_idx++)
			// for(int outdepth_minitile_baseidx=0;outdepth_minitile_baseidx<weightbuffer_outdepth_minitile_number; outdepth_minitile_baseidx ++)
			// for(ap_uint<3> indepth_minitile_idx=0; indepth_minitile_idx< INDEPTH_MINITILE_SIZE; indepth_minitile_idx++)

			ap_uint<1> col_legal_flag[WINO_WIDTH][WINO_DOMAIN_SIZE];
			#pragma HLS array_partition variable=col_legal_flag complete

			for(int wino_array_col=0;wino_array_col<WINO_WIDTH;wino_array_col++)
			{
			#pragma HLS unroll
				for(int i=0;i<WINO_DOMAIN_SIZE;i++)
				{
					#pragma HLS unroll
					col_legal_flag[wino_array_col][i]= ( input_col_idx[wino_array_col][i] >=0 && input_col_idx[wino_array_col][i] < inwidth);

				}
			}
			ap_uint<INBUFFER_MID_ADDR_BITWIDTH> col_pix_address_offset[INBUFFER_WIDTH];
			
			ap_uint<INBUFFER_WIDTH_BITWIDTH> col_breakpoint=input_head_col_idx.range(INBUFFER_WIDTH_BITWIDTH-1,0);

			ap_uint<INBUFFER_MID_ADDR_BITWIDTH> input_head_col_address_offset;
			input_head_col_address_offset= input_head_col_idx.range(INBUFFER_WIDTH_BITWIDTH+INBUFFER_MID_ADDR_BITWIDTH-1,INBUFFER_WIDTH_BITWIDTH) 
			+ buffer_address_mid_minitile_depth_offset;




			for(int i=0;i<INBUFFER_WIDTH;i++)
			{
				if(i>=col_breakpoint)
					col_pix_address_offset[i] = input_head_col_address_offset;
				else
					col_pix_address_offset[i] = input_head_col_address_offset+1;
			}

			ap_uint<INPUT_BUFFER_DEPTH_BITWIDTH> buffer_address[INBUFFER_HEIGHT][INBUFFER_WIDTH];

			for(int i=0;i<INBUFFER_HEIGHT; i++)
			{
				#pragma HLS unroll
				for(int j=0;j<INBUFFER_WIDTH;j++)
				{
					#pragma HLS unroll
					buffer_address[i][j]=(row_address_offset[i],col_pix_address_offset[j],loop_indepth_minitile_idx);

				}
			}

			

			ap_uint<16> input_buffer_val[INBUFFER_HEIGHT][INBUFFER_WIDTH];
			#pragma HLS array_partition variable=input_buffer_val complete


			for(int i=0;i<INBUFFER_HEIGHT; i++)
			{
				#pragma HLS unroll
				for(int j=0;j<INBUFFER_WIDTH;j++)
				{
					#pragma HLS unroll
					input_buffer_val[i][j]=input_buffer[i][j][buffer_address[i][j]];
				}
			}



			ap_uint<16> input_plane_tile_row[WINO_DOMAIN_SIZE][INBUFFER_WIDTH];
			#pragma HLS array_partition variable=input_plane_tile_row dim=1 complete
			#pragma HLS array_partition variable=input_plane_tile_row dim=2 complete


			for(int j=0;j<INBUFFER_WIDTH;j++)
			{
			#pragma HLS unroll
				for(int i=0;i<WINO_DOMAIN_SIZE;i++)
				{
				#pragma HLS unroll
					if(row_legal_flag[i])
					{
						input_plane_tile_row[i][j]=input_buffer_val[row_bank_idx[i]][j];
					}
					else
					{
						input_plane_tile_row[i][j]=0;
					}
				}
			}



			ap_uint<16> input_plane_tile[WINO_WIDTH][WINO_DOMAIN_SIZE][WINO_DOMAIN_SIZE];
			#pragma HLS array_partition variable=input_plane_tile complete
			for(int i=0;i<WINO_WIDTH;i++)
			{
			#pragma HLS unroll
				for(int j=0;j<WINO_DOMAIN_SIZE;j++)
				{
				#pragma HLS unroll
					for(int k=0;k<WINO_DOMAIN_SIZE;k++)
					{
					#pragma HLS unroll
						if(col_legal_flag[i][k])
							input_plane_tile[i][j][k]=input_plane_tile_row[j][ (ap_uint<INBUFFER_WIDTH_BITWIDTH>) input_col_idx[i][k].range(INBUFFER_WIDTH_BITWIDTH-1,0) ];
						else
							input_plane_tile[i][j][k]=0;
					}
				}
			}

			#if DEBUG_FILE_PRINT
				int indepth = buffer_address_mid_minitile_depth_offset/buffer_address_mid_minitile_depth_step*INDEPTH_MINITILE_SIZE
								+loop_indepth_minitile_idx;
				attach_streaming_content<WINO_WIDTH>(input_plane_tile, start_row_idx, input_head_col_idx+pad_size, indepth, "instream.txt");
			#endif

			if(loop_wino_tile_col_cnt == loop_wino_tile_col_reset_cycle)
			{
				buffer_address_mid_minitile_depth_offset += buffer_address_mid_minitile_depth_step;
			}

			if(loop_wino_tile_col_cnt == loop_wino_tile_col_reset_cycle)
			{
				input_head_col_idx=wino_col_offset_constant[0]-pad_size;
				for(int wino_array_col=0;wino_array_col<WINO_WIDTH;wino_array_col++)
				{
					#pragma HLS unroll
					for(int i=0;i<WINO_DOMAIN_SIZE;i++)
					{
						#pragma HLS unroll
						input_col_idx[wino_array_col][i]=i-pad_size+wino_col_offset_constant[wino_array_col];
					}
				}
			}
			else if(loop_outdepth_minitile_baseidx_cnt==loop_outdepth_minitile_baseidx_reset_cycle)
			{
				input_head_col_idx+=wino_out_size_by_wino_width;
				for(int wino_array_col=0;wino_array_col<WINO_WIDTH;wino_array_col++)
				{
					#pragma HLS unroll
					for(int i=0;i<WINO_DOMAIN_SIZE;i++)
					{
						#pragma HLS unroll
						input_col_idx[wino_array_col][i]+=wino_out_size_by_wino_width;
					}
				}
			}

			if(loop_wino_tile_col_cnt == loop_wino_tile_col_reset_cycle)
			{
				loop_wino_tile_col_cnt=1;
			}
			else 
			{
				loop_wino_tile_col_cnt++;
			}

			if(loop_outdepth_minitile_baseidx_cnt==loop_outdepth_minitile_baseidx_reset_cycle)
			{
				loop_outdepth_minitile_baseidx_cnt=1;
			}
			else
			{
				loop_outdepth_minitile_baseidx_cnt++;
			}
			loop_indepth_minitile_idx++;							
		}
	}	
}



//template<int dummy>
void load_weight_ddr_one_port(
	ap_uint<128>* weight_DDR,
	ap_uint<32> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4][WEIGHT_BUFFER_DEPTH],
	ap_uint<16> weightDDR_buffer_burst_length,
	ap_uint<16> weightDDR_port_burst_length,
	ap_uint<32> ddr_address_offset,
	ap_uint<1> pingpong,
	ap_uint<1> skip_flag
	#if DEBUG_FILE_PRINT
	,ConvDesc_t conv_desc
	#endif
	)
{
		

	#pragma HLS array_partition variable = weight_buff dim=1 complete
	#pragma HLS array_partition variable = weight_buff dim=2 complete

	if(skip_flag)
		return;
	printf("load pingpong %d %d\n",(int) pingpong,(int) ddr_address_offset);
	// printf("DDR_offset %d, i%d o%d\n",(int) ddr_address_offset,(int) conv_desc.weightbuffer_load_indepth_number,(int) conv_desc.weightbuffer_load_outdepth_number);

	ap_uint<128>* offseted_weight_DDR=weight_DDR+ddr_address_offset;

	ap_uint<WEIGHTDDR_INDEPTH_MINITILE_128BIT_STEP_BITWIDTH> counter=0;

	ap_uint<16> port_load_cnt=1;

	ap_uint<WEIGHT_BUFFER_DEPTH_BITWIDTH-1> buffer_address_offset=0;

	#if WEIGHT_FEED_NUMBER_PER_PORT_BITWIDTH == 0
	ap_uint<1> buffer_idx;
	#else
	ap_uint<WEIGHT_FEED_NUMBER_PER_PORT_BITWIDTH> buffer_idx=0;
	#endif

	// printf("into the loop %d %d\n", (int) weightDDR_port_burst_length, (int) weightDDR_buffer_burst_length);
	// fflush(stdout);
	for(int address = 0; address<weightDDR_port_burst_length; address++)
	{
		
		#pragma HLS pipeline
		ap_uint<128> temp128 = offseted_weight_DDR[address];
		ap_uint<32> temp32[4];
		#pragma HLS array_partition  variable = temp32 complete

		for(int i=0;i<4;i++)
		{
			#pragma HLS unroll
			temp32[i]=temp128.range(i*32+31,i*32);
		}

		ap_uint<10> buffer_address = (pingpong,buffer_address_offset);

		for(int i=0;i<WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4;i++)
		{

			#pragma HLS unroll
			if( i/4==counter)
			{
				#if WEIGHT_FEED_NUMBER_PER_PORT == 0
				weight_buff[0][i][buffer_address]=temp32[i%4];
				#else
				weight_buff[buffer_idx][i][buffer_address]=temp32[i%4];
				#endif
			}
		}
	
		if(port_load_cnt==weightDDR_buffer_burst_length)
		{
			buffer_address_offset=0;
		}
		else if(counter== WEIGHTDDR_INDEPTH_MINITILE_128BIT_STEP-1)
		{
			buffer_address_offset++;
		}

		if(port_load_cnt==weightDDR_buffer_burst_length)
		{
			buffer_idx++;
			port_load_cnt=1;
		}
		else
		{
			port_load_cnt++;
		}

		if(counter==WEIGHTDDR_INDEPTH_MINITILE_128BIT_STEP-1)
		{
			counter=0;
		}
		else
		{
			counter++;
		}
	}
}


void weight_streamer(
	ap_uint<32> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4][WEIGHT_BUFFER_DEPTH],
	hls::stream<ap_uint<8*INDEPTH_MINITILE_SIZE*WINO_DOMAIN_SIZE_SQUARE> >  weight_stream[WEIGHT_FEED_NUMBER_PER_PORT],
	ap_uint<16> loop_outdepth_minitile_baseidx_reset_cycle_minus1,
	ap_uint<16> loop_start_output_basecol_reset_cycle,
	ap_uint<32> loop_weight_feed_bound,
	ap_uint<1> pingpong
	#if DEBUG_FILE_PRINT
	,ConvDesc_t conv_desc
	#endif
)
{
	printf("stream pingpong %d\n",(int) pingpong);
	#pragma HLS array_partition variable = weight_buff dim=1 complete
	#pragma HLS array_partition variable = weight_buff dim=2 complete
	// int weight_feed_total_size_by2 = weight_feed_total_size/2;


	ap_uint<WEIGHT_BUFFER_DEPTH_BITWIDTH -1> outdepth_minitile_addr_offset=0;
	ap_uint<WEIGHT_BUFFER_DEPTH_BITWIDTH -1> indepth_minitile_addr_offset=0;



	ap_uint<16> loop_outdepth_minitile_baseidx_cnt=1;
	ap_uint<16> loop_start_output_basecol_cnt=1;




	// loop_outdepth_minitile_baseidx_reset_cycle_minus1=conv_desc.weightbuffer_outdepth_minitile_number-1;
	// loop_start_output_basecol_reset_cycle=conv_desc.weightbuffer_outdepth_minitile_number * conv_desc.wino_tile_number_in_outwidth;

	// int loop_weight_feed_bound = conv_desc.weightbuffer_indepth_minitile_number * conv_desc.weightbuffer_outdepth_minitile_number * conv_desc.wino_tile_number_in_outwidth;

	for(ap_uint<32> cycle=0;cycle < loop_weight_feed_bound; cycle++)
	{
		#pragma HLS pipeline
		// for(int indepth_minitile_baseidx=0;indepth_minitile_baseidx<conv_desc.weightbuffer_load_indepth_step; indepth_minitile_baseidx += INDEPTH_MINITILE_SIZE)
		// for(int start_output_basecol =0; start_output_basecol < conv_desc.outwidth; start_output_basecol+=conv_desc.wino_output_tile_size*WINO_WIDTH)
		// for(int outdepth_minitile_baseidx=0;outdepth_minitile_baseidx<conv_desc.weightbuffer_load_outdepth_step; outdepth_minitile_baseidx += OUTDEPTH_MINITILE_SIZE)
		ap_uint<WEIGHT_BUFFER_DEPTH_BITWIDTH-1> weight_buffer_address_right=indepth_minitile_addr_offset+outdepth_minitile_addr_offset;
		

		ap_uint<WEIGHT_BUFFER_DEPTH_BITWIDTH> weight_buffer_address = (pingpong,weight_buffer_address_right);


		ap_uint<32> temp18[WEIGHT_FEED_NUMBER_PER_PORT][WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4];
		#pragma HLS array_partition variable = temp18 complete



		for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
		{
		#pragma HLS unroll
			for(int j18=0;j18<WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4;j18++)
			{
			#pragma HLS unroll
				temp18[buffer_idx][j18]=weight_buff[buffer_idx][j18][weight_buffer_address];
			}
		}

		ap_uint<8*INDEPTH_MINITILE_SIZE*WINO_DOMAIN_SIZE_SQUARE> temp16x36[WEIGHT_FEED_NUMBER_PER_PORT];
		#pragma HLS array_partition variable = temp16x36 complete

		for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
		{
			#pragma HLS unroll
			for(int j18=0;j18<WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4;j18++)
			{
				#pragma HLS unroll
				temp16x36[buffer_idx].range(j18*32+31,j18*32)=temp18[buffer_idx][j18];
			}
		}
		for(int buffer_idx =0; buffer_idx< WEIGHT_FEED_NUMBER_PER_PORT; buffer_idx++)
		{
			#pragma HLS unroll
			weight_stream[buffer_idx]<<temp16x36[buffer_idx];
		}


		if(loop_start_output_basecol_cnt==loop_start_output_basecol_reset_cycle){
			indepth_minitile_addr_offset+=conv_desc.weightbuffer_outdepth_minitile_number;
		}
		
		if(outdepth_minitile_addr_offset==loop_outdepth_minitile_baseidx_reset_cycle_minus1){
			outdepth_minitile_addr_offset=0;
		}
		else{
			outdepth_minitile_addr_offset++;
		}

		if(loop_start_output_basecol_cnt==loop_start_output_basecol_reset_cycle){
			loop_start_output_basecol_cnt=1;
		}
		else{
			loop_start_output_basecol_cnt++;
		}
	}
}



template<int dummy>
void weight_feed_one_port(
	ap_uint<128>* weight_DDR0,
	hls::stream<ap_uint<8*INDEPTH_MINITILE_SIZE*WINO_DOMAIN_SIZE_SQUARE> >  weight_stream[WEIGHT_FEED_NUMBER_PER_PORT],
	ap_uint<16> weightDDR_burst_number,
	ap_uint<16> weightDDR_buffer_burst_length,
	ap_uint<16> weightDDR_port_burst_length,
	ap_uint<16> loop_outdepth_minitile_baseidx_reset_cycle_minus1,
	ap_uint<16> loop_start_output_basecol_reset_cycle,
	ap_uint<32> loop_weight_feed_bound,
	ap_uint<1> first_flag,
	ap_uint<1> last_flag
	#if DEBUG_FILE_PRINT
	,ConvDesc_t conv_desc	
	#endif
)
{
	printf("\ninto weight feed %d %d %d\n",(int) first_flag, (int) last_flag, (int) weightDDR_burst_number  );
	static ap_uint<32> weight_buff[WEIGHT_FEED_NUMBER_PER_PORT][WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4][WEIGHT_BUFFER_DEPTH];
	static ap_uint<16> DDR_offset;
	static ap_uint<16> DDR_load_cnt;
	static ap_uint<1> pingpong;

	if(first_flag)
		DDR_offset=0;
		DDR_load_cnt=0;
		pingpong = 0;

	load_weight_ddr_one_port(
	weight_DDR0,
	weight_buff,
	weightDDR_buffer_burst_length,
	weightDDR_port_burst_length,
	0,
	0,
	~first_flag
	#if DEBUG_FILE_PRINT
	,conv_desc
	#endif
	);

	for(ap_uint<16> cnt=0;cnt< weightDDR_burst_number ;cnt++)
	{
		
		if(DDR_load_cnt == weightDDR_burst_number-1)
		{
			DDR_load_cnt = 0;
			DDR_offset = 0;
		}
		else
		{
			DDR_load_cnt+=1;
			DDR_offset+=weightDDR_port_burst_length;
		}

		load_weight_ddr_one_port(
		weight_DDR0,
		weight_buff,
		weightDDR_buffer_burst_length,
		weightDDR_port_burst_length,
		DDR_offset,
		~pingpong,
		last_flag & (DDR_load_cnt==0) 
		#if DEBUG_FILE_PRINT
		,conv_desc
		#endif
		);
		for(int i=0;i<WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4;i++)
		{
			printf("[%4x]",(int)weight_buff[0][i][0]);
		}
		printf("\n");
		for(int i=0;i<WINO_DOMAIN_SIZE_SQUARE*INDEPTH_MINITILE_SIZE/4;i++)
		{
			printf("[%4x]",(int) weight_buff[0][i][WEIGHT_BUFFER_DEPTH/2]);
		}
		printf("\n");
		weight_streamer(
			weight_buff,
			weight_stream,
			loop_outdepth_minitile_baseidx_reset_cycle_minus1,
			loop_start_output_basecol_reset_cycle,
			loop_weight_feed_bound,
			pingpong
			#if DEBUG_FILE_PRINT
			,conv_desc
			#endif
		);

		pingpong = ~pingpong;

	}

}




#endif