diff --git a/hpvm/projects/onnx_frontend/main.py b/hpvm/projects/onnx_frontend/main.py index 07b5bb724915480d4f6bc04110f808da0254e0bb..184877082ba79e953575be2579ac0079ed4d9854 100644 --- a/hpvm/projects/onnx_frontend/main.py +++ b/hpvm/projects/onnx_frontend/main.py @@ -41,11 +41,11 @@ def compile( model = check_version(model, opset) dfg = DFG(model.graph) if hpvmc: - hpvmCodeGen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix) - hpvmCodeGen.compile() + hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix) + hpvm_code_gen.compile() else: - TensorCodeGen = TensorCodeGen(dfg, output_dir, input_size) - TensorCodeGen.compile() + tensor_code_gen = TensorCodeGen(dfg, output_dir, input_size, batch_size, prefix) + tensor_code_gen.compile() dfg.dump_weights(output_dir) @@ -66,7 +66,7 @@ def parse_args(): "-p", "--prefix", type=str, - help="Prefix in generated code; will be attached before name of weight/input files." + help="Prefix in generated code; will be attached before name of weight/input files. " "Defaults to output_dir.", ) parser.add_argument(