diff --git a/distiller/models/__init__.py b/distiller/models/__init__.py index 712e5e51089ae1ae93a1b6c7b4251a53e26caf55..58240e912a71b25282e456ea2d9dd58acd9b73cf 100755 --- a/distiller/models/__init__.py +++ b/distiller/models/__init__.py @@ -57,6 +57,19 @@ ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES + MNIST_MODEL_NAMES))) +# A temporary monkey-patch to get past this Torchvision bug: +# https://github.com/pytorch/pytorch/issues/20516 +from functools import partial +def patch_torchvision_mobilenet_v2_bug(model): + def patched_forward(self, x): + x = self.features(x) + #x = x.mean([2, 3]) + x = x.mean(3).mean(2) + x = self.classifier(x) + return x + model.__class__.forward = patched_forward + + def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): """Create a pytorch model based on the model architecture and dataset @@ -80,6 +93,8 @@ def create_model(pretrained, dataset, arch, parallel=True, device_ids=None): elif arch in TORCHVISION_MODEL_NAMES: try: model = getattr(torch_models, arch)(pretrained=pretrained) + if arch == "mobilenet_v2": + patch_torchvision_mobilenet_v2_bug(model) except NotImplementedError: # In torchvision 0.3, trying to download a model that has no # pretrained image available will raise NotImplementedError