Skip to content
Snippets Groups Projects
Commit ddd7a4d2 authored by Hashim Sharif's avatar Hashim Sharif
Browse files

Modifying readTrainedWeights calls in Alexnet to read global constants

parent 75166d37
No related branches found
No related tags found
No related merge requests found
......@@ -334,6 +334,8 @@ struct ret_t {
size_t bytes;
};
typedef struct __attribute__((__packed__)) {
void *input;
size_t input_bytes;
......@@ -393,46 +395,53 @@ int main(int argc, char *argv[]) {
std::string input_path = dir_prefix + std::string("test_input.bin");
std::string labels_path = dir_prefix + std::string("test_labels.bin");
uint8_t *labels = readLabels(labels_path.c_str(), 5000);
std::string conv2d_1_w_path = dir_prefix + std::string("conv2d_1_w.bin");
//--std::string conv2d_1_w_path = dir_prefix + std::string("conv2d_1_w.bin");
//--void *conv2d_1_w =
//-- readTrainedWeights(conv2d_1_w_path.c_str(), 0, 64, 3, 11, 11);
void *conv2d_1_w =
readTrainedWeights(conv2d_1_w_path.c_str(), 0, 64, 3, 11, 11);
std::string conv2d_1_b_path = dir_prefix + std::string("conv2d_1_b.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_1_w.bin", 0, 64, 3, 11, 11);
//---std::string conv2d_1_b_path = dir_prefix + std::string("conv2d_1_b.bin");
void *conv2d_1_b =
readTrainedWeights(conv2d_1_b_path.c_str(), 0, 1, 64, 1, 1);
std::string conv2d_2_w_path = dir_prefix + std::string("conv2d_2_w.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_1_b.bin", 0, 1, 64, 1, 1);
//std::string conv2d_2_w_path = dir_prefix + std::string("conv2d_2_w.bin");
void *conv2d_2_w =
readTrainedWeights(conv2d_2_w_path.c_str(), 0, 192, 64, 5, 5);
std::string conv2d_2_b_path = dir_prefix + std::string("conv2d_2_b.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_2_w.bin", 0, 192, 64, 5, 5);
//std::string conv2d_2_b_path = dir_prefix + std::string("conv2d_2_b.bin");
void *conv2d_2_b =
readTrainedWeights(conv2d_2_b_path.c_str(), 0, 1, 192, 1, 1);
std::string conv2d_3_w_path = dir_prefix + std::string("conv2d_3_w.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_2_b.bin", 0, 1, 192, 1, 1);
//std::string conv2d_3_w_path = dir_prefix + std::string("conv2d_3_w.bin");
void *conv2d_3_w =
readTrainedWeights(conv2d_3_w_path.c_str(), 0, 384, 192, 3, 3);
std::string conv2d_3_b_path = dir_prefix + std::string("conv2d_3_b.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_3_w.bin", 0, 384, 192, 3, 3);
//std::string conv2d_3_b_path = dir_prefix + std::string("conv2d_3_b.bin");
void *conv2d_3_b =
readTrainedWeights(conv2d_3_b_path.c_str(), 0, 1, 384, 1, 1);
std::string conv2d_4_w_path = dir_prefix + std::string("conv2d_4_w.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_3_b.bin", 0, 1, 384, 1, 1);
//std::string conv2d_4_w_path = dir_prefix + std::string("conv2d_4_w.bin");
void *conv2d_4_w =
readTrainedWeights(conv2d_4_w_path.c_str(), 0, 256, 384, 3, 3);
std::string conv2d_4_b_path = dir_prefix + std::string("conv2d_4_b.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_4_w.bin", 0, 256, 384, 3, 3);
//std::string conv2d_4_b_path = dir_prefix + std::string("conv2d_4_b.bin");
void *conv2d_4_b =
readTrainedWeights(conv2d_4_b_path.c_str(), 0, 1, 256, 1, 1);
std::string conv2d_5_w_path = dir_prefix + std::string("conv2d_5_w.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_4_b.bin", 0, 1, 256, 1, 1);
///std::string conv2d_5_w_path = dir_prefix + std::string("conv2d_5_w.bin");
void *conv2d_5_w =
readTrainedWeights(conv2d_5_w_path.c_str(), 0, 256, 256, 3, 3);
std::string conv2d_5_b_path = dir_prefix + std::string("conv2d_5_b.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_5_w.bin", 0, 256, 256, 3, 3);
///std::string conv2d_5_b_path = dir_prefix + std::string("conv2d_5_b.bin");
void *conv2d_5_b =
readTrainedWeights(conv2d_5_b_path.c_str(), 0, 1, 256, 1, 1);
std::string dense_1_w_path = dir_prefix + std::string("dense_1_w.bin");
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/conv2d_5_b.bin", 0, 1, 256, 1, 1);
//std::string dense_1_w_path = dir_prefix + std::string("dense_1_w.bin");
void *dense_1_w =
readTrainedWeights(dense_1_w_path.c_str(), 0, 1, 1, 4096, 10);
std::string dense_1_b_path = dir_prefix + std::string("dense_1_b.bin");
void *dense_1_b = readTrainedWeights(dense_1_b_path.c_str(), 0, 1, 10, 1, 1);
readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/dense_1_w.bin", 0, 1, 1, 4096, 10);
///std::string dense_1_b_path = dir_prefix + std::string("dense_1_b.bin");
void *dense_1_b = readTrainedWeights(MODEL_PARAMS_DIR_STR "/alexnet_cifar10/dense_1_b.bin", 0, 1, 10, 1, 1);
RootIn *args = static_cast<RootIn *>(malloc(sizeof(RootIn)));
void *input = create4DTensor(0, nchw, batch_size, 3, 32, 32);
args->input = input;
args->input_bytes = 0;
//void *input = create4DTensor(0, nchw, batch_size, 3, 32, 32);
//args->input = input;
//args->input_bytes = 0;
args->conv2d_1_w = conv2d_1_w;
args->conv2d_1_w_bytes = 0;
args->conv2d_1_b = conv2d_1_b;
......@@ -459,6 +468,7 @@ int main(int argc, char *argv[]) {
args->dense_1_b_bytes = 0;
__hpvm__init();
if (config_path != "") {
llvm_hpvm_initializeRuntimeController(config_path.c_str());
}
......@@ -482,3 +492,5 @@ int main(int argc, char *argv[]) {
__hpvm__cleanup();
return 0;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment