diff --git a/apputils/model_summaries.py b/apputils/model_summaries.py index 1531116544987389b9f0fe78a0a7c6f96ff0ef89..3bb94aea9850a74cd0cdd03291a2816ea0ceace0 100755 --- a/apputils/model_summaries.py +++ b/apputils/model_summaries.py @@ -618,7 +618,7 @@ def export_img_classifier_to_onnx(model, onnx_fname, dataset, export_params=True if add_softmax: # Explicitly add a softmax layer, because it is needed for the ONNX inference phase. model.original_forward = model.forward - softmax = torch.nn.Softmax(dim=1) + softmax = torch.nn.Softmax(dim=-1) model.forward = lambda input: softmax(model.original_forward(input)) torch.onnx.export(model, dummy_input, onnx_fname, verbose=False, export_params=export_params) msglogger.info('Exported the model to ONNX format at %s' % os.path.realpath(onnx_fname))