diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
index 7a5ea0cb0f36c119755a9728f59582c9f798ab85..6f6b71eae0deda9176c3dcb32c76c99bccbf5f07 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/codegen_hpvm.py
@@ -69,14 +69,22 @@ class HpvmCodeGen(CodeGen):
     # Variable indicator is always int for hpvm gen
     variables: Dict[DFGNode, Tuple[int, bool]]
 
-    def __init__(self, dfg: DFG, prefix: PathLike, input_size: int, target: str, inspectable: bool):
+    def __init__(
+        self,
+        dfg: DFG,
+        prefix: PathLike,
+        input_size: int,
+        target: str,
+        inspectable: Optional[dict],
+    ):
         super().__init__(dfg, prefix, input_size)
         if target not in ("tensor", "cudnn"):
             raise ValueError(f"Unsupported target {target}")
         self.target = target
         self.template = template_env.get_template(
-            INSPECT_TEMPLATE_FILE if inspectable else PLAIN_TEMPLATE_FILE
+            PLAIN_TEMPLATE_FILE if inspectable is None else INSPECT_TEMPLATE_FILE
         )
+        self.inspect_vars = inspectable or {}
 
     def _emit_hpvm_node_edges(self, input_vars: List[DFGNode]) -> List[dict]:
         ret = []
@@ -147,5 +155,6 @@ class HpvmCodeGen(CodeGen):
                     weights=weights,
                     prefix=self.prefix,
                     target=self.target,
+                    **self.inspect_vars
                 )
             )
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
index 25b3fff2920320bab5f260b57991fc88563035d5..52b6a5bc21dbbe8df16af6972df370321013aca5 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
@@ -34,6 +34,8 @@ class ModelExporter:
     weight_dir_name = "weights"
     source_file_name = "hpvm_c.cpp"
     metadata_file_name = "ops.json"
+    config_file_name = "tuner_confs.txt"
+    fifo_file_name = "hpvm_fifo"
 
     def __init__(
         self,
@@ -43,32 +45,43 @@ class ModelExporter:
         output_dir: PathLike,
         target: str = "hpvm_tensor",
         opset: Optional[int] = None,
+        config_file: PathLike = None,
     ):
-        from onnxsim import simplify
-
         self.tune_dataset, self.test_dataset = tune_dataset, test_dataset
         self.dataset_shape = self._check_datasets(tune_dataset, test_dataset)
         self.dataset_size = self.dataset_shape[0]
-        onnx_model = self._load_model(model, self.dataset_shape)
-        if opset is not None:
-            onnx_model = check_onnx_version(onnx_model, opset)
-        onnx_model, check = simplify(onnx_model)
-        assert check, "Simplified ONNX model could not be validated"
-        onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
-
+        onnx_model = self._load_model(model, self.dataset_shape, opset)
         self.dfg = DFG(onnx_model.graph)
-        self.output_dir = Path(output_dir)
+
+        output_dir = Path(output_dir).absolute()
         os.makedirs(output_dir, exist_ok=True)
-        self.weight_dir = self.output_dir / self.weight_dir_name
+        self.weight_dir = output_dir / self.weight_dir_name
         self.weight_dir.mkdir(exist_ok=True)
+        self.codefile = output_dir / self.source_file_name
+        self.metafile = output_dir / self.metadata_file_name
 
         args3 = self.dfg, self.weight_dir, self.dataset_size
+        self.compile_args = None
         if target == "hpvm_tensor":
-            self.codegen = HpvmCodeGen(*args3, "tensor", False)
+            if config_file is None:
+                raise ValueError(
+                    f"Config file must be given and exist under hpvm_tensor mode"
+                )
+            self.compile_args = ["-t", "tensor", "--conf-file", str(config_file)]
+            self.codegen = HpvmCodeGen(*args3, "tensor", None)
         elif target == "hpvm_tensor_inspect":
-            self.codegen = HpvmCodeGen(*args3, "tensor", True)
+            if config_file is None:
+                config_file = output_dir / self.config_file_name
+            else:
+                config_file = Path(config_file).absolute()
+            fifo_file = output_dir / self.fifo_file_name
+            self.compile_args = ["-t", "tensor", "--conf-file", str(config_file)]
+            inspect_mode_args = {"conf_path": config_file, "fifo_path": fifo_file}
+            self.codegen = HpvmCodeGen(*args3, "tensor", inspect_mode_args)
         elif target == "hpvm_cudnn":
-            self.codegen = HpvmCodeGen(*args3, "cudnn", False)
+            self.compile_target = "cudnn"
+            self.compile_args = ["-t", "cudnn"]
+            self.codegen = HpvmCodeGen(*args3, "cudnn", None)
         elif target == "tensor":
             self.codegen = TensorCodeGen(*args3)
         else:
@@ -76,9 +89,11 @@ class ModelExporter:
 
     def export_source_code(self, output: PathLike, batch_size: Optional[int] = None):
         self.codegen.compile(output, batch_size)
+        return self
 
     def export_weights(self):
         self.dfg.dump_weights(self.weight_dir)
+        return self
 
     def export_datasets(self):
         input_, labels = self.tuneset_name
@@ -89,6 +104,7 @@ class ModelExporter:
         self._dump_dataset(
             self.test_dataset, self.weight_dir / input_, self.weight_dir / labels
         )
+        return self
 
     def export_metadata(
         self, output: PathLike, approx_knobs_file: PathLike = def_approx_knobs_file
@@ -139,19 +155,38 @@ class ModelExporter:
                     "op_knobs": op_knobs,
                     "baseline_knob": baseline_knob,
                     "tune_args": "tune",
-                    "test_args": "test"
+                    "test_args": "test",
                 },
                 f,
                 indent=2,
             )
-
-    def export_all(self, output: PathLike = None, batch_size: Optional[int] = None):
-        default_codefile = self.output_dir / self.source_file_name
-        self.export_source_code(output or default_codefile, batch_size)
-        default_metafile = self.output_dir / self.metadata_file_name
-        self.export_metadata(default_metafile)
+        return self
+
+    def compile(self, output_binary: PathLike, working_dir: Optional[PathLike] = None):
+        from subprocess import run
+
+        args = [
+            "approxhpvm.py",
+            str(self.codefile),
+            str(output_binary),
+            *self.compile_args,
+        ]
+        if working_dir is not None:
+            args.extend(["-d", str(working_dir)])
+        run(args, check=True)
+        return self
+
+    def generate(
+        self, output_code_file: PathLike = None, batch_size: Optional[int] = None
+    ):
+        self.codefile = (
+            self.codefile if output_code_file is None else Path(output_code_file)
+        )
+        self.export_source_code(self.codefile, batch_size)
+        self.export_metadata(self.metafile)
         self.export_weights()
         self.export_datasets()
+        return self
 
     @staticmethod
     def _dump_dataset(dataset: DatasetTy, input_filename: Path, labels_filename: Path):
@@ -229,7 +264,11 @@ class ModelExporter:
         return dataset.shape
 
     @staticmethod
-    def _load_model(model: ModelTy, dataset_shape: Sequence[int]) -> onnx.ModelProto:
+    def _load_model(
+        model: ModelTy, dataset_shape: Sequence[int], opset: Optional[int]
+    ) -> onnx.ModelProto:
+        from onnxsim import simplify
+
         if isinstance(model, Module):
             # Export to ONNX and load back.
             sample_input_shape = 1, *dataset_shape[1:]
@@ -237,10 +276,16 @@ class ModelExporter:
             with NamedTemporaryFile("w+b") as tmp:
                 torch_to_onnx(model, (sample_input,), tmp)
                 tmp.seek(0)
-                return onnx.load_model(tmp)
-        if isinstance(model, onnx.ModelProto):
-            return model
-        return onnx.load(Path(model).as_posix())
+                onnx_model = onnx.load_model(tmp)
+        elif isinstance(model, onnx.ModelProto):
+            onnx_model = model
+        else:
+            raise ValueError(f"Cannot accept model of type {type(model)}")
+        if opset is not None:
+            onnx_model = check_onnx_version(onnx_model, opset)
+        onnx_model, check = simplify(onnx_model)
+        assert check, "Simplified ONNX model could not be validated"
+        return onnx.shape_inference.infer_shapes(onnx_model)
 
 
 def check_onnx_version(model, new_version):
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in
index 178a6d28e942fcd71c6ec18b6b4080866d06f1cb..00a15dc9553fcfab7ecbcda1be61e31cfde9dfd9 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/template_hpvm_inspect.cpp.in
@@ -28,14 +28,19 @@ void fifo_write_finished(const std::string &filename) {
 }
 
 void make_fifo(const std::string &filename) {
-  if (mkfifo(filename.c_str(), 0666) == 0)
+  if (mkfifo(filename.c_str(), 0666) == 0) {
+    std::ofstream file{filename};
+    file << "{{conf_path}}\n"; // Write path to config file in FIFO file
     return;
+  }
+
   if (errno == EEXIST) {
     if (unlink(filename.c_str()) < 0) {
       std::cout << "Error removing existing file: " << strerror(errno) << '\n';
       abort();
     }
     make_fifo(filename);
+    return;
   }
   std::cout << "Error making FIFO file: " << strerror(errno) << '\n';
   abort();
@@ -126,8 +131,8 @@ int main(int argc, char *argv[]){
   args->{{n}}_bytes = 0;
 {% endfor %}
 
-  make_fifo("/tmp/hpvm_fifo");
-  while (fifo_wait("/tmp/hpvm_fifo")) {
+  make_fifo("{{fifo_path}}");
+  while (fifo_wait("{{fifo_path}}")) {
     __hpvm__init();
     startMemTracking();
     for (int i = 0; i < batch_count; i++){
@@ -143,7 +148,7 @@ int main(int argc, char *argv[]){
       freeBatchMemory();
     }
     __hpvm__cleanup();
-    fifo_write_finished("/tmp/hpvm_fifo");
+    fifo_write_finished("{{fifo_path}}");
   }
 
   return 0;
diff --git a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
index 7395136eb5f19adc2ad3450c34b60c911f72747e..19f17366459a7684c6df8a940438b661cf7f6029 100644
--- a/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
+++ b/hpvm/test/dnn_benchmarks/pytorch/test_frontend.py
@@ -42,15 +42,11 @@ for model_cls, nch, img_size, batch_size, pathname in benchmarks:
     checkpoint = self_folder / "../model_params" / f"{pathname}.pth.tar"
     model.load_state_dict(torch.load(checkpoint.as_posix()))
 
-    exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir)
-    exporter.export_all(batch_size=batch_size)
-
-    conf_file = self_folder / "../hpvm-c/benchmarks" / pathname / "data/tuner_confs.txt"
     build_dir = codegen_dir / "build"
     target_binary = build_dir / pathname
-    run([
-        "approxhpvm.py", str(codegen_dir / ModelExporter.source_file_name), str(target_binary),
-        "-d", str(build_dir),
-        "-t", "tensor", "--conf-file", str(conf_file)
-    ], check=True)
+    conf_file = self_folder / "../hpvm-c/benchmarks" / pathname / "data/tuner_confs.txt"
+    exporter = ModelExporter(
+        model, bin_tuneset, bin_testset, codegen_dir, config_file=conf_file
+    )
+    exporter.generate(batch_size=batch_size).compile(target_binary, build_dir)
     run([str(target_binary), "test"], check=True)
diff --git a/hpvm/test/dnn_benchmarks/pytorch/test_tuning.py b/hpvm/test/dnn_benchmarks/pytorch/test_tuning.py
index 1c4e8120ef675a077757c593a0b295f31841a1ea..4da55559a2c2b5abe45d432aaac0930bada5faf5 100644
--- a/hpvm/test/dnn_benchmarks/pytorch/test_tuning.py
+++ b/hpvm/test/dnn_benchmarks/pytorch/test_tuning.py
@@ -53,18 +53,13 @@ model: Module = model_cls()
 checkpoint = self_folder / "../model_params" / f"{pathname}.pth.tar"
 model.load_state_dict(torch.load(checkpoint.as_posix()))
 
-exporter = ModelExporter(model, bin_tuneset, bin_testset, codegen_dir, target="hpvm_tensor_inspect")
-exporter.export_all(batch_size=batch_size)
-
-conf_file = self_folder / "../hpvm-c/benchmarks" / pathname / "data/tuner_confs.txt"
 build_dir = codegen_dir / "build"
 target_binary = build_dir / pathname
-run([
-    "approxhpvm.py", str(codegen_dir / ModelExporter.source_file_name), str(target_binary),
-    "-d", str(build_dir),
-    "-t", "tensor", "--conf-file", str(conf_file)
-], check=True)
-# run([str(target_binary), "test"], check=True)
+exporter = ModelExporter(
+    model, bin_tuneset, bin_testset, codegen_dir, target="hpvm_tensor_inspect"
+)
+exporter.generate(batch_size=batch_size).compile(target_binary, build_dir)
+run([str(target_binary), "test"], check=True)
 
 # build_dir = codegen_dir / "build"
 # print(PipedBinaryApp("test", codegen_dir / "ops.json", build_dir / "lenet_mnist", build_dir))