diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 2ba18521daea0bf4f3ae04c7552d35274d25d1cb..3d0cedc5293242882126fe6dbff0163d509a5cd5 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -24,6 +24,7 @@ import os import re import numpy as np import collections +from copy import deepcopy import torch import torchvision from torch.autograd import Variable @@ -615,6 +616,14 @@ def export_img_classifier_to_onnx(model, onnx_fname, dataset): # Pytorch 0.4 doesn't support exporting modules wrapped in DataParallel if isinstance(model, torch.nn.DataParallel): model = model.module + + # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. + # We make a copy of the model, since we are about to change it (adding softmax). + model = deepcopy(model) + model.original_forward = model.forward + softmax = torch.nn.Softmax(dim=1) + model.forward = lambda input: softmax(model.original_forward(input)) + torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=True) msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname))