diff --git a/hpvm/docs/KerasFrontend.md b/hpvm/docs/KerasFrontend.md new file mode 100644 index 0000000000000000000000000000000000000000..3225b112ad4f1e8c03b69af1b330a2dbced24ab1 --- /dev/null +++ b/hpvm/docs/KerasFrontend.md @@ -0,0 +1,191 @@ +# Keras Frontend + +Install Keras Frontend after moving to directory `/hpvm/hpvm/projects/keras` + +## Requirements + +* python == 3.6.x +* pip >= 18 + +If your system uses a different Python version, we recommend using the conda environment `keras_python36.yml`. Install this using: + +``` +conda env create -f keras_python36.yml --name keras_python36 +``` + +Activate the conda environment before installing the pip package (below) using: + +``` +conda activate keras_python36 +``` + +**NOTE:** This step must be performed each time (for each shell process) the frontend is to be used. + + +## Installing the Keras Frontend Package + +At the root of this project (`/projects/keras/`) install the Keras frontend pip package as: + +``` +pip3 install -e ./ +``` + +**NOTE:** If you are using the conda environment, activate it prior to this step. + +## Suppported Operations + +List of supported operations and limitations are documented [here](../projects/keras/docs/Support.md) + + + +# Keras Benchmarks + +Run the Keras benchmarks under `hpvm/hpvm/test/dnn_benchmarks/keras` + +## Download CNN Model Files + +Prior to running the benchmarks, ensure you download the CNN model data (inputs and weights) if not done in automatic build script. + +``` +wget https://databank.illinois.edu/datafiles/o3izd/download -O model_params.tar.gz +tar -xf model_params.tar.gz +``` + +Move extracted `model_params` directory to `/test/dnn_benchmarks/model_params` (Benchmarks expect data at this location) + + +## Running Benchmaks + +List of benchmarks and the expected accuracies: + +| Benchmark | Accuracy | +| ----------- | ----------- | +| alexnet.py | 79.28 | +| alexnet2.py | 84.98 | +| alexnet_imagenet.py | 56.30 | +| lenet.py | 98.70 | +| mobilenet_cifar10.py | 84.42 | +| resnet18_cifar10.py | 89.56 | +| resnet50_imagenet.py | 75.10 | +| vgg16_cifar10.py | 89.96 | +| vgg16_cifar100.py | 66.50 | +| vgg16_imagenet.py | 69.46 | + + +### Synopsis + +``` +python3 ${BENCH_NAME}.py [hpvm_reload|keras_reload] [frontend] [compile] + +``` + + +**Command-line Parameters** + +`hpvm_reload` : Reloads HPVM weights (`.bin` binary format used by HPVM weights - present in `model_params` download directory) from directory path specified in the `reload_dir` parameter set in code - this is described in "Parameters to Change in Code" (below). + +`keras_reload`: Alternatively, reload weights in Keras `.h5` file format with path to file specified in `keras_model_file` described in "Parameters to Change in Code" (below). + +`frontend`: Invokes the HPVM frontend and dumps weights (in HPVM `.bin` format) in the output directory specified. The parameters that control where data and source files are dumped are specified by parameters `data_dir` and `src_dir`, respectively. These are described below. + +`compile`: Optional Parameter. When specified, it compiles the HPVM-C code generated by the frontend into an HPVM binary under the directory specified by `src_dir` (described below). If `src_dir` path exists, a unique directory (which appends a unique ID) is created. +The binary is built with the name `HPVM_binary`. + +**NOTE:** Before running `HPVM_binary` necessary to set CUDA and CUDNN paths with: + +``` +source ${PATH_TO_YOUR_HPVM_ROOT}/hpvm/set_paths.sh +``` + +**Parameters to Change in Code** + +The AlexNet source is commented with explanations on how to use the Keras frontend interface. AlexNet source is [here](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/alexnet.py). + +* `NAME`: Benchmark Name - Can be set to any desired value + +* `reload_dir`: Path to directory from where to reload weights in HPVM format. This directory is used to reload weights if `hpvm_reload` command-line option is used. + +* `keras_model_file`: Path to Keras .h5 model file to reload weigths from. Either of `reload_dir` or `keras_model_file` can be used. +`keras_model_file` is used when `keras_reload` commad-line parameter is used with the Benchmark script. + +* `data_dir`: Directory to dump weights specified specified in [constructor](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/Benchmark.py#L21) + +* `src_dir`: Directory to dump ApproxHPVM sources in HPVM-C (C with HPVM compiler intrinsics) specified in [constructor](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/Benchmark.py#L22) + +* `num_classes`: number of output classes - dependent on the dataset used. For CIFAR10, `num_classes` is 10, CIFAR100 has 100 classes, + for ImageNet, number of classes is 1000. + +* `batch_size`: This parameter controls the size of each batch that is processed in HPVM. The batch size should be kept as large as the GPU memory +can support. This parameter should be adapted according to the memory size of the deployed device. + + + +### Using the Frontend with Custom (New) Benchmarks + +Any new benchmarks must inherit from the commom parent `Benchmark` class +and override the virtual functions for building the model, training, +and data preprocessing. These methods are described below: + + +`def buildModel(self)`: +Constructs and returns a keras model + +`def data_preprocess(self)`: +returns X_train, y_train, X_test, y_test, X_tuner, and y_tuner data (in that order): +These are described here: + +* `X_train:` Training data (fp32) in NCHW format +* `y_train:` Training labels (int32) + +* `X_test:` Testing/Evaluation data in NCHW format +* `y_test:` Testing/Evaluation labels + +* `X_tuner:` Data to be used for autotuning +* `y_tuner:` Labels corresponding to tuning data + + +`def trainModel(self, model, X_train, y_train, X_test, y_test)`: +Trains the Keras model constructed in `buildModel` and is expected to return the +trained keras model - training parameters should be tuned here. + +### Directly using Keras Frontend API + +Alternate to extending the `Benchmark` class, users may directly invoke the Keras Frontend API. This can be done as: + +```python + +from keras_frontend.approxhpvm_translator import translate_to_approxhpvm + +# Construct and train your Keras Model (or load pre-trained weights) + +translate_to_approxhpvm(model, data_dir, src_dir, test_data, test_labels, tune_data, tune_labels, batch_size, num_classes) + +``` + +## Running HPVM Binary + +Run the `HPVM_binary` generated under the directory specified by `src_dir` (described above). Usage: + +``` +./HPVM_binary -t {test|tune} -c ${config_file_path} +``` + +`test|tune`: Runs with either tune (autotuning data) or test set (for evaluation) + +`config_file_path`: Path to an HPVM tensor configuration file (includes approximation settings) + +**NOTE:** The accuracy of the bennchmarks is dumped into a file named `final_accuracy` in the current working directory - this includes accuracy averaged across batches + +## Automated Tests + +`scripts/test_benchmarks.py` is an automated test script that evaluates the accuracy of each Benchmark in Keras and HPVM (after comilation using HPVM Compiler) and compares the accuracy of each binary to the known correct accuracy. Run from root of `/test/dnn_benchmarks/keras`: + +``` +python test_benchmarks.py +``` + + + + + + diff --git a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/hpvm-rt-controller.cpp b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/hpvm-rt-controller.cpp index 47323b1e652d3f7990d03c74bab6fefefeb04efa..bea66370ba073490fe7970014f1005f123e58988 100644 --- a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/hpvm-rt-controller.cpp +++ b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/hpvm-rt-controller.cpp @@ -432,6 +432,8 @@ NodeConfiguration *RuntimeController::getNodeConfiguration(const char *data) { } void RuntimeController::init(const char *Cstr) { + INFO("INIT RUNTIME CONTROLLER ==================\n"); + printf("INIT RUNTIME CONTROLLER ==================\n"); // We initialize the path to the profile info output file, // based on the path given for the configuration file PI->set_out_file_name("profile_info.txt"); @@ -836,7 +838,7 @@ void RuntimeController::readConfigurationFile(const char *str) { firstTensorID += NodeConf->getApproxChoices().size(); } else if (tokens[1] == "cpu") { - DEBUG("Found gpu configuration\n"); + INFO("---------Found cpu configuration\n"); // There must be at least one operation, with an approximation option CUSTOM_ASSERT((tokens.size() >= 5) && @@ -846,6 +848,10 @@ void RuntimeController::readConfigurationFile(const char *str) { InitialConfigurations.back().setup.insert( std::make_pair(tokens[0], NodeConf)); + InitialConfigurations.back().idConfigMap.insert( + std::make_pair(firstTensorID, NodeConf)); + INFO("*** firstTensorID = %d \n\n", firstTensorID); + INFO("FIXED CPU ID ISSUE\n"); unsigned idx = 2; while (idx < tokens.size()) { if (tokens[idx] == "add") { @@ -933,7 +939,7 @@ void RuntimeController::readConfigurationFile(const char *str) { } // TODO: other approximation options handled here } - + firstTensorID += NodeConf->getApproxChoices().size(); } else { DEBUG("Invalid Configuration File\n"); exit(1); diff --git a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/tensor_cpu_runtime.cc b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/tensor_cpu_runtime.cc index 8f0b2dba1147b6a674d760a22112f03f962bf89c..d3a037ea152c279c1528d8a4ec24d32f28df64f9 100644 --- a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/tensor_cpu_runtime.cc +++ b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/tensor_cpu_runtime.cc @@ -154,6 +154,7 @@ void *tensorRegularConvolutionCPU(void *input_ptr, void *filter_ptr, output_width * batch_size; float *host_data = (float *)malloc(conv_data_size); printf("host data: %p\n", host_data); + printf("conv_data_size: %d\n", conv_data_size); printf("number of batches: %d\n", batch_size); omp_set_num_threads(4); #pragma omp parallel for @@ -1066,7 +1067,7 @@ void *tensorSoftmaxCPU(void *input_ptr) { float *logits = (float *)input->host_data; int n = input->dims.dim_sizes[0]; int c = input->dims.dim_sizes[1]; - + omp_set_num_threads(4); #pragma omp parallel for for (int i = 0; i < n; i++) { @@ -1085,7 +1086,6 @@ void *tensorSoftmaxCPU(void *input_ptr) { logits[j] /= x; } } - return input; } @@ -1095,10 +1095,14 @@ void *tensorBatchNormCPU(void *input_ptr, void *gamma_ptr, void *beta_ptr, Tensor *input = (Tensor *)input_ptr; Tensor *gamma = (Tensor *)gamma_ptr; Tensor *beta = (Tensor *)beta_ptr; + Tensor *mean = (Tensor *)mean_ptr; + Tensor *variance = (Tensor *)variance_ptr; float *__restrict__ host_image = (float *)input->host_data; float *__restrict__ host_beta = (float *)beta->host_data; float *__restrict__ host_gamma = (float *)gamma->host_data; + float *__restrict__ host_mean = (float *)mean->host_data; + float *__restrict__ host_variance = (float *)variance->host_data; int batch_size = input->dims.dim_sizes[0]; int channels = input->dims.dim_sizes[1]; @@ -1110,32 +1114,19 @@ void *tensorBatchNormCPU(void *input_ptr, void *gamma_ptr, void *beta_ptr, #pragma omp parallel for for (int b = 0; b < batch_size; b++) { for (int ch = 0; ch < channels; ch++) { - float mean = 0; -#pragma omp simd reduction(+ : mean) - for (int i = 0; i < image_dim; i++) { - int index = b * channels * image_dim + ch * image_dim + i; - mean += host_image[index]; - } - mean = mean / channels; - - float variance = 0; -#pragma omp simd reduction(+ : variance) - for (int i = 0; i < image_dim; i++) { - int index = b * channels * image_dim + ch * image_dim + i; - float tmp = host_image[index] - mean; - variance += (tmp * tmp); - } - variance = variance / channels; - -#pragma omp simd + float mean = host_mean[ch]; + float sq_ep_var = sqrt(epsilon + host_variance[ch]); + float gamma = host_gamma[ch]; + float beta = host_beta[ch]; + #pragma omp simd for (int i = 0; i < image_dim; i++) { int index = b * channels * image_dim + ch * image_dim + i; host_image[index] = - host_beta[ch] + (host_gamma[ch] * ((host_image[index] - mean) / - sqrt(epsilon + variance))); + beta + (gamma * ((host_image[index] - mean) / sq_ep_var)); } } } + return input; } diff --git a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/wrapper_runtime.cu b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/wrapper_runtime.cu index 8c77234e2432bd5fe1cde144b031d42273140d42..b9f52c0c8dddb8e7a4aa37abec5ea0d9dfa7164b 100644 --- a/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/wrapper_runtime.cu +++ b/hpvm/projects/hpvm-tensor-rt/tensor_runtime/src/wrapper_runtime.cu @@ -39,10 +39,43 @@ #include "approxhpvm_runtime_utils.h" #include "approx_api.h" +#include "tensor_runtime.h" +#include "tensor_cpu_runtime.h" + extern "C" { /**** Wrapper Runtime API ***/ +// Initialization and clean routines for various supported devices + void llvm_libtensorhpvm_init(int gpuid) { + llvm_hpvm_initApproxhpvmRt(gpuid); + llvm_hpvm_initTensorRtCPU(); + } + + void llvm_libtensorhpvm_cleanup() { + llvm_hpvm_cleanupApproxhpvmRt(); + llvm_hpvm_cleanupTensorRtCPU(); + } + + void llvm_libtensorhpvm_request_tensor(const char* hpvm_node_id, void* tensor) { + + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + if (NodeConf->isGPUNodeConfiguration()) { + DEBUG("GPU Configuration detected at node %s: requesting tensor\n", hpvm_node_id); + hpvm_request_tensor(tensor, 1); // 1 for GPU + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("CPU Configuration detected at node %s: requesting tensor\n", hpvm_node_id); + hpvm_request_tensor(tensor, 0); // 0 for CPU + } else { + ERROR("Currently unsupported configuration\n"); + abort(); + } + } + + + + + void * wrapper_ConvLayer(const char *hpvm_node_id, void *input, void *filter, void *bias, int conv_pad_h, int conv_pad_w, int conv_stride_h, @@ -164,9 +197,129 @@ wrapper_ConvLayer(const char *hpvm_node_id, void *input, void *filter, pool_out = activation_out; } return pool_out; - } else { - ERROR("Unsupported Configuration"); - abort(); + } else if(NodeConf->isCPUNodeConfiguration()) { + DEBUG("CPU Configuration for ConvLayer\n"); + // Mapped to GPU - get a GPU node configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = CPUConf->getApproxChoices(); + + // Check for convolution as first operation + CUSTOM_ASSERT((ApproxChoices.size() >= 1) && + (ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::CONV) && + "Incorrect number/type of operations in provided Conv layer configuration"); + + void* conv_out = handleTensorConvApproximationTuples_CPU(ApproxChoices[0].second, + input, filter, conv_pad_h, conv_pad_w, + conv_stride_h, conv_stride_w); + void* add_out; + if (bias != NULL) { + // Check for add as second operation + CUSTOM_ASSERT((ApproxChoices.size() >= 2) && + (ApproxChoices[1].first == CPUNodeConfiguration::TENSOR_OP::ADD) && + "Incorrect number/type of operations in provided Conv layer configuration"); + add_out = handleTensorAddApproximationTuples_CPU(ApproxChoices[1].second, + conv_out, bias); + } else { + add_out = conv_out; + } + + void* activation_out; + switch (activation_id) { + case -1: + { // No activation + INFO("No activation Function\n"); + activation_out = add_out; + } + break; + case 0: + { // TanH activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::TANH) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = handleTensorTanhApproximationTuples_CPU(ApproxChoices[2].second, + add_out); + } + break; + case 1: + { // ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::RELU) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = handleTensorReluApproximationTuples_CPU(ApproxChoices[2].second, + add_out); + } + break; + case 2: + { // Clipped ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = + handleTensorClippedReluApproximationTuples_CPU(ApproxChoices[2].second, + add_out, out_min, out_max); + } + break; + default: + { + ERROR("Activation id %d NOT supported \n", activation_id); + } + break; + } + + void* pool_out; + + if (pool_size > 0) { + switch (pool_id) { + case 0: + { + // If we remove the asserts, we can have all cases handled by a single call + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MAX) && + "Expected POOL_MAX in provided Conv layer configuration"); + pool_out = + handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size, pool_size, 0, 0, + pool_size, pool_size); + } + break; + case 1: + { + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MEAN) && + "Expected POOL_MEAN in provided Conv layer configuration"); + pool_out = + handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size, pool_size, 0, 0, + pool_size, pool_size); + } + break; + case 2: + { + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MIN) && + "Expected POOL_MIN in provided Conv layer configuration"); + pool_out = + handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size, pool_size, 0, 0, + pool_size, pool_size); + } + break; + default: + { + ERROR("Pool id %d NOT supported \n", pool_id); + } + break; + } + } else { + pool_out = activation_out; + } + return pool_out; + } else { + ERROR("Unsupported Configuration"); + abort(); } return NULL; @@ -180,11 +333,12 @@ void *wrapper_ConvLayer2( // NOTE: out_min, out_max are only relevant for ClippedRelu float out_min, float out_max) { - INFO("*** Conv Layer \n"); + INFO("*** ------Conv Layer \n"); NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + INFO("HERE\n"); if (NodeConf->isGPUNodeConfiguration()) { - DEBUG("GPU Configuration for ConvLayer\n"); + INFO("GPU Configuration for ConvLayer\n"); // Mapped to GPU - get a GPU node configuration GPUNodeConfiguration *GPUConf = (GPUNodeConfiguration *)NodeConf; @@ -306,7 +460,136 @@ void *wrapper_ConvLayer2( pool_out = activation_out; } return pool_out; - } else { + } else if (NodeConf->isCPUNodeConfiguration()) { + INFO("CPU Configuration for ConvLayer\n"); + // Mapped to CPU - get a CPU node configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Check for convolution as first operation + CUSTOM_ASSERT((ApproxChoices.size() >= 1) && + (ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::CONV) && + "Incorrect number/type of operations in provided Conv layer configuration"); + + void* conv_out = handleTensorConvApproximationTuples_CPU(ApproxChoices[0].second, + input, filter, conv_pad_h, conv_pad_w, + conv_stride_h, conv_stride_w); + + void* add_out; + if (bias != NULL) { + // Check for add as second operation + CUSTOM_ASSERT((ApproxChoices.size() >= 2) && + (ApproxChoices[1].first == CPUNodeConfiguration::TENSOR_OP::ADD) && + "Incorrect number/type of operations in provided Conv layer configuration"); + add_out = handleTensorAddApproximationTuples_CPU(ApproxChoices[1].second, + conv_out, bias); + } else { + add_out = conv_out; + } + + void* activation_out; + switch (activation_id) { + case -1: + { // No activation + INFO("No activation Function\n"); + activation_out = add_out; + } + break; + case 0: + { // TanH activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::TANH) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = handleTensorTanhApproximationTuples_CPU(ApproxChoices[2].second, + add_out); + } + break; + case 1: + { // ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::RELU) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = handleTensorReluApproximationTuples_CPU(ApproxChoices[2].second, + add_out); + } + break; + case 2: + { // Clipped ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() >= 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU) && + "Incorrect number/type of operations in provided Conv layer configuration"); + activation_out = + handleTensorClippedReluApproximationTuples_CPU(ApproxChoices[2].second, + add_out, out_min, out_max); + } + break; + default: + { + ERROR("Activation id %d NOT supported \n", activation_id); + } + break; + } + + void* pool_out; + + if (pool_size_v > 0) { + switch (pool_id) { + case 0: + { + // If we remove the asserts, we can have all cases handled by a single call + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MAX) && + "Expected POOL_MAX in provided Conv layer configuration"); + + pool_out = handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size_v, pool_size_h, + pool_pad_v, pool_pad_h, + pool_stride_v, pool_stride_h); + } + break; + case 1: + { + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MEAN) && + "Expected POOL_MEAN in provided Conv layer configuration"); + + // FIXIT: POOL_MEAN still needs fixing + pool_out = + handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size_v, pool_size_h, + 0, 0, + pool_size_v, pool_size_h); + } + break; + case 2: + { + CUSTOM_ASSERT((ApproxChoices.back().first == CPUNodeConfiguration::TENSOR_OP::POOL_MIN) && + "Expected POOL_MIN in provided Conv layer configuration"); + // FIXIT: Pool_MEAN needs fixing + pool_out = + handleTensorPoolingApproximationTuples_CPU(ApproxChoices.back().second, + activation_out, pool_id, + pool_size_v, pool_size_h, 0, 0, + pool_size_v, pool_size_h); + } + break; + default: + { + ERROR("Pool id %d NOT supported \n", pool_id); + } + break; + } + } else { + pool_out = activation_out; + } + return pool_out; + + } + else { ERROR("Unsupported Configuration"); abort(); } @@ -386,7 +669,74 @@ wrapper_FCLayer(const char *hpvm_node_id, void *input, void *weights, } break; } return activation_out; - } else { + } else if (NodeConf->isCPUNodeConfiguration()){ + DEBUG("CPU Configuration for FCLayer\n"); + // Mapped to CPU - get a CPU node configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a FC wrapper operation + CUSTOM_ASSERT((ApproxChoices.size() == 2 || ApproxChoices.size() == 3) && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::MUL && + ApproxChoices[1].first == CPUNodeConfiguration::TENSOR_OP::ADD && + "Invalid configuration generated for FC layer wrapper operation"); + + void* gemm_out = handleTensorMulApproximationTuples_CPU(ApproxChoices[0].second, + input, weights); + void* add_out = handleTensorAddApproximationTuples_CPU(ApproxChoices[1].second, + gemm_out, bias); + + void* activation_out; + switch (activation_id) { + case -1: + { // No activation + CUSTOM_ASSERT((ApproxChoices.size() == 2) && + "Incorrect number of operations in provided FC layer configuration"); + INFO("No activation Function\n"); + activation_out = add_out; + } + break; + case 0: + { // TanH activation + CUSTOM_ASSERT((ApproxChoices.size() == 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::TANH) && + "Incorrect number/type of operations in provided FC layer configuration"); + activation_out = handleTensorTanhApproximationTuples_CPU(ApproxChoices[1].second, + add_out); + } + break; + case 1: + { // ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() == 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::RELU) && + "Incorrect number/type of operations in provided FC layer configuration"); + activation_out = handleTensorReluApproximationTuples_CPU(ApproxChoices[1].second, + add_out); + } + break; + case 2: + { // Clipped ReLU activation + CUSTOM_ASSERT((ApproxChoices.size() == 3) && + (ApproxChoices[2].first == CPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU) && + "Incorrect number/type of operations in provided FC layer configuration"); + activation_out = + handleTensorClippedReluApproximationTuples_CPU(ApproxChoices[1].second, + add_out, out_min, out_max); + } + break; + default: + { + ERROR("Activation id %d NOT supported \n", activation_id); + } + break; + } + return activation_out; + } + else { ERROR("Unsupported Configuration"); abort(); } @@ -398,66 +748,152 @@ void *wrapper_tensorRelu(const char *hpvm_node_id, void *input_ptr) { INFO("*** Relu Operation \n"); - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + + if (NodeConf->isGPUNodeConfiguration()) { + + // Only mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)NodeConf; - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = GPUConf->getApproxChoices(); + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = GPUConf->getApproxChoices(); - // Approximation choices must be for a relu operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::RELU && - "Invalid configuration generated for tensor relu wrapper operation"); + // Approximation choices must be for a relu operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::RELU && + "Invalid configuration generated for tensor relu wrapper operation"); + + return handleTensorReluApproximationTuples(ApproxChoices[0].second, + input_ptr); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("ReLU operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a relu operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::RELU && + "Invalid configuration generated for tensor relu wrapper operation"); + + return handleTensorReluApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + + return NULL; - return handleTensorReluApproximationTuples(ApproxChoices[0].second, - input_ptr); } void *wrapper_tensorClippedRelu(const char *hpvm_node_id, void *input_ptr, float out_min, float out_max) { - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); - - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = GPUConf->getApproxChoices(); - - // Approximation choices must be for a relu operation - CUSTOM_ASSERT(ApproxChoices.size() == 1 && - ApproxChoices[0].first == - GPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU && - "Invalid configuration generated for tensor clipped relu " - "wrapper operation"); - - return handleTensorClippedReluApproximationTuples( - ApproxChoices[0].second, input_ptr, out_min, out_max); + + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + if (NodeConf->isGPUNodeConfiguration()) { + + // mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)NodeConf; + + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = GPUConf->getApproxChoices(); + + // Approximation choices must be for a relu operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == + GPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU && + "Invalid configuration generated for tensor clipped relu " + "wrapper operation"); + + return handleTensorClippedReluApproximationTuples( + ApproxChoices[0].second, input_ptr, out_min, out_max); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("Clipped ReLU operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a clipped relu operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::CLIPPED_RELU && + "Invalid configuration generated for tensor clipped relu wrapper operation"); + + return handleTensorClippedReluApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr, out_min, out_max); + + + } else { + ERROR("Unsupported Configuration"); + abort(); + } + return NULL; + } void *wrapper_tensorTanh(const char *hpvm_node_id, void *input_ptr) { // return tensorTanh(input_ptr); - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = GPUConf->getApproxChoices(); + if (NodeConf->isGPUNodeConfiguration()) { + GPUNodeConfiguration *GPUConf = (GPUNodeConfiguration *)NodeConf; - // Approximation choices must be for a tanh operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::TANH && - "Invalid configuration generated for tensor tanh wrapper operation"); + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = GPUConf->getApproxChoices(); + + // Approximation choices must be for a tanh operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::TANH && + "Invalid configuration generated for tensor tanh wrapper operation"); + + return handleTensorTanhApproximationTuples(ApproxChoices[0].second, + input_ptr); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("TanH operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a tanh operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::TANH && + "Invalid configuration generated for tensor tanh wrapper operation"); + + return handleTensorTanhApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + + return NULL; - return handleTensorTanhApproximationTuples(ApproxChoices[0].second, - input_ptr); } void *wrapper_tensorBatchNorm(const char *hpvm_node_id, void *input_ptr, @@ -466,55 +902,111 @@ void *wrapper_tensorBatchNorm(const char *hpvm_node_id, void *input_ptr, INFO("*** BatchNorm Operation \n"); - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = + + if (NodeConf->isGPUNodeConfiguration()) { + // mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)NodeConf; + + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = + + GPUConf->getApproxChoices(); - GPUConf->getApproxChoices(); + // printf("*** BatchNorm \n ApproxChoice = %d \n BatchNorm = %d \n CONV = %d + // \n", ApproxChoices[0].first, + // GPUNodeConfiguration::TENSOR_OP::BATCHNORM, + // GPUNodeConfiguration::TENSOR_OP::CONV); - // printf("*** BatchNorm \n ApproxChoice = %d \n BatchNorm = %d \n CONV = %d - // \n", ApproxChoices[0].first, - // GPUNodeConfiguration::TENSOR_OP::BATCHNORM, - // GPUNodeConfiguration::TENSOR_OP::CONV); + // Approximation choices must be for a batchnorm operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::BATCHNORM && + "Invalid configuration generated for tensor batchnorm wrapper operation"); + + return handleTensorBatchNormApproximationTuples( + ApproxChoices[0].second, input_ptr, gamma_ptr, beta_ptr, mean_ptr, + variance_ptr, epsilon); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("BatchNorm operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a softmax operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::BATCHNORM && + "Invalid configuration generated for tensor batchnorm wrapper operation"); + + return handleTensorBatchNormApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr, gamma_ptr, beta_ptr, + mean_ptr, variance_ptr, epsilon); + } else { + ERROR("Unsupported Configuration"); + abort(); + } - // Approximation choices must be for a batchnorm operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::BATCHNORM && - "Invalid configuration generated for tensor batchnorm wrapper operation"); + return NULL; - return handleTensorBatchNormApproximationTuples( - ApproxChoices[0].second, input_ptr, gamma_ptr, beta_ptr, mean_ptr, - variance_ptr, epsilon); } void *wrapper_tensorAdd(const char *hpvm_node_id, void *input_ptr, void *bias_ptr) { - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + + if (NodeConf->isGPUNodeConfiguration()) { - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = + // mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)NodeConf; - GPUConf->getApproxChoices(); + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = - // Approximation choices must be for an add operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::ADD && - "Invalid configuration generated for tensor add wrapper operation"); + GPUConf->getApproxChoices(); - return handleTensorAddApproximationTuples(ApproxChoices[0].second, input_ptr, - bias_ptr); + // Approximation choices must be for an add operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::ADD && + "Invalid configuration generated for tensor add wrapper operation"); + + return handleTensorAddApproximationTuples(ApproxChoices[0].second, input_ptr, + bias_ptr); + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("Add operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for an add operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::ADD && + "Invalid configuration generated for tensor add wrapper operation"); + + return handleTensorAddApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr, bias_ptr); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + return NULL; } void *wrapper_tensorPooling(const char *hpvm_node_id, void *input_ptr, @@ -525,37 +1017,73 @@ void *wrapper_tensorPooling(const char *hpvm_node_id, void *input_ptr, INFO("*** TensorPooling Operation \n"); - // return tensorPooling(input_ptr, poolFunction, window_height, window_width, - // vertical_pad, horizontal_pad, vertical_stride, - // horizontal_stride); - - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); - - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = - - GPUConf->getApproxChoices(); - - // Approximation choices must be for a single operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - "Invalid configuration generated for tensor pool wrapper operation"); - enum GPUNodeConfiguration::TENSOR_OP top = ApproxChoices[0].first; - // Approximation choices must be for a pool operation - CUSTOM_ASSERT( - (top == GPUNodeConfiguration::TENSOR_OP::POOL_MAX || - top == GPUNodeConfiguration::TENSOR_OP::POOL_MEAN || - top == GPUNodeConfiguration::TENSOR_OP::POOL_MIN) && - "Invalid configuration generated for tensor pool wrapper operation"); - - return handleTensorPoolingApproximationTuples( - ApproxChoices[0].second, input_ptr, poolFunction, window_height, - window_width, vertical_pad, horizontal_pad, vertical_stride, - horizontal_stride); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + + if (NodeConf->isGPUNodeConfiguration()) { + + // return tensorPooling(input_ptr, poolFunction, window_height, window_width, + // vertical_pad, horizontal_pad, vertical_stride, + // horizontal_stride); + + // Only mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = (GPUNodeConfiguration *)NodeConf; + + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = + + GPUConf->getApproxChoices(); + + // Approximation choices must be for a single operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + "Invalid configuration generated for tensor pool wrapper operation"); + enum GPUNodeConfiguration::TENSOR_OP top = ApproxChoices[0].first; + // Approximation choices must be for a pool operation + CUSTOM_ASSERT( + (top == GPUNodeConfiguration::TENSOR_OP::POOL_MAX || + top == GPUNodeConfiguration::TENSOR_OP::POOL_MEAN || + top == GPUNodeConfiguration::TENSOR_OP::POOL_MIN) && + "Invalid configuration generated for tensor pool wrapper operation"); + + return handleTensorPoolingApproximationTuples( + ApproxChoices[0].second, input_ptr, poolFunction, window_height, + window_width, vertical_pad, horizontal_pad, vertical_stride, + horizontal_stride); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("Pool operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a single operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + "Invalid configuration generated for tensor pool wrapper operation"); + enum CPUNodeConfiguration::TENSOR_OP top = ApproxChoices[0].first; + // Approximation choices must be for a pool operation + CUSTOM_ASSERT((top == CPUNodeConfiguration::TENSOR_OP::POOL_MAX || + top == CPUNodeConfiguration::TENSOR_OP::POOL_MEAN || + top == CPUNodeConfiguration::TENSOR_OP::POOL_MIN) && + "Invalid configuration generated for tensor pool wrapper operation"); + + return handleTensorPoolingApproximationTuples_CPU(ApproxChoices[0].second, + input_ptr, poolFunction, + window_height, window_width, + vertical_pad, horizontal_pad, + vertical_stride, horizontal_stride); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + + return NULL; + } void *wrapper_tensorGroupConvolution(const char *hpvm_node_id, void *input, @@ -563,47 +1091,104 @@ void *wrapper_tensorGroupConvolution(const char *hpvm_node_id, void *input, int horizontal_pad, int vertical_stride, int horizontal_stride, int conv_mode, int conv_groups) { - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); - - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = GPUConf->getApproxChoices(); - - // Approximation choices must be for a group_conv operation - CUSTOM_ASSERT(ApproxChoices.size() == 1 && - ApproxChoices[0].first == - GPUNodeConfiguration::TENSOR_OP::GROUP_CONV && - "Invalid configuration generated for tensor group_conv wrapper " - "operation"); - - return handleTensorGroupConvApproximationTuples( - ApproxChoices[0].second, input, filter, vertical_pad, horizontal_pad, - vertical_stride, horizontal_stride, conv_mode, conv_groups); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + + if (NodeConf->isGPUNodeConfiguration()) { + + // Mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)NodeConf; + + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = GPUConf->getApproxChoices(); + + // Approximation choices must be for a group_conv operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == + GPUNodeConfiguration::TENSOR_OP::GROUP_CONV && + "Invalid configuration generated for tensor group_conv wrapper " + "operation"); + + return handleTensorGroupConvApproximationTuples( + ApproxChoices[0].second, input, filter, vertical_pad, horizontal_pad, + vertical_stride, horizontal_stride, conv_mode, conv_groups); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("Group Convolution operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a group_conv operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::GROUP_CONV && + "Invalid configuration generated for tensor group_conv wrapper operation"); + + return handleTensorGroupConvApproximationTuples_CPU(ApproxChoices[0].second, + input, filter, + vertical_pad, horizontal_pad, + vertical_stride, horizontal_stride, + conv_mode, conv_groups); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + return NULL; + } void *wrapper_tensorSoftmax(const char *hpvm_node_id, void *input_ptr) { // return tensorSoftmax(input_ptr); - // Only mapped to GPU - get a GPU configuration - GPUNodeConfiguration *GPUConf = - (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); + NodeConfiguration *NodeConf = RC->getNodeConfiguration(hpvm_node_id); + if (NodeConf->isGPUNodeConfiguration()) { + + // Mapped to GPU - get a GPU configuration + GPUNodeConfiguration *GPUConf = + (GPUNodeConfiguration *)RC->getNodeConfiguration(hpvm_node_id); - std::vector< - std::pair<GPUNodeConfiguration::TENSOR_OP, - std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> - &ApproxChoices = GPUConf->getApproxChoices(); + std::vector< + std::pair<GPUNodeConfiguration::TENSOR_OP, + std::vector<std::pair<GPUNodeConfiguration::APPROX, int>>>> + &ApproxChoices = GPUConf->getApproxChoices(); - // Approximation choices must be for a softmax operation - CUSTOM_ASSERT( - ApproxChoices.size() == 1 && - ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::SOFTMAX && - "Invalid configuration generated for tensor softmax wrapper operation"); + // Approximation choices must be for a softmax operation + CUSTOM_ASSERT( + ApproxChoices.size() == 1 && + ApproxChoices[0].first == GPUNodeConfiguration::TENSOR_OP::SOFTMAX && + "Invalid configuration generated for tensor softmax wrapper operation"); + + return handleTensorSoftmaxApproximationTuples(ApproxChoices[0].second, + input_ptr); + + } else if (NodeConf->isCPUNodeConfiguration()) { + DEBUG("SoftMax operation: CPU Configuration\n"); + // Mapped to CPU - get a CPU configuration + CPUNodeConfiguration *CPUConf = (CPUNodeConfiguration *)NodeConf; + + std::vector< std::pair< CPUNodeConfiguration::TENSOR_OP, + std::vector< std::pair<CPUNodeConfiguration::APPROX, + int> > > > &ApproxChoices = + CPUConf->getApproxChoices(); + + // Approximation choices must be for a softmax operation + CUSTOM_ASSERT(ApproxChoices.size() == 1 && + ApproxChoices[0].first == CPUNodeConfiguration::TENSOR_OP::SOFTMAX && + "Invalid configuration generated for tensor softmax wrapper operation"); + + return handleTensorSoftmaxApproximationTuples_CPU(ApproxChoices[0].second, input_ptr); + } else { + ERROR("Unsupported Configuration"); + abort(); + } + return NULL; - return handleTensorSoftmaxApproximationTuples(ApproxChoices[0].second, - input_ptr); } void *tensor_set_node_id(unsigned int node_id) { diff --git a/hpvm/projects/keras/README.md b/hpvm/projects/keras/README.md index ef31139ee149cb1cd4218603a50c5e687f86158a..0e28bfc9516de4ed975ce116af36872dc525ecfe 100644 --- a/hpvm/projects/keras/README.md +++ b/hpvm/projects/keras/README.md @@ -1,38 +1,4 @@ -# Keras Frontend - -## Requirements - -* python == 3.6.x -* pip >= 18 - -If your system uses a different Python version, we recommend using the conda environment `keras_python36.yml`. Install this using: - -```bash -conda env create -f keras_python36.yml --name keras_python36 -``` - -Activate the conda environment before installing the pip package (below) using: - -```bash -conda activate keras_python36 -``` - -**NOTE:** This step must be performed each time (for each shell process) the frontend is to be used. -## Installing the Keras Frontend Package - -At the root of this project (`/projects/keras/`) install the Keras frontend pip package as: - -```bash -pip3 install -e ./ -``` - -**NOTE:** If you are using the conda environment, activate it prior to this step. - -## Keras Benchmarks - -Keras benchmarks can be found [here](https://gitlab.engr.illinois.edu/llvm/hpvm/-/tree/approx_hpvm_reorg_keras/hpvm/test/dnn_benchmarks/keras). - -## Suppported Operations +# Keras Frontend -List of supported operations and limitations detailed in https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/docs/Support.md +Instructions on installing and using the Keras Frontend can be found [**here**](../../docs/KerasFrontend.md) diff --git a/hpvm/projects/keras/docs/Support.md b/hpvm/projects/keras/docs/Support.md index e5e7b1a1a2125940cd0749e9c957c43bf2205aa3..b568d3d640204fd90c977e63e24dc36dc6d92336 100644 --- a/hpvm/projects/keras/docs/Support.md +++ b/hpvm/projects/keras/docs/Support.md @@ -1,4 +1,3 @@ - ## Supported Keras Operators The Keras frontend supports `Sequential()` Keras models. @@ -27,14 +26,29 @@ The list of supported operations is as follows: * Softmax operator should be the last operation in the CNN pipeline * Softmax operation must be a separate operator (not specified as activation to another type of Keras operator). Example of what works: -``` +```python Activation ("softmax") ``` Example of what is NOT supported: -``` +```python Dense(num_classes, activation="softmax") ``` +* For convolutions with stride > 1 `same` convolution is NOT supported. Explicitly add `ZeroPadding2D` layer before `Conv2D` or `DepthwiseConv2D` operators. Example of what does NOT work: + +```python +Conv2D(num_filters, kernel_size = (3,3), strides = (2,2), padding = `same`) +``` + +Example of what works instead: + +```python +# NOTE: Amount of padding varies with kernel sizes and strides +ZeroPadding2D(padding=(1,1), data_format = `channels_first`) # only support NCHW +Conv2D(num_filters, kernel_size = (3,3), strides = (2,2), padding = `valid`) +``` + + diff --git a/hpvm/test/dnn_benchmarks/keras/README.md b/hpvm/test/dnn_benchmarks/keras/README.md index f80ac8a387ecd3a537473a9797eaec190f3c9964..c17472eb7d30d306711e99cf95b8bc5ef8f84b7f 100644 --- a/hpvm/test/dnn_benchmarks/keras/README.md +++ b/hpvm/test/dnn_benchmarks/keras/README.md @@ -1,192 +1,3 @@ -# Keras Frontend - -Install Keras Frontend after moving to directory `/hpvm/hpvm/projects/keras` - -## Requirements - -* python == 3.6.x -* pip >= 18 - -If your system uses a different Python version, we recommend using the conda environment `keras_python36.yml`. Install this using: - -``` -conda env create -f keras_python36.yml --name keras_python36 -``` - -Activate the conda environment before installing the pip package (below) using: - -``` -conda activate keras_python36 -``` - -**NOTE:** This step must be performed each time (for each shell process) the frontend is to be used. - - -## Installing the Keras Frontend Package - -At the root of this project (`/projects/keras/`) install the Keras frontend pip package as: - -``` -pip3 install -e ./ -``` - -**NOTE:** If you are using the conda environment, activate it prior to this step. - -## Suppported Operations - -List of supported operations and limitations detailed in https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/docs/Support.md - - - - - # Keras Benchmarks -Run the Keras benchmarks under `hpvm/hpvm/test/dnn_benchmarks/keras` - -## Download CNN Model Files - -Prior to running the benchmarks, ensure you download the CNN model data (inputs and weights) if not done in automatic build script. - -``` -wget https://databank.illinois.edu/datafiles/o3izd/download -O model_params.tar.gz -tar -xf model_params.tar.gz -``` - -Move extracted `model_params` directory to `/test/dnn_benchmarks/model_params` (Benchmarks expect data at this location) - - -## Running Benchmaks - -List of benchmarks and the expected accuracies: - -| Benchmark | Accuracy | -| ----------- | ----------- | -| alexnet.py | 79.28 | -| alexnet2.py | 84.98 | -| alexnet_imagenet.py | 56.30 | -| lenet.py | 98.70 | -| mobilenet_cifar10.py | 84.42 | -| resnet18_cifar10.py | 89.56 | -| resnet50_imagenet.py | 75.10 | -| vgg16_cifar10.py | 89.96 | -| vgg16_cifar100.py | 66.50 | -| vgg16_imagenet.py | 69.46 | - - -### Synopsis - -``` -python3 ${BENCH_NAME}.py [hpvm_reload|keras_reload] [frontend] [compile] - -``` - - -**Command-line Parameters** - -`hpvm_reload` : Reloads HPVM weights ('.bin' binary format used in `model_params` found here: https://gitlab.engr.illinois.edu/llvm/hpvm/-/tree/approx_hpvm_reorg_keras/hpvm/test/dnn_benchmarks/model_params) from directory path specified in the `reload_dir` parameter set in code - this is described in "Parameters to Change in Code" (below). - -`keras_reload`: Alternatively, reload weights in Keras `.h5` file format with path to file specified in `keras_model_file` described in "Parameters to Change in Code" (below). - -`frontend`: Invokes the HPVM frontend and dumps weights (in HPVM `.bin` format) in the output directory specified. The parameters that control where data and source files are dumped are specified by parameters `data_dir` and `src_dir`, respectively. These are described below. - -`compile`: Optional Parameter. When specified, it compiles the HPVM-C code generated by the frontend into an HPVM binary under the directory specified by `src_dir` (described below). If `src_dir` path exists, a unique directory (which appends a unique ID) is created. -The binary is built with the name `HPVM_binary`. - -**NOTE:** Before running `HPVM_binary` necessary to set CUDA and CUDNN paths with: - -``` -source ${PATH_TO_YOUR_HPVM_ROOT}/hpvm/set_paths.sh -``` - -**Parameters to Change in Code** - -The AlexNet source is commented with explanations on how to use the Keras frontend interface. AlexNet source is [here](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/alexnet.py). - -* `NAME`: Benchmark Name - Can be set to any desired value - -* `reload_dir`: Path to directory from where to reload weights in HPVM format. This directory is used to reload weights if `hpvm_reload` command-line option is used. - -* `keras_model_file`: Path to Keras .h5 model file to reload weigths from. Either of `reload_dir` or `keras_model_file` can be used. -`keras_model_file` is used when `keras_reload` commad-line parameter is used with the Benchmark script. - -* `data_dir`: Directory to dump weights specified specified in [constructor](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/Benchmark.py#L21) - -* `src_dir`: Directory to dump ApproxHPVM sources in HPVM-C (C with HPVM compiler intrinsics) specified in [constructor](https://gitlab.engr.illinois.edu/llvm/hpvm/-/blob/approx_hpvm_reorg_keras/hpvm/projects/keras/src/Benchmark.py#L22) - -* `num_classes`: number of output classes - dependent on the dataset used. For CIFAR10, `num_classes` is 10, CIFAR100 has 100 classes, - for ImageNet, number of classes is 1000. - -* `batch_size`: This parameter controls the size of each batch that is processed in HPVM. The batch size should be kept as large as the GPU memory -can support. This parameter should be adapted according to the memory size of the deployed device. - - - -### Using the Frontend with Custom (New) Benchmarks - -Any new benchmarks must inherit from the commom parent `Benchmark` class -and override the virtual functions for building the model, training, -and data preprocessing. These methods are described below: - - -`def buildModel(self)`: -Constructs and returns a keras model - -`def data_preprocess(self)`: -returns X_train, y_train, X_test, y_test, X_tuner, and y_tuner data (in that order): -These are described here: - -* `X_train:` Training data (fp32) in NCHW format -* `y_train:` Training labels (int32) - -* `X_test:` Testing/Evaluation data in NCHW format -* `y_test:` Testing/Evaluation labels - -* `X_tuner:` Data to be used for autotuning -* `y_tuner:` Labels corresponding to tuning data - - -`def trainModel(self, model, X_train, y_train, X_test, y_test)`: -Trains the Keras model constructed in `buildModel` and is expected to return the -trained keras model - training parameters should be tuned here. - -### Directly using Keras Frontend API - -Alternate to extending the `Benchmark` class, users may directly invoke the Keras Frontend API. This can be done as: - -```python - -from keras_frontend.approxhpvm_translator import translate_to_approxhpvm - -# Construct and train your Keras Model (or load pre-trained weights) - -translate_to_approxhpvm(model, data_dir, src_dir, test_data, test_labels, tune_data, tune_labels, batch_size, num_classes) - -``` - -## Running HPVM Binary - -Run the `HPVM_binary` generated under the directory specified by `src_dir` (described above). Usage: - -``` -./HPVM_binary -t {test|tune} -c ${config_file_path} -``` - -`test|tune`: Runs with either tune (autotuning data) or test set (for evaluation) - -`config_file_path`: Path to an HPVM tensor configuration file (includes approximation settings) - - -## Automated Tests - -`scripts/test_benchmarks.py` is an automated test script that evaluates the accuracy of each Benchmark in Keras and HPVM (after comilation using HPVM Compiler) and compares the accuracy of each binary to the known correct accuracy. Run from root of `/test/dnn_benchmarks/keras`: - -``` -python test_benchmarks.py -``` - - - - - - +Instructions on using the Keras benchmarks and the Keras Frontend can be found [**here**](../../../docs/KerasFrontend.md)