Skip to content
Snippets Groups Projects
Commit c0e45da2 authored by Neta Zmora's avatar Neta Zmora
Browse files

Bypass Torchvision MobileNet v2 ONNX-export bug

A temporary monkey-patch to get past this Torchvision bug:
https://github.com/pytorch/pytorch/issues/20516

To trigger, try exporting mobilenet v2 to ONNX:
time python3 compress_classifier.py --arch=mobilenet_v2 --pretrained ${IMAGENET_PATH} --export-onnx
parent 032b1f74
No related branches found
No related tags found
No related merge requests found
...@@ -57,6 +57,19 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), ...@@ -57,6 +57,19 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES)))
# A temporary monkey-patch to get past this Torchvision bug:
# https://github.com/pytorch/pytorch/issues/20516
from functools import partial
def patch_torchvision_mobilenet_v2_bug(model):
def patched_forward(self, x):
x = self.features(x)
#x = x.mean([2, 3])
x = x.mean(3).mean(2)
x = self.classifier(x)
return x
model.__class__.forward = patched_forward
def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
"""Create a pytorch model based on the model architecture and dataset """Create a pytorch model based on the model architecture and dataset
...@@ -80,6 +93,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): ...@@ -80,6 +93,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
elif arch in TORCHVISION_MODEL_NAMES: elif arch in TORCHVISION_MODEL_NAMES:
try: try:
model = getattr(torch_models, arch)(pretrained=pretrained) model = getattr(torch_models, arch)(pretrained=pretrained)
if arch == "mobilenet_v2":
patch_torchvision_mobilenet_v2_bug(model)
except NotImplementedError: except NotImplementedError:
# In torchvision 0.3, trying to download a model that has no # In torchvision 0.3, trying to download a model that has no
# pretrained image available will raise NotImplementedError # pretrained image available will raise NotImplementedError
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment