From c0e45da21ecc3244d4c0b8a854090ffcbae1d12b Mon Sep 17 00:00:00 2001 From: Neta Zmora <neta.zmora@intel.com> Date: Thu, 4 Jul 2019 19:34:39 +0300 Subject: [PATCH] Bypass Torchvision MobileNet v2 ONNX-export bug A temporary monkey-patch to get past this Torchvision bug: https://github.com/pytorch/pytorch/issues/20516 To trigger, try exporting mobilenet v2 to ONNX: time python3 compress_classifier.py --arch=mobilenet_v2 --pretrained ${IMAGENET_PATH} --export-onnx --- distiller/models/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 712e5e5..58240e9 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -57,6 +57,19 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) +# A temporary monkey-patch to get past this Torchvision bug: +# https://github.com/pytorch/pytorch/issues/20516 +from functools import partial +def patch_torchvision_mobilenet_v2_bug(model): + def patched_forward(self, x): + x = self.features(x) + #x = x.mean([2, 3]) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + model.__class__.forward = patched_forward + + def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): """Create a pytorch model based on the model architecture and dataset @@ -80,6 +93,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): elif arch in TORCHVISION_MODEL_NAMES: try: model = getattr(torch_models, arch)(pretrained=pretrained) + if arch == "mobilenet_v2": + patch_torchvision_mobilenet_v2_bug(model) except NotImplementedError: # In torchvision 0.3, trying to download a model that has no # pretrained image available will raise NotImplementedError -- GitLab