From ee06d981afa0675c8637ab3edf60b1250b09c27c Mon Sep 17 00:00:00 2001 From: Yifan Zhao <yifanz16@illinois.edu> Date: Tue, 2 Feb 2021 01:06:30 -0600 Subject: [PATCH] Supports selecting hpvm target name in frontend --- .../projects/torch2hpvm/torch2hpvm/codegen_hpvm.py | 11 ++++++++++- hpvm/projects/torch2hpvm/torch2hpvm/compile.py | 12 +++++++++--- .../torch2hpvm/torch2hpvm/template_hpvm.cpp.in | 2 +- hpvm/test/dnn_benchmarks/pytorch/test_frontend.py | 14 +++++++++----- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py index 439c5af5d7..cdba5f327f 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py @@ -52,7 +52,9 @@ class CodeGen: for weight in weights: name = cls.make_c_identifier(weight.name) file_path = f"{weight.new_name}_path.bin" - ret.append({"name": name, "shape": weight.output_shape, "filename": file_path}) + ret.append( + {"name": name, "shape": weight.output_shape, "filename": file_path} + ) return ret @staticmethod @@ -67,6 +69,12 @@ class HpvmCodeGen(CodeGen): # Variable indicator is always int for hpvm gen variables: Dict[DFGNode, Tuple[int, bool]] + def __init__(self, dfg: DFG, prefix: PathLike, input_size: int, target: str): + super().__init__(dfg, prefix, input_size) + if target not in ("tensor", "cudnn"): + raise ValueError(f"Unsupported target {target}") + self.target = target + def _emit_hpvm_node_edges(self, input_vars: List[DFGNode]) -> List[dict]: ret = [] it = 0 @@ -135,5 +143,6 @@ class HpvmCodeGen(CodeGen): root_output_idx=output_var_idx, weights=weights, prefix=self.prefix, + target=self.target, ) ) diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py index a94c16e6d4..cc2a670dad 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py +++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py @@ -41,7 +41,7 @@ class ModelExporter: tune_dataset: DatasetTy, test_dataset: DatasetTy, output_dir: PathLike, - hpvmc: bool = True, + target: str = "hpvm_tensor", opset: Optional[int] = None, ): from onnxsim import simplify @@ -62,8 +62,14 @@ class ModelExporter: self.weight_dir = self.output_dir / self.weight_dir_name self.weight_dir.mkdir(exist_ok=True) - flavor = HpvmCodeGen if hpvmc else TensorCodeGen - self.codegen = flavor(self.dfg, self.weight_dir, self.dataset_size) + if target == "hpvm_tensor": + self.codegen = HpvmCodeGen(self.dfg, self.weight_dir, self.dataset_size, "tensor") + elif target == "hpvm_cudnn": + self.codegen = HpvmCodeGen(self.dfg, self.weight_dir, self.dataset_size, "cudnn") + elif target == "tensor": + self.codegen = TensorCodeGen(self.dfg, self.weight_dir, self.dataset_size) + else: + raise ValueError(f"Target {target} not recognized") def export_source_code(self, output: PathLike, batch_size: Optional[int] = None): self.codegen.compile(output, batch_size) diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in index d7d62e4216..d7fd6c8884 100644 --- a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in +++ b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm.cpp.in @@ -8,7 +8,7 @@ void var_{{node.idx}}_node( {%- for n in range(node.input_size) -%} void *t{{n}}, size_t bytes_t{{n}}{{", " if not loop.last}} {%- endfor %}) { - __hpvm__hint(hpvm::CUDNN_TARGET); + __hpvm__hint(hpvm::{{target.upper()}}_TARGET); __hpvm__attributes({{node.input_size}}, {% for n in range(node.input_size) -%} t{{n}}{{", " if not loop.last}} {%- endfor %}, 0); diff --git a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py index 96de3db2ed..7395136eb5 100644 --- a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py +++ b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py @@ -29,24 +29,28 @@ for model_cls, nch, img_size, batch_size, pathname in benchmarks: print(f"Generating {pathname} to {codegen_dir}") if codegen_dir.exists(): shutil.rmtree(codegen_dir) - prefix = self_folder / "../model_params" / pathname + + params = self_folder / "../model_params" / pathname dataset_shape = 5000, nch, img_size, img_size bin_tuneset = BinDataset( - prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape + params / "tune_input.bin", params / "tune_labels.bin", dataset_shape ) bin_testset = BinDataset( - prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape + params / "test_input.bin", params / "test_labels.bin", dataset_shape ) model: Module = model_cls() checkpoint = self_folder / "../model_params" / f"{pathname}.pth.tar" model.load_state_dict(torch.load(checkpoint.as_posix())) - exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir, True) + + exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir) exporter.export_all(batch_size=batch_size) + + conf_file = self_folder / "../hpvm-c/benchmarks" / pathname / "data/tuner_confs.txt" build_dir = codegen_dir / "build" target_binary = build_dir / pathname run([ "approxhpvm.py", str(codegen_dir / ModelExporter.source_file_name), str(target_binary), "-d", str(build_dir), - "-t", "cudnn" + "-t", "tensor", "--conf-file", str(conf_file) ], check=True) run([str(target_binary), "test"], check=True) -- GitLab