#include "wino_macro.h"
#include <ap_int.h>
#include <hls_stream.h>


#include "wino_IO.cpp"
#include "wino_cell.cpp"


#include "../software/param.h"




void wino_systolic_kernel(
    ap_uint<128> *weight_DDR0,
    ap_uint<128> *weight_DDR1,
    ap_uint<128> *weight_DDR2,
    ap_uint<128> *weight_DDR3,
    ap_uint<16> input_buffer[INBUFFER_HEIGHT][INBUFFER_WIDTH][INPUT_BUFFER_DEPTH],
    ap_uint<OUT_WIDTH*2> out_buffer[WINO_OUT_SIZE][WINO_OUT_SIZE][OUTDEPTH_MINITILE_SIZE][WINO_WIDTH][OUTPUT_BUFFER_DEPTH],
    ap_uint<16> start_output_row,
    ap_int<16> start_row_idx_minus_pad_size,
    ap_uint<1> first_flag,
    ap_uint<1> last_flag,
    ConvDesc_t conv_desc,
    ap_uint<1> ap_clk_div2
)
{
    hls::stream< ap_uint<8*BATCH_SIZE*WINO_DOMAIN_SIZE_SQUARE> > input_tile_stream[WINO_WIDTH];
    hls::stream< ap_uint<BTB_WIDTH*BATCH_SIZE*36> > input_tile_transformed_stream[WINO_WIDTH];
    hls::stream<ap_uint<8*INDEPTH_MINITILE_SIZE*WINO_DOMAIN_SIZE_SQUARE> >  weight_stream[4][WEIGHT_FEED_NUMBER_PER_PORT];

    input_feed_underconstruction(
        input_buffer,
        input_tile_stream,
        // hls::stream< ap_uint<16*BATCH_SIZE*36> > &input_tile_stream1, 
        // hls::stream< ap_uint<16*BATCH_SIZE*36> > &input_tile_stream2,
        conv_desc.inheight,
        conv_desc.inwidth,
        conv_desc.pad_size,
        conv_desc.weightbuffer_load_indepth_number,
        conv_desc.weightbuffer_load_outdepth_number,
        conv_desc.wino_output_tile_size,
        conv_desc.input_buffer_feeding_loop_bound,
        conv_desc.loop_wino_tile_row_reset_cycle,
        conv_desc.loop_wino_tile_col_reset_cycle,
        conv_desc.loop_outdepth_minitile_baseidx_reset_cycle,
        conv_desc.buffer_address_mid_minitile_depth_step,
        conv_desc.wino_out_size_by_wino_width,
        conv_desc.row_address_bitnumber_flag,
        start_output_row,
        start_row_idx_minus_pad_size
    );

    for(int i=0;i<WINO_WIDTH;i++)
    {
        input_transform(
            input_tile_stream[i],
            input_tile_transformed_stream[i],
            conv_desc.input_transform_feeding_loop_bound,
            i
        );
    }

    weight_feed_one_port<0>(
        weight_DDR0,
        weight_stream[0],
        conv_desc.weightDDR_burst_number,
        conv_desc.weightDDR_buffer_burst_length,
        conv_desc.weightDDR_port_burst_length,
        conv_desc.loop_outdepth_minitile_baseidx_reset_cycle_minus1,
        conv_desc.loop_start_output_baserowcol_reset_cycle,
        conv_desc.loop_weight_feed_bound,
        conv_desc.weightbuffer_outdepth_minitile_number,
        first_flag,
        last_flag
        #if DEBUG_CONV_DESC
        ,conv_desc	
        #endif
    );

    weight_feed_one_port<1>(
        weight_DDR1,
        weight_stream[1],
        conv_desc.weightDDR_burst_number,
        conv_desc.weightDDR_buffer_burst_length,
        conv_desc.weightDDR_port_burst_length,
        conv_desc.loop_outdepth_minitile_baseidx_reset_cycle_minus1,
        conv_desc.loop_start_output_baserowcol_reset_cycle,
        conv_desc.loop_weight_feed_bound,
        conv_desc.weightbuffer_outdepth_minitile_number,
        first_flag,
        last_flag
        #if DEBUG_CONV_DESC
        ,conv_desc	
        #endif
    );

    weight_feed_one_port<2>(
        weight_DDR2,
        weight_stream[2],
        conv_desc.weightDDR_burst_number,
        conv_desc.weightDDR_buffer_burst_length,
        conv_desc.weightDDR_port_burst_length,
        conv_desc.loop_outdepth_minitile_baseidx_reset_cycle_minus1,
        conv_desc.loop_start_output_baserowcol_reset_cycle,
        conv_desc.loop_weight_feed_bound,
        conv_desc.weightbuffer_outdepth_minitile_number,
        first_flag,
        last_flag
        #if DEBUG_CONV_DESC
        ,conv_desc	
        #endif
    );
    
    weight_feed_one_port<3>(
        weight_DDR3,
        weight_stream[3],
        conv_desc.weightDDR_burst_number,
        conv_desc.weightDDR_buffer_burst_length,
        conv_desc.weightDDR_port_burst_length,
        conv_desc.loop_outdepth_minitile_baseidx_reset_cycle_minus1,
        conv_desc.loop_start_output_baserowcol_reset_cycle,
        conv_desc.loop_weight_feed_bound,
        conv_desc.weightbuffer_outdepth_minitile_number,
        first_flag,
        last_flag
        #if DEBUG_CONV_DESC
        ,conv_desc	
        #endif
    );



    // #if DEBUG_FILE_PRINT
    // for(int i=0;i<4;i++)
    // {
    //     for(int j=0;j<WEIGHT_FEED_NUMBER_PER_PORT;j++)
    //     {
    //         int outdepth_minitile_idx=i*WEIGHT_FEED_NUMBER_PER_PORT+j;
    //         char filename[100];
    //         sprintf(filename,"weightstream%d.txt",outdepth_minitile_idx);

    //         attach_weight_stream_content<INDEPTH_MINITILE_SIZE,WINO_DOMAIN_SIZE,WINO_DOMAIN_SIZE_SQUARE>(weight_stream[i][j],filename);
    //     }
    // }
    // #endif

    wino_stream_block(
        input_tile_transformed_stream,
        weight_stream,
        out_buffer,
		conv_desc.weightbuffer_outdepth_minitile_number,
		conv_desc.total_input_stream_tile,
		conv_desc.loop_omini_base_reset_cycle,
		conv_desc.loop_wino_tile_rowcol_self_reset_cycle_min1,
		conv_desc.loop_iload_reset_cycle,
		conv_desc.loop_wino_cell_bound,
		conv_desc.outbuffer_oload_increment_step,
		conv_desc.outbuffer_omini_increment_step,
		conv_desc.wino5x5_flag
        #if DEBUG_CONV_DESC
        ,conv_desc	
        #endif
        ,ap_clk_div2
    );

}


void wino_systolic_top(
    ap_uint<128> *input_DDR0,
    ap_uint<128> *input_DDR1,
    ap_uint<128> *input_DDR2,
    ap_uint<128> *input_DDR3,
    ap_uint<128> *weight_DDR0,
    ap_uint<128> *weight_DDR1,
    ap_uint<128> *weight_DDR2,
    ap_uint<128> *weight_DDR3,
    ap_uint<ODDR_WIDTH*BATCH_SIZE*8> *output_DDR0,
    ap_uint<ODDR_WIDTH*BATCH_SIZE*8> *output_DDR1,
    ap_uint<ODDR_WIDTH*BATCH_SIZE*8> *output_DDR2,
    ap_uint<ODDR_WIDTH*BATCH_SIZE*8> *output_DDR3,
    ConvDesc_t conv_desc,
    ap_uint<1> ap_clk_div2
    )
{


    #pragma HLS interface m_axi port= input_DDR3 depth=65535
    #pragma HLS interface m_axi port= input_DDR2 depth=65535
    #pragma HLS interface m_axi port= input_DDR1 depth=65535
    #pragma HLS interface m_axi port= input_DDR0 depth=65535
    #pragma HLS interface m_axi port= output_DDR3 depth=65535
    #pragma HLS interface m_axi port= output_DDR2 depth=65535
    #pragma HLS interface m_axi port= output_DDR1 depth=65535
    #pragma HLS interface m_axi port= output_DDR0 depth=65535
    #pragma HLS interface m_axi port= weight_DDR3 depth=65535
    #pragma HLS interface m_axi port= weight_DDR2 depth=65535
    #pragma HLS interface m_axi port= weight_DDR1 depth=65535
    #pragma HLS interface m_axi port= weight_DDR0 depth=65535


    //input buffer declaration
    ap_uint<16> input_buffer[INBUFFER_HEIGHT][INBUFFER_WIDTH][INPUT_BUFFER_DEPTH];
    #pragma HLS array_partition variable=input_buffer complete dim=1 
    #pragma HLS array_partition variable=input_buffer complete dim=2 
    ap_uint<OUT_WIDTH*2> output_buffer0[WINO_OUT_SIZE][WINO_OUT_SIZE][OUTDEPTH_MINITILE_SIZE][WINO_WIDTH][OUTPUT_BUFFER_DEPTH];
    #pragma HLS array_partition variable=output_buffer0 complete dim=1 
    #pragma HLS array_partition variable=output_buffer0 complete dim=2 
    ap_uint<OUT_WIDTH*2> output_buffer1[WINO_OUT_SIZE][WINO_OUT_SIZE][OUTDEPTH_MINITILE_SIZE][WINO_WIDTH][OUTPUT_BUFFER_DEPTH];
    #pragma HLS array_partition variable=output_buffer1 complete dim=1 
    #pragma HLS array_partition variable=output_buffer1 complete dim=2 

    ap_uint<1> pingpong;

    #if DEBUG_FILE_PRINT
        clear_buffer_content<INBUFFER_HEIGHT,INBUFFER_WIDTH, INPUT_BUFFER_DEPTH>(input_buffer);
        clear_output_buffer_content_uniformed_hw<OUT_WIDTH,BATCH_SIZE,WINO_HEIGHT,WINO_WIDTH,WINO_OUT_SIZE,OUTPUT_BUFFER_DEPTH>(output_buffer0);
        clear_output_buffer_content_uniformed_hw<OUT_WIDTH,BATCH_SIZE,WINO_HEIGHT,WINO_WIDTH,WINO_OUT_SIZE,OUTPUT_BUFFER_DEPTH>(output_buffer1);
    #endif

    printf("conv_desc.row_address_bitnumber_flag %d\n", conv_desc.row_address_bitnumber_flag);
    load_input_rowtile_from_ddr(
        input_DDR0,
        input_DDR1,
        input_DDR2,
        input_DDR3,
        input_buffer,
		conv_desc.inheight,
		conv_desc.inwidth,
		conv_desc.stride,
        conv_desc.pad_size,
		conv_desc.inwidth_align8,
		conv_desc.indepth_align8,
		conv_desc.group_indepth_x_inwidth_align8_by8,
		conv_desc.group_indepth_offset_x_inwidth_align8_by8,
        conv_desc.inwidth_ceildiv_inbufferwidth,
        conv_desc.buffer_address_mid_increment_step,
		conv_desc.input_load_burst_length,
        conv_desc.row_address_bitnumber_flag,
        conv_desc.out_rowstep,
        0,
        1);

    #if DEBUG_FILE_PRINT
        attach_input_buffer_content_uniformed<INBUFFER_HEIGHT,INBUFFER_WIDTH, INPUT_BUFFER_DEPTH>(input_buffer,0,"input_buffer_content.txt");
    #endif


    ap_uint<16> start_output_row =0;
    for( ; start_output_row < conv_desc.outheight; start_output_row+=conv_desc.out_rowstep)
    {
        ap_uint<16> start_row_idx_minus_pad_size=start_output_row-conv_desc.pad_size;
        
        load_input_rowtile_from_ddr(
            input_DDR0,
            input_DDR1,
            input_DDR2,
            input_DDR3,
            input_buffer,
            conv_desc.inheight,
            conv_desc.inwidth,
            conv_desc.stride,
            conv_desc.pad_size,
            conv_desc.inwidth_align8,
            conv_desc.indepth_align8,
            conv_desc.group_indepth_x_inwidth_align8_by8,
            conv_desc.group_indepth_offset_x_inwidth_align8_by8,
            conv_desc.inwidth_ceildiv_inbufferwidth,
            conv_desc.buffer_address_mid_increment_step,
            conv_desc.input_load_burst_length,
            conv_desc.row_address_bitnumber_flag,
            conv_desc.out_rowstep,
            start_output_row + conv_desc.out_rowstep,
            0);
        #if DEBUG_FILE_PRINT
            attach_input_buffer_content_uniformed<INBUFFER_HEIGHT,INBUFFER_WIDTH, INPUT_BUFFER_DEPTH>(input_buffer,0,"input_buffer_content.txt");
        #endif

        wino_systolic_kernel(
            weight_DDR0,
            weight_DDR1,
            weight_DDR2,
            weight_DDR3,
            input_buffer,
            output_buffer0,
            start_output_row,
            start_row_idx_minus_pad_size,
            start_output_row==0,
            start_output_row+conv_desc.wino_output_tile_size > conv_desc.outheight,
            conv_desc,
            ap_clk_div2
        );
        #if DEBUG_FILE_PRINT
        char outfilename[100];
        sprintf(outfilename,"outbuffer.txt");
        attach_output_buffer_content_uniformed_hw<OUT_WIDTH,BATCH_SIZE,WINO_HEIGHT,WINO_WIDTH,WINO_OUT_SIZE,OUTPUT_BUFFER_DEPTH>(
            output_buffer0,0,outfilename);
        #endif

        write_output_to_DDR(
		output_DDR0,
		output_DDR1,
		output_DDR2,
		output_DDR3,
		output_buffer0,
        conv_desc.outheight,
		conv_desc.outwidth_align8,
		conv_desc.wino_output_tile_size,
        conv_desc.wino_tile_number_in_outwidth,
        conv_desc.wino_tile_number_in_out_rowstep,
		conv_desc.wino_col_pix_upper_bound,
		conv_desc.wino_tile_number_rowcol,
		conv_desc.output_burst_length,
        conv_desc.out_ddr_increment_step,
		start_output_row,
		start_output_row==0
		#if DEBUG_CONV_DESC
		,conv_desc
		#endif
        );
    }












    

}