From e960a7f1b7daeff569b0eb8fa563ce697e3285b5 Mon Sep 17 00:00:00 2001
From: Yifan Zhao <yifanz16@illinois.edu>
Date: Thu, 25 Mar 2021 03:50:02 -0500
Subject: [PATCH] Changed pytorch version requirement

---
 hpvm/env.yaml                                  |  2 +-
 hpvm/projects/predtuner                        |  2 +-
 hpvm/projects/torch2hpvm/setup.py              |  3 ++-
 hpvm/projects/torch2hpvm/torch2hpvm/compile.py | 15 +++++++++------
 4 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/hpvm/env.yaml b/hpvm/env.yaml
index c9cf79fc98..d2dcaa3de2 100644
--- a/hpvm/env.yaml
+++ b/hpvm/env.yaml
@@ -11,7 +11,7 @@ dependencies:
   - pandas=1.1
   - python==3.6.13
   - pip
-  - pytorch=1.7
+  - pytorch==1.6.0
   - torchvision=0.8
   - tqdm=4.59
   - scipy==1.1.0
diff --git a/hpvm/projects/predtuner b/hpvm/projects/predtuner
index fd00663b14..2fbd6f876c 160000
--- a/hpvm/projects/predtuner
+++ b/hpvm/projects/predtuner
@@ -1 +1 @@
-Subproject commit fd00663b145998da06ef861ffafa6c99ac2c0a47
+Subproject commit 2fbd6f876c34bfdbcbddc71cd73646e71bde5748
diff --git a/hpvm/projects/torch2hpvm/setup.py b/hpvm/projects/torch2hpvm/setup.py
index 0c2a89fc19..6d66372c71 100644
--- a/hpvm/projects/torch2hpvm/setup.py
+++ b/hpvm/projects/torch2hpvm/setup.py
@@ -12,7 +12,8 @@ setup(
         "jinja2>=2.11",
         "networkx>=2.5",
         "onnx>=1.8.0",
-        "torch>=1.5",
+        # Starting from 1.7.0 PyTorch starts to do some weird optimizations.
+        "torch>=1.4,<=1.6",
         "onnx-simplifier>=0.2.27",
     ],
 )
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
index 946fd27d39..922b6795ad 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
@@ -81,7 +81,7 @@ class ModelExporter:
                 "tune_labels_path": (self.weight_dir / self.tuneset_name[1]).as_posix(),
                 "conf_path": config_file.as_posix(),
                 "fifo_path_r": (output_dir / self.fifo_file_name_r).as_posix(),
-                "fifo_path_w": (output_dir / self.fifo_file_name_w).as_posix()
+                "fifo_path_w": (output_dir / self.fifo_file_name_w).as_posix(),
             }
             self.compile_args = ["-t", "tensor", "--conf-file", str(config_file)]
             self.codegen = HpvmCodeGen(*args3, "tensor", self.path_params)
@@ -161,7 +161,7 @@ class ModelExporter:
                     "knob_speedup": knob_speedup,
                     "op_knobs": op_knobs,
                     "baseline_knob": baseline_knob,
-                    **self.path_params
+                    **self.path_params,
                 },
                 f,
                 indent=2,
@@ -290,7 +290,9 @@ class ModelExporter:
             raise ValueError(f"Cannot accept model of type {type(model)}")
         if opset is not None:
             onnx_model = check_onnx_version(onnx_model, opset)
-        onnx_model, check = simplify(onnx_model)
+        onnx_model, check = simplify(
+            onnx_model, skip_fuse_bn=True, skipped_optimizers=["fuse_bn_into_conv"]
+        )
         assert check, "Simplified ONNX model could not be validated"
         return onnx.shape_inference.infer_shapes(onnx_model)
 
@@ -318,17 +320,18 @@ def torch_to_onnx(
     output_obj: Union[IO, PathLike],
     opset_version: int = 10,
 ):
+    from torch.onnx import export
+
     # Export the model (must be on CPU, some model only supports this)
-    torch.onnx.export(
+    export(
         module_cpu.eval(),
         model_args_cpu,
         output_obj,
         export_params=True,  # store the trained parameter weights inside the model file
+        do_constant_folding=False,
         opset_version=opset_version,  # the ONNX version to export the model to
-        do_constant_folding=True,  # whether to execute constant folding for optimization
         input_names=["input"],  # the model's input names
         output_names=["output"],  # the model's output names
-        strip_doc_string=False,
     )
 
 
-- 
GitLab