diff --git a/distiller/models/imagenet/resnet_earlyexit.py b/distiller/models/imagenet/resnet_earlyexit.py index c4b87423df0bff892afa9fffcaf6d0796a203367..4e6ba9962a42a29ac1913571717576c28868e06a 100644 --- a/distiller/models/imagenet/resnet_earlyexit.py +++ b/distiller/models/imagenet/resnet_earlyexit.py @@ -6,7 +6,7 @@ from torchvision.models.resnet import Bottleneck from torchvision.models.resnet import BasicBlock -__all__ = ['resnet18_earlyexit', 'resnet34_earlyexit', 'resnet50_earlyexit', 'resnet101_earlyexit', 'resnet152_earlyexit'] +__all__ = ['resnet50_earlyexit'] def conv3x3(in_planes, out_planes, stride=1): @@ -66,36 +66,8 @@ class ResNetEarlyExit(models.ResNet): return output -def resnet18_earlyexit(**kwargs): - """Constructs a ResNet-18 model. - """ - model = ResNetEarlyExit(BasicBlock, [2, 2, 2, 2], **kwargs) - return model - - -def resnet34_earlyexit(**kwargs): - """Constructs a ResNet-34 model. - """ - model = ResNetEarlyExit(BasicBlock, [3, 4, 6, 3], **kwargs) - return model - - -def resnet50_earlyexit(**kwargs): +def resnet50_earlyexit(pretrained=False, **kwargs): """Constructs a ResNet-50 model. """ model = ResNetEarlyExit(Bottleneck, [3, 4, 6, 3], **kwargs) return model - - -def resnet101_earlyexit(**kwargs): - """Constructs a ResNet-101 model. - """ - model = ResNetEarlyExit(Bottleneck, [3, 4, 23, 3], **kwargs) - return model - - -def resnet152_earlyexit(**kwargs): - """Constructs a ResNet-152 model. - """ - model = ResNetEarlyExit(Bottleneck, [3, 8, 36, 3], **kwargs) - return model