diff --git a/hpvm/projects/torch2hpvm/setup.py b/hpvm/projects/torch2hpvm/setup.py
index fd21f8b59d5373ea6f589e88dfd4a006b9cd1ff5..ae103a2cdf0c0872278c147ddac5774ce79da452 100644
--- a/hpvm/projects/torch2hpvm/setup.py
+++ b/hpvm/projects/torch2hpvm/setup.py
@@ -7,6 +7,6 @@ setup(
     author="Yuanjing Shi, Yifan Zhao",
     author_email="ys26@illinois.edu, yifanz16@illinois.edu",
     packages=["torch2hpvm"],
-    install_requires=["jinja2>=2.11", "networkx>=2.5", "onnx>=1.8.0"],
+    install_requires=["jinja2>=2.11", "networkx>=2.5", "onnx>=1.8.0", "torch"],
     entry_points={"console_scripts": ["torch2hpvm=torch2hpvm:main"]},
 )
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/__init__.py b/hpvm/projects/torch2hpvm/torch2hpvm/__init__.py
index 8e7de3cf9a33aa8ce1d07915b0b2d0a89964411a..dd59cd582e9ab5d23c15a38bf6f68eb258a0253a 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/__init__.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/__init__.py
@@ -1,2 +1,2 @@
-from .compile import compile
+from .compile import compile_onnx_model, compile_torch_module
 from .__main__ import main
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/__main__.py b/hpvm/projects/torch2hpvm/torch2hpvm/__main__.py
index 8eae267f4738c7ca586e4e175d245b270ec65b8a..6b499bfb5a14e8f95f0bafd274cb099dbf0346de 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/__main__.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/__main__.py
@@ -1,7 +1,7 @@
 import os
 from pathlib import Path
 
-from .compile import compile
+from .compile import compile_onnx_model
 
 
 def parse_args():
@@ -15,7 +15,7 @@ def parse_args():
         help="Output folder where source file and weight files are generated",
     )
     parser.add_argument(
-        "input_size", type=int, help="Size of input dataset",
+        "dataset_size", type=int, help="Size of input dataset",
     )
     parser.add_argument(
         "-p",
@@ -51,5 +51,4 @@ hpvmc: HPVM C Interface. Default value is hpvmc.""",
 
 def main():
     args = parse_args()
-    os.makedirs(args.output_dir, exist_ok=True)
-    compile(**vars(args))
+    compile_onnx_model(**vars(args))
diff --git a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
index 694e5f62311b0a415ef4db9987213528ebe4b4e6..f77e9f63b9b17046cf24b215670891f7bfa67746 100644
--- a/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
+++ b/hpvm/projects/torch2hpvm/torch2hpvm/compile.py
@@ -1,8 +1,12 @@
+import os
 from pathlib import Path
-from typing import Optional, Union
+from tempfile import NamedTemporaryFile
+from typing import IO, Optional, Sequence, Union
 
 import onnx
+import torch
 from onnx import version_converter
+from torch.nn import Module
 
 from .codegen_hpvm import HpvmCodeGen
 from .codegen_tensor import TensorCodeGen
@@ -11,7 +15,7 @@ from .graph_builder import DFG
 PathLike = Union[Path, str]
 
 
-def check_version(model, new_version):
+def check_onnx_version(model, new_version):
     try:
         opset = model.opset_import[0].version if model.opset_import else 1
     except AttributeError:
@@ -29,24 +33,73 @@ def check_version(model, new_version):
     return model
 
 
-def compile(
-    onnx_file: Path,
-    output_dir: Path,
-    input_size: int,
-    prefix: Optional[str],
-    batch_size: Optional[int],
-    opset: Optional[int],
+def torch_to_onnx(
+    module_cpu: Module,
+    model_args_cpu: tuple,
+    output_obj: Union[IO, PathLike],
+    opset_version: int = 10,
+):
+    # Export the model (must be on CPU, some model only supports this)
+    torch.onnx.export(
+        module_cpu.eval(),
+        model_args_cpu,
+        output_obj,
+        export_params=True,  # store the trained parameter weights inside the model file
+        opset_version=opset_version,  # the ONNX version to export the model to
+        do_constant_folding=True,  # whether to execute constant folding for optimization
+        input_names=["input"],  # the model's input names
+        output_names=["output"],  # the model's output names
+        dynamic_axes={
+            "input": {0: "batch_size"},  # variable length axes
+            "output": {0: "batch_size"},
+        },
+        strip_doc_string=False,
+    )
+
+
+def compile_onnx_model(
+    file_or_model: Union[PathLike, onnx.ModelProto],
+    output_dir: PathLike,
+    dataset_size: int,
     hpvmc: bool,
+    prefix: Optional[str] = None,
+    batch_size: Optional[int] = None,
+    opset: Optional[int] = None,
 ):
-    model = onnx.load(onnx_file)
+    if isinstance(file_or_model, onnx.ModelProto):
+        model = file_or_model
+    else:
+        model = onnx.load(Path(file_or_model).as_posix())
     if opset is not None:
-        model = check_version(model, opset)
+        model = check_onnx_version(model, opset)
     model = onnx.shape_inference.infer_shapes(model)
     dfg = DFG(model.graph)
+    output_dir = Path(output_dir)
+    os.makedirs(output_dir, exist_ok=True)
     if hpvmc:
-        hpvm_code_gen = HpvmCodeGen(dfg, output_dir, input_size, batch_size, prefix)
+        hpvm_code_gen = HpvmCodeGen(dfg, output_dir, dataset_size, batch_size, prefix)
         hpvm_code_gen.compile()
     else:
-        tensor_code_gen = TensorCodeGen(dfg, output_dir, input_size, batch_size, prefix)
+        tensor_code_gen = TensorCodeGen(dfg, output_dir, dataset_size, batch_size, prefix)
         tensor_code_gen.compile()
     dfg.dump_weights(output_dir)
+
+
+def compile_torch_module(
+    module: Module,
+    input_shape: Sequence[int],
+    output_dir: PathLike,
+    hpvmc: bool,
+    prefix: Optional[str] = None,
+    batch_size: Optional[int] = None,
+):
+    dataset_size, *single_input_shape = input_shape
+    sample_input_shape = 1, *single_input_shape
+    sample_input = torch.rand(sample_input_shape)
+    with NamedTemporaryFile("w+b") as tmp:
+        torch_to_onnx(module, (sample_input, ), tmp)
+        tmp.seek(0)
+        onnx_model = onnx.load_model(tmp)
+        compile_onnx_model(
+            onnx_model, output_dir, dataset_size, hpvmc, prefix, batch_size
+        )