#ifndef _WINO_TRANSFORM_CPP_ #define _WINO_TRANSFORM_CPP_ #include "wino_macro.h" #include <ap_int.h> #define DB6x6_1(in,out,ridx,bidx) \ out[ridx][0][bidx]=((in[ridx][0][bidx]-in[ridx][2][bidx])*4-in[ridx][2][bidx]+in[ridx][4][bidx])>>DB_QUANT_BIT;\ out[ridx][1][bidx]=(in[ridx][3][bidx]+in[ridx][4][bidx]-(in[ridx][1][bidx]+in[ridx][2][bidx])*4)>>DB_QUANT_BIT;\ out[ridx][2][bidx]=((in[ridx][1][bidx]-in[ridx][2][bidx])*4+in[ridx][4][bidx]-in[ridx][3][bidx])>>DB_QUANT_BIT;\ out[ridx][3][bidx]=((in[ridx][3][bidx]-in[ridx][1][bidx])*2+in[ridx][4][bidx]-in[ridx][2][bidx])>>DB_QUANT_BIT;\ out[ridx][4][bidx]=((in[ridx][1][bidx]-in[ridx][3][bidx])*2+in[ridx][4][bidx]-in[ridx][2][bidx])>>DB_QUANT_BIT;\ out[ridx][5][bidx]=((in[ridx][1][bidx]-in[ridx][3][bidx])*4+in[ridx][5][bidx]-in[ridx][3][bidx])>>DB_QUANT_BIT; #define BTB6x6_1(in,out,cidx,bidx) \ out[0][cidx][bidx]=((in[0][cidx][bidx]-in[2][cidx][bidx])*4-in[2][cidx][bidx]+in[4][cidx][bidx])>>BTB_QUANT_BIT;\ out[1][cidx][bidx]=(in[3][cidx][bidx]+in[4][cidx][bidx]-(in[1][cidx][bidx]+in[2][cidx][bidx])*4)>>BTB_QUANT_BIT;\ out[2][cidx][bidx]=((in[1][cidx][bidx]-in[2][cidx][bidx])*4+in[4][cidx][bidx]-in[3][cidx][bidx])>>BTB_QUANT_BIT;\ out[3][cidx][bidx]=((in[3][cidx][bidx]-in[1][cidx][bidx])*2+in[4][cidx][bidx]-in[2][cidx][bidx])>>BTB_QUANT_BIT;\ out[4][cidx][bidx]=((in[1][cidx][bidx]-in[3][cidx][bidx])*2+in[4][cidx][bidx]-in[2][cidx][bidx])>>BTB_QUANT_BIT;\ out[5][cidx][bidx]=((in[1][cidx][bidx]-in[3][cidx][bidx])*4+in[5][cidx][bidx]-in[3][cidx][bidx])>>BTB_QUANT_BIT; #define DB4x4_1(in,out,ridx,bidx) \ out[ridx][0][bidx]=(in[ridx][0][bidx]-in[ridx][2][bidx])>>DB_QUANT_BIT;\ out[ridx][1][bidx]=(in[ridx][1][bidx]+in[ridx][2][bidx])>>DB_QUANT_BIT;\ out[ridx][2][bidx]=(in[ridx][2][bidx]-in[ridx][1][bidx])>>DB_QUANT_BIT;\ out[ridx][3][bidx]=(in[ridx][1][bidx]-in[ridx][3][bidx])>>DB_QUANT_BIT; #define BTB4x4_1(in,out,cidx,bidx) \ out[0][cidx][bidx]=(in[0][cidx][bidx]-in[2][cidx][bidx])>>BTB_QUANT_BIT;\ out[1][cidx][bidx]=(in[1][cidx][bidx]+in[2][cidx][bidx])>>BTB_QUANT_BIT;\ out[2][cidx][bidx]=(in[2][cidx][bidx]-in[1][cidx][bidx])>>BTB_QUANT_BIT;\ out[3][cidx][bidx]=(in[1][cidx][bidx]-in[3][cidx][bidx])>>BTB_QUANT_BIT; void UVA_row( ap_int<UVA_WIDTH> out[WINO_DOMAIN_SIZE][WINO_OUT_SIZE][BATCH_SIZE], ap_int<UV_WIDTH> in[WINO_DOMAIN_SIZE][WINO_DOMAIN_SIZE][BATCH_SIZE], int ridx, int bidx, ap_uint<1> flag) { #pragma HLS inline #if WINO_DOMAIN_SIZE == 4 out[ridx][0][bidx]=(in[ridx][0][bidx]+in[ridx][1][bidx]+in[ridx][2][bidx])>>UVA_QUANT_BIT; out[ridx][1][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]-in[ridx][3][bidx])>>UVA_QUANT_BIT; #elif WINO3x3_EN && WINO5x5_EN out[ridx][0][bidx]=(in[ridx][0][bidx]+in[ridx][1][bidx]+in[ridx][2][bidx]+in[ridx][3][bidx]+in[ridx][4][bidx])>>UVA_QUANT_BIT; out[ridx][1][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]+(in[ridx][3][bidx]-in[ridx][4][bidx])*2+in[ridx][5][bidx]*flag)>>UVA_QUANT_BIT; out[ridx][2][bidx]=(in[ridx][1][bidx]+in[ridx][2][bidx]+(in[ridx][3][bidx]+in[ridx][4][bidx])*4)>>UVA_QUANT_BIT; out[ridx][3][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]+(in[ridx][3][bidx]-in[ridx][4][bidx])*8+in[ridx][5][bidx])>>UVA_QUANT_BIT; #elif WINO3x3_EN out[ridx][0][bidx]=(in[ridx][0][bidx]+in[ridx][1][bidx]+in[ridx][2][bidx]+in[ridx][3][bidx]+in[ridx][4][bidx])>>UVA_QUANT_BIT; out[ridx][1][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]+(in[ridx][3][bidx]-in[ridx][4][bidx])*2)>>UVA_QUANT_BIT; out[ridx][2][bidx]=(in[ridx][1][bidx]+in[ridx][2][bidx]+(in[ridx][3][bidx]+in[ridx][4][bidx])*4)>>UVA_QUANT_BIT; out[ridx][3][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]+(in[ridx][3][bidx]-in[ridx][4][bidx])*8+in[ridx][5][bidx])>>UVA_QUANT_BIT; #else out[ridx][0][bidx]=(in[ridx][0][bidx]+in[ridx][1][bidx]+in[ridx][2][bidx]+in[ridx][3][bidx]+in[ridx][4][bidx])>>UVA_QUANT_BIT; out[ridx][1][bidx]=(in[ridx][1][bidx]-in[ridx][2][bidx]+(in[ridx][3][bidx]-in[ridx][4][bidx])*2+in[ridx][4][bidx])>>UVA_QUANT_BIT; #endif } void ATA_col( ap_int<ATA_WIDTH> out[WINO_OUT_SIZE][WINO_OUT_SIZE][BATCH_SIZE], ap_int<UVA_WIDTH> in[WINO_DOMAIN_SIZE][WINO_OUT_SIZE][BATCH_SIZE], int cidx, int bidx, ap_uint<1> flag) { #pragma HLS inline #if WINO_DOMAIN_SIZE == 4 out[0][cidx][bidx]=(in[0][cidx][bidx]+in[1][cidx][bidx]+in[2][cidx][bidx])>>ATA_QUANT_BIT; out[1][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]-in[3][cidx][bidx])>>ATA_QUANT_BIT; #elif WINO3x3_EN && WINO5x5_EN out[0][cidx][bidx]=(in[0][cidx][bidx]+in[1][cidx][bidx]+in[2][cidx][bidx]+in[3][cidx][bidx]+in[4][cidx][bidx])>>ATA_QUANT_BIT; out[1][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]+(in[3][cidx][bidx]-in[4][cidx][bidx])*2+in[5][cidx][bidx]*flag)>>ATA_QUANT_BIT; out[2][cidx][bidx]=(in[1][cidx][bidx]+in[2][cidx][bidx]+(in[3][cidx][bidx]+in[4][cidx][bidx])*4)>>ATA_QUANT_BIT; out[3][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]+(in[3][cidx][bidx]-in[4][cidx][bidx])*8+in[5][cidx][bidx])>>ATA_QUANT_BIT; #elif WINO3x3_EN out[0][cidx][bidx]=(in[0][cidx][bidx]+in[1][cidx][bidx]+in[2][cidx][bidx]+in[3][cidx][bidx]+in[4][cidx][bidx])>>ATA_QUANT_BIT; out[1][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]+(in[3][cidx][bidx]-in[4][cidx][bidx])*2)>>ATA_QUANT_BIT; out[2][cidx][bidx]=(in[1][cidx][bidx]+in[2][cidx][bidx]+(in[3][cidx][bidx]+in[4][cidx][bidx])*4)>>ATA_QUANT_BIT; out[3][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]+(in[3][cidx][bidx]-in[4][cidx][bidx])*8+in[5][cidx][bidx])>>ATA_QUANT_BIT; #else out[0][cidx][bidx]=(in[0][cidx][bidx]+in[1][cidx][bidx]+in[2][cidx][bidx]+in[3][cidx][bidx]+in[4][cidx][bidx])>>ATA_QUANT_BIT; out[1][cidx][bidx]=(in[1][cidx][bidx]-in[2][cidx][bidx]+(in[3][cidx][bidx]-in[4][cidx][bidx])*2+in[4][cidx][bidx])>>ATA_QUANT_BIT; #endif } #endif