#define AP_INT_MAX_W 1152
#include <ap_int.h>
#include <hls_stream.h>

#include "wino_macro.h"
#include "wino_IO.cpp"
#include "wino_cell.hpp"





#define CEIL_DIV(x,y)  ( ( (x) + (y) - 1) / (y) )
#define ALIGN(x,y)  ( ( (x) + (y) - 1) / (y) * (y) )

void wino_systolic_top(
    ap_uint<128> *input_DDR0,
    ap_uint<128> *input_DDR1,
    ap_uint<128> *input_DDR2,
    ap_uint<128> *input_DDR3,
    ap_uint<128> *weight_DDR0,
    ap_uint<128> *weight_DDR1,
    ap_uint<128> *weight_DDR2,
    ap_uint<128> *weight_DDR3,
    ap_uint<128> *output_DDR0,
    ap_uint<128> *output_DDR1,
    ap_uint<128> *output_DDR2,
    ap_uint<128> *output_DDR3,
    int inHeight,
    int inWidth,
    int inpDepth,
    int outHeight,
    int outWidth,
    int outDepth,
    int kernelWindowSize,
    int padSize)
{


    /******** this version assumes it shall only compute 5x5 and 3x3 for temporarily design **********/

    // following part process parameters that may be moved outside the IP, but will keep it here first
    
    //declaration part
    ap_uint<16> input_height;
    ap_uint<16> input_width;
    ap_uint<16> input_depth;
    ap_uint<16> output_height;
    ap_uint<16> output_width;
    ap_uint<16> output_depth;
    ap_uint<8> pad_size;
    ap_uint<8> kernel_window_size;



    ap_uint<8> wino_output_tile_size;

    ap_uint<16> input_width_align_8;
    ap_uint<16> input_width_align_16;


	ap_uint<16> weight_indepth_load_number;
	ap_uint<16> weight_outdepth_load_number;
	ap_uint<16> weight_outdepth_feed_size;

    ap_uint<16> input_depth_align8;
    ap_uint<16> input_width_ceildiv_16;


    ap_uint<16> weight_total_load_number;
	ap_uint<16> weight_total_feed_size;
	ap_uint<16> ddr_load_length;
	ap_uint<16> ddr_load_length_per_feed;
	ap_uint<16> row_repeat_times;


    // computation part
    input_height=inHeight;
    input_width=inWidth;
    input_depth=inpDepth;
    output_height=outHeight;
    output_width=outWidth;
    output_depth=outDepth;
    pad_size=padSize;
    if(kernelWindowSize==3 || kernelWindowSize==5 ) 
        kernel_window_size=1;
    else
        kernel_window_size=kernelWindowSize;


    if(kernelWindowSize==3)
        wino_output_tile_size = 4;
    else
        wino_output_tile_size = 2;

    if(inWidth%8)
        input_width_align_8 = (input_width/8+1)*8;
    else
        input_width_align_8 = inWidth;

    if(inWidth%16)
        input_width_align_16 = (input_width/16+1)*16;
    else
        input_width_align_16 = inWidth;


  
    input_depth_align8 = ALIGN(input_depth,8);
    int output_depth_align_systolic_height = ALIGN(output_depth, WEIGHT_FEED_NUMBER_PER_PORT);
    int output_depth_ceildiv_systolic_height = CEIL_DIV(output_depth, WEIGHT_FEED_NUMBER_PER_PORT*4);

    int outdepth_ceil_div8 = CEIL_DIV(output_depth, 8);






    int max_weight_outdepth_feed_size = WEIGHT_BUFFER_DEPTH/8;

    weight_outdepth_feed_size = max_weight_outdepth_feed_size < output_depth_ceildiv_systolic_height ? max_weight_outdepth_feed_size :output_depth_ceildiv_systolic_height;
    
    int load_size_per_feeding_port = 8 * weight_outdepth_feed_size;

    weight_indepth_load_number = input_depth_align8 /8;

    weight_outdepth_load_number =  CEIL_DIV( output_depth_align_systolic_height,  max_weight_outdepth_feed_size);

    int ddr_bytes_per_feeding_port = load_size_per_feeding_port/2*80;
    
    int ddr_bytes_per_DDR_port = weight_indepth_load_number*weight_outdepth_load_number*ddr_bytes_per_feeding_port*WEIGHT_FEED_NUMBER_PER_PORT;

    int ddr_bytes_total = ddr_bytes_per_DDR_port*4;

    input_width_ceildiv_16 = CEIL_DIV(input_width,16);


    row_repeat_times = CEIL_DIV(output_width, wino_output_tile_size*2);

    weight_total_load_number = weight_indepth_load_number *  weight_outdepth_load_number;
    weight_total_feed_size = weight_outdepth_feed_size * 8;
    ddr_load_length_per_feed = weight_total_feed_size/2*5;
    ddr_load_length = ddr_load_length_per_feed* WEIGHT_FEED_NUMBER_PER_PORT;

    



    //input buffer declaration
    ap_uint<16> input_buffer[8][16][INPUT_BUFFER_DEPTH];
    ap_uint<36> output_buffer0[16][16][OUTPUT_BUFFER_DEPTH];
    ap_uint<36> output_buffer1[16][16][OUTPUT_BUFFER_DEPTH];
    for(int i=0;i<16;i++)
    for(int j=0;j<16;j++)
    for(int k=0;k<1024;k++)
    {
        output_buffer0[i][j][k]=0;
    }

    #if DEBUG_FILE_PRINT
        clear_buffer_content<INPUT_BUFFER_DEPTH>(input_buffer);
    #endif

    load_input_rowtile_from_ddr(
        input_DDR0,
        input_DDR1,
        input_DDR2,
        input_DDR3,
        input_buffer,
        input_height,
        input_width,
        input_width_align_8,
        input_width_align_16,
        input_depth,
        output_height,
        0,
        pad_size,
        1);

    #if DEBUG_FILE_PRINT
        attach_input_buffer_content<INPUT_BUFFER_DEPTH>(input_buffer,0,"input_buffer_content.txt");
    #endif



    for(ap_uint<16> start_output_row =0 ; start_output_row < output_height; start_output_row+=wino_output_tile_size)
    {
        load_input_rowtile_from_ddr(
        input_DDR0,
        input_DDR1,
        input_DDR2,
        input_DDR3,
        input_buffer,
        input_height,
        input_width,
        input_width_align_8,
        input_width_align_16,
        input_depth,
        output_height,
        start_output_row + wino_output_tile_size,
        pad_size,
        0);

        #if DEBUG_FILE_PRINT
            attach_input_buffer_content<INPUT_BUFFER_DEPTH>(input_buffer,start_output_row,"input_buffer_content.txt");
        #endif

	
    wino_systolic(
	input_buffer,
	output_buffer0,
	weight_DDR0,
	weight_DDR1,
	weight_DDR2,
	weight_DDR3,
	input_height,
	input_width,
	input_depth,
    input_width_ceildiv_16,
    input_depth_align8,
	output_height,
	output_width,
	output_depth,
    kernel_window_size,
	pad_size,
	weight_indepth_load_number,
	weight_outdepth_load_number,
	weight_outdepth_feed_size,
	start_output_row,

    weight_total_load_number,
	weight_total_feed_size,
	ddr_load_length,
	ddr_load_length_per_feed,
	row_repeat_times,
	(start_output_row==0) ,
	(start_output_row+wino_output_tile_size >= output_height));
#if DEBUG_FILE_PRINT
    attach_output_buffer_content<0>(output_buffer0,"output_buffer_content.txt");
    #endif


// void write_output_to_DDR(
// 	ap_uint<128>* DDR_port0,
// 	ap_uint<128>* DDR_port1,
// 	ap_uint<128>* DDR_port2,
// 	ap_uint<128>* DDR_port3,
// 	ap_uint<36> output_buffer[16][16][1024],
// 	int outdepth_ceildiv8,
// 	int start_out_row_idx,
// 	int end_row,
// 	int out_height_aligned4,	
// 	int out_width,
// 	int wino_output_tile_size,
// 	int row_tile_number,
// 	int wino_flag,
// 	int final_right_shift)

    printf("Get into the design\n");
    fflush(stdout);
    write_output_to_DDR(
    output_DDR0,
    output_DDR1,
    output_DDR2,
    output_DDR3,
    output_buffer0,
    outdepth_ceil_div8,
    start_output_row,
    start_output_row+4,
    output_height,
    output_width,
    wino_output_tile_size,
    row_repeat_times,
    1,
    0);
    printf("Get outof the design\n");
    fflush(stdout);
    }












    

}





#include "wino_macro.h"