Skip to content
Snippets Groups Projects
Commit ee06d981 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Supports selecting hpvm target name in frontend

parent 813ca8da
No related branches found
No related tags found
No related merge requests found
......@@ -52,7 +52,9 @@ class CodeGen:
for weight in weights:
name = cls.make_c_identifier(weight.name)
file_path = f"{weight.new_name}_path.bin"
ret.append({"name": name, "shape": weight.output_shape, "filename": file_path})
ret.append(
{"name": name, "shape": weight.output_shape, "filename": file_path}
)
return ret
@staticmethod
......@@ -67,6 +69,12 @@ class HpvmCodeGen(CodeGen):
# Variable indicator is always int for hpvm gen
variables: Dict[DFGNode, Tuple[int, bool]]
def __init__(self, dfg: DFG, prefix: PathLike, input_size: int, target: str):
super().__init__(dfg, prefix, input_size)
if target not in ("tensor", "cudnn"):
raise ValueError(f"Unsupported target {target}")
self.target = target
def _emit_hpvm_node_edges(self, input_vars: List[DFGNode]) -> List[dict]:
ret = []
it = 0
......@@ -135,5 +143,6 @@ class HpvmCodeGen(CodeGen):
root_output_idx=output_var_idx,
weights=weights,
prefix=self.prefix,
target=self.target,
)
)
......@@ -41,7 +41,7 @@ class ModelExporter:
tune_dataset: DatasetTy,
test_dataset: DatasetTy,
output_dir: PathLike,
hpvmc: bool = True,
target: str = "hpvm_tensor",
opset: Optional[int] = None,
):
from onnxsim import simplify
......@@ -62,8 +62,14 @@ class ModelExporter:
self.weight_dir = self.output_dir / self.weight_dir_name
self.weight_dir.mkdir(exist_ok=True)
flavor = HpvmCodeGen if hpvmc else TensorCodeGen
self.codegen = flavor(self.dfg, self.weight_dir, self.dataset_size)
if target == "hpvm_tensor":
self.codegen = HpvmCodeGen(self.dfg, self.weight_dir, self.dataset_size, "tensor")
elif target == "hpvm_cudnn":
self.codegen = HpvmCodeGen(self.dfg, self.weight_dir, self.dataset_size, "cudnn")
elif target == "tensor":
self.codegen = TensorCodeGen(self.dfg, self.weight_dir, self.dataset_size)
else:
raise ValueError(f"Target {target} not recognized")
def export_source_code(self, output: PathLike, batch_size: Optional[int] = None):
self.codegen.compile(output, batch_size)
......
......@@ -8,7 +8,7 @@ void var_{{node.idx}}_node(
{%- for n in range(node.input_size) -%}
void *t{{n}}, size_t bytes_t{{n}}{{", " if not loop.last}}
{%- endfor %}) {
__hpvm__hint(hpvm::CUDNN_TARGET);
__hpvm__hint(hpvm::{{target.upper()}}_TARGET);
__hpvm__attributes({{node.input_size}}, {% for n in range(node.input_size) -%}
t{{n}}{{", " if not loop.last}}
{%- endfor %}, 0);
......
......@@ -29,24 +29,28 @@ for model_cls, nch, img_size, batch_size, pathname in benchmarks:
print(f"Generating {pathname} to {codegen_dir}")
if codegen_dir.exists():
shutil.rmtree(codegen_dir)
prefix = self_folder / "../model_params" / pathname
params = self_folder / "../model_params" / pathname
dataset_shape = 5000, nch, img_size, img_size
bin_tuneset = BinDataset(
prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape
params / "tune_input.bin", params / "tune_labels.bin", dataset_shape
)
bin_testset = BinDataset(
prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape
params / "test_input.bin", params / "test_labels.bin", dataset_shape
)
model: Module = model_cls()
checkpoint = self_folder / "../model_params" / f"{pathname}.pth.tar"
model.load_state_dict(torch.load(checkpoint.as_posix()))
exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir, True)
exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir)
exporter.export_all(batch_size=batch_size)
conf_file = self_folder / "../hpvm-c/benchmarks" / pathname / "data/tuner_confs.txt"
build_dir = codegen_dir / "build"
target_binary = build_dir / pathname
run([
"approxhpvm.py", str(codegen_dir / ModelExporter.source_file_name), str(target_binary),
"-d", str(build_dir),
"-t", "cudnn"
"-t", "tensor", "--conf-file", str(conf_file)
], check=True)
run([str(target_binary), "test"], check=True)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment