Skip to content
Snippets Groups Projects
Commit b35d34c3 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

PyTorch interface now works correctly (continuing 797378f3)

parent 9d987de1
No related branches found
No related tags found
No related merge requests found
from .compile import ModelExporter, BinDataset from .compile import ModelExporter, BinDataset
from .__main__ import main
import os
from pathlib import Path
from .compile import compile_onnx_model
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="ONNX to HPVM-C")
parser.add_argument("onnx_file", type=Path, help="Path to input ONNX file")
parser.add_argument(
"output_dir",
type=Path,
help="Output folder where source file and weight files are generated",
)
parser.add_argument(
"dataset_size", type=int, help="Size of input dataset",
)
parser.add_argument(
"-p",
"--prefix",
type=str,
help="Prefix in generated code; will be attached before name of weight/input files. "
"Defaults to output_dir.",
)
parser.add_argument(
"-b",
"--batch-size",
type=int,
help="Batch size to be used in the generated code. "
"Defaults to input size (i.e., not using batch).",
)
parser.add_argument("--opset", type=int, help="ONNX opset version (enforced)")
parser.add_argument(
"-c",
"--compile-mode",
type=str,
choices=["tensor", "hpvmc"],
default="hpvmc",
help="""Output mode.
tensor: HPVM Tensor Runtime;
hpvmc: HPVM C Interface. Default value is hpvmc.""",
)
args = parser.parse_args()
args.hpvmc = args.compile_mode == "hpvmc"
delattr(args, "compile_mode")
return args
def main():
args = parse_args()
compile_onnx_model(**vars(args))
...@@ -29,14 +29,16 @@ DatasetTy = Union[BinDataset, Dataset] ...@@ -29,14 +29,16 @@ DatasetTy = Union[BinDataset, Dataset]
class ModelExporter: class ModelExporter:
tuneset_name = "tune_input.bin", "tune_labels.bin" tuneset_name = "tune_input.bin", "tune_labels.bin"
testset_name = "test_input.bin", "test_labels.bin" testset_name = "test_input.bin", "test_labels.bin"
weight_dir_name = "weights"
source_file_name = "hpvm_c.cpp"
def __init__( def __init__(
self, self,
model: ModelTy, model: ModelTy,
tune_dataset: DatasetTy, tune_dataset: DatasetTy,
test_dataset: DatasetTy, test_dataset: DatasetTy,
weight_dir: PathLike, output_dir: PathLike,
hpvmc: bool, hpvmc: bool = True,
opset: Optional[int] = None, opset: Optional[int] = None,
): ):
self.tune_dataset, self.test_dataset = tune_dataset, test_dataset self.tune_dataset, self.test_dataset = tune_dataset, test_dataset
...@@ -48,10 +50,13 @@ class ModelExporter: ...@@ -48,10 +50,13 @@ class ModelExporter:
self.onnx_model = onnx.shape_inference.infer_shapes(onnx_model) self.onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
self.dfg = DFG(self.onnx_model.graph) self.dfg = DFG(self.onnx_model.graph)
self.weight_dir = Path(weight_dir) self.output_dir = Path(output_dir)
os.makedirs(weight_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self.weight_dir = self.output_dir / self.weight_dir_name
self.weight_dir.mkdir(exist_ok=True)
flavor = HpvmCodeGen if hpvmc else TensorCodeGen flavor = HpvmCodeGen if hpvmc else TensorCodeGen
self.codegen = flavor(self.dfg, weight_dir, self.dataset_size) self.codegen = flavor(self.dfg, output_dir, self.dataset_size)
def export_source_code(self, output: PathLike, batch_size: Optional[int] = None): def export_source_code(self, output: PathLike, batch_size: Optional[int] = None):
self.codegen.compile(output, batch_size) self.codegen.compile(output, batch_size)
...@@ -65,25 +70,28 @@ class ModelExporter: ...@@ -65,25 +70,28 @@ class ModelExporter:
input_, labels = self.testset_name input_, labels = self.testset_name
self._dump_dataset(self.test_dataset, self.weight_dir / input_, self.weight_dir / labels) self._dump_dataset(self.test_dataset, self.weight_dir / input_, self.weight_dir / labels)
def export_all(self, output: PathLike, batch_size: Optional[int] = None): def export_all(self, output: PathLike = None, batch_size: Optional[int] = None):
self.export_source_code(output, batch_size) default_codefile = self.output_dir / self.source_file_name
self.export_source_code(output or default_codefile, batch_size)
self.export_weights() self.export_weights()
self.export_datasets() self.export_datasets()
@staticmethod @staticmethod
def _dump_dataset(dataset: DatasetTy, input_filename: Path, label_filename: Path): def _dump_dataset(dataset: DatasetTy, input_filename: Path, labels_filename: Path):
import numpy as np import numpy as np
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
if isinstance(dataset, BinDataset): if isinstance(dataset, BinDataset):
Path(dataset.input_file).symlink_to(input_filename) input_filename.unlink(missing_ok=True)
Path(dataset.labels_file).symlink_to(label_filename) labels_filename.unlink(missing_ok=True)
Path(input_filename).symlink_to(dataset.input_file)
Path(labels_filename).symlink_to(dataset.labels_file)
return return
inputs, labels = zip(*iter(DataLoader(dataset))) inputs, labels = zip(*iter(DataLoader(dataset)))
inputs = np.stack(inputs, axis=0) inputs = np.stack(inputs, axis=0)
labels = np.stack(labels, axis=0) labels = np.stack(labels, axis=0)
inputs.tofile(input_filename) inputs.tofile(input_filename)
inputs.tofile(label_filename) inputs.tofile(labels_filename)
@classmethod @classmethod
def _check_datasets(cls, tune_dataset: DatasetTy, test_dataset: DatasetTy) -> Tuple[int, int, int, int]: def _check_datasets(cls, tune_dataset: DatasetTy, test_dataset: DatasetTy) -> Tuple[int, int, int, int]:
......
from torch2hpvm import compile_torch_module from pathlib import Path
from torch2hpvm import ModelExporter, BinDataset
from dnn import AlexNet from dnn import AlexNet
compile_torch_module(AlexNet(), (1, 3, 32, 32), "/tmp/alexnet", True) prefix = Path(__file__).parent / "../model_params/alexnet_cifar10"
dataset_shape = 5000, 3, 32, 32
bin_tuneset = BinDataset(prefix / "tune_input.bin", prefix / "tune_labels.bin", dataset_shape)
bin_testset = BinDataset(prefix / "test_input.bin", prefix / "test_labels.bin", dataset_shape)
ModelExporter(AlexNet(), bin_tuneset, bin_testset, "/tmp/alexnet", True).export_all()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment