From c0e45da21ecc3244d4c0b8a854090ffcbae1d12b Mon Sep 17 00:00:00 2001
From: Neta Zmora <neta.zmora@intel.com>
Date: Thu, 4 Jul 2019 19:34:39 +0300
Subject: [PATCH] Bypass Torchvision MobileNet v2 ONNX-export bug

A temporary monkey-patch to get past this Torchvision bug:
https://github.com/pytorch/pytorch/issues/20516

To trigger, try exporting mobilenet v2 to ONNX:
time python3 compress_classifier.py --arch=mobilenet_v2 --pretrained ${IMAGENET_PATH} --export-onnx
---
 distiller/models/__init__.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py
index 712e5e5..58240e9 100755
--- a/distiller/models/__init__.py
+++ b/distiller/models/__init__.py
@@ -57,6 +57,19 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(),
                             set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES)))
 
 
+# A temporary monkey-patch to get past this Torchvision bug:
+# https://github.com/pytorch/pytorch/issues/20516
+from functools import partial
+def patch_torchvision_mobilenet_v2_bug(model):
+    def patched_forward(self, x):
+        x = self.features(x)
+        #x = x.mean([2, 3])
+        x = x.mean(3).mean(2)
+        x = self.classifier(x)
+        return x
+    model.__class__.forward = patched_forward
+
+
 def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
     """Create a pytorch model based on the model architecture and dataset
 
@@ -80,6 +93,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
         elif arch in TORCHVISION_MODEL_NAMES:
             try:
                 model = getattr(torch_models, arch)(pretrained=pretrained)
+                if arch == "mobilenet_v2":
+                    patch_torchvision_mobilenet_v2_bug(model)
             except NotImplementedError:
                 # In torchvision 0.3, trying to download a model that has no
                 # pretrained image available will raise NotImplementedError
-- 
GitLab