diff --git a/hpvm/projects/onnx/generated/vgg16/approxhpvm_src.cc b/hpvm/projects/onnx/keras_ref/generated/vgg16/approxhpvm_src.cc similarity index 100% rename from hpvm/projects/onnx/generated/vgg16/approxhpvm_src.cc rename to hpvm/projects/onnx/keras_ref/generated/vgg16/approxhpvm_src.cc diff --git a/hpvm/projects/onnx/generated/vgg16/src.cc b/hpvm/projects/onnx/keras_ref/generated/vgg16/src.cc similarity index 100% rename from hpvm/projects/onnx/generated/vgg16/src.cc rename to hpvm/projects/onnx/keras_ref/generated/vgg16/src.cc diff --git a/hpvm/projects/onnx/keras_environment.yml b/hpvm/projects/onnx/keras_ref/keras_environment.yml similarity index 100% rename from hpvm/projects/onnx/keras_environment.yml rename to hpvm/projects/onnx/keras_ref/keras_environment.yml diff --git a/hpvm/projects/onnx/keras_frontend/__init__.py b/hpvm/projects/onnx/keras_ref/keras_frontend/__init__.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/__init__.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/__init__.py diff --git a/hpvm/projects/onnx/keras_frontend/approxhpvm_translator.py b/hpvm/projects/onnx/keras_ref/keras_frontend/approxhpvm_translator.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/approxhpvm_translator.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/approxhpvm_translator.py diff --git a/hpvm/projects/onnx/keras_frontend/hpvm_dfg_translator.py b/hpvm/projects/onnx/keras_ref/keras_frontend/hpvm_dfg_translator.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/hpvm_dfg_translator.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/hpvm_dfg_translator.py diff --git a/hpvm/projects/onnx/keras_frontend/promise_translator.py b/hpvm/projects/onnx/keras_ref/keras_frontend/promise_translator.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/promise_translator.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/promise_translator.py diff --git a/hpvm/projects/onnx/keras_frontend/quantize_utils.py b/hpvm/projects/onnx/keras_ref/keras_frontend/quantize_utils.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/quantize_utils.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/quantize_utils.py diff --git a/hpvm/projects/onnx/keras_frontend/setup.py b/hpvm/projects/onnx/keras_ref/keras_frontend/setup.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/setup.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/setup.py diff --git a/hpvm/projects/onnx/keras_frontend/utils.py b/hpvm/projects/onnx/keras_ref/keras_frontend/utils.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/utils.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/utils.py diff --git a/hpvm/projects/onnx/keras_frontend/weight_utils.py b/hpvm/projects/onnx/keras_ref/keras_frontend/weight_utils.py similarity index 100% rename from hpvm/projects/onnx/keras_frontend/weight_utils.py rename to hpvm/projects/onnx/keras_ref/keras_frontend/weight_utils.py diff --git a/hpvm/projects/onnx/keras_ref/setup.py b/hpvm/projects/onnx/keras_ref/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..9da7193379454d1d5cdc1e63f6b436d3771e15a5 --- /dev/null +++ b/hpvm/projects/onnx/keras_ref/setup.py @@ -0,0 +1,12 @@ + +from setuptools import setup + +setup( + name='frontend', + version='1.0', + description='ApproxHPVM frontend modules', + author='Hashim', + author_email='hsharif3@illinois.edu', + packages=['frontend'], + install_requires=[], +) diff --git a/hpvm/projects/onnx/src/alexnet.py b/hpvm/projects/onnx/keras_ref/src/alexnet.py similarity index 100% rename from hpvm/projects/onnx/src/alexnet.py rename to hpvm/projects/onnx/keras_ref/src/alexnet.py diff --git a/hpvm/projects/onnx/src/alexnet2.py b/hpvm/projects/onnx/keras_ref/src/alexnet2.py similarity index 100% rename from hpvm/projects/onnx/src/alexnet2.py rename to hpvm/projects/onnx/keras_ref/src/alexnet2.py diff --git a/hpvm/projects/onnx/src/lenet.py b/hpvm/projects/onnx/keras_ref/src/lenet.py similarity index 100% rename from hpvm/projects/onnx/src/lenet.py rename to hpvm/projects/onnx/keras_ref/src/lenet.py diff --git a/hpvm/projects/onnx/src/mobilenet_cifar10.py b/hpvm/projects/onnx/keras_ref/src/mobilenet_cifar10.py similarity index 100% rename from hpvm/projects/onnx/src/mobilenet_cifar10.py rename to hpvm/projects/onnx/keras_ref/src/mobilenet_cifar10.py diff --git a/hpvm/projects/onnx/src/mobilenet_shallow.py b/hpvm/projects/onnx/keras_ref/src/mobilenet_shallow.py similarity index 100% rename from hpvm/projects/onnx/src/mobilenet_shallow.py rename to hpvm/projects/onnx/keras_ref/src/mobilenet_shallow.py diff --git a/hpvm/projects/onnx/src/mobilenetv2_cifar10.py b/hpvm/projects/onnx/keras_ref/src/mobilenetv2_cifar10.py similarity index 100% rename from hpvm/projects/onnx/src/mobilenetv2_cifar10.py rename to hpvm/projects/onnx/keras_ref/src/mobilenetv2_cifar10.py diff --git a/hpvm/projects/onnx/src/resnet.py b/hpvm/projects/onnx/keras_ref/src/resnet.py similarity index 100% rename from hpvm/projects/onnx/src/resnet.py rename to hpvm/projects/onnx/keras_ref/src/resnet.py diff --git a/hpvm/projects/onnx/src/vgg16_cifar10.py b/hpvm/projects/onnx/keras_ref/src/vgg16_cifar10.py similarity index 100% rename from hpvm/projects/onnx/src/vgg16_cifar10.py rename to hpvm/projects/onnx/keras_ref/src/vgg16_cifar10.py diff --git a/hpvm/projects/onnx/src/vgg16_cifar100.py b/hpvm/projects/onnx/keras_ref/src/vgg16_cifar100.py similarity index 100% rename from hpvm/projects/onnx/src/vgg16_cifar100.py rename to hpvm/projects/onnx/keras_ref/src/vgg16_cifar100.py diff --git a/hpvm/projects/onnx/onnx_environment.yml b/hpvm/projects/onnx/onnx_environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..b0aa451b76f83618f69e8eb95ff86163b658a33a --- /dev/null +++ b/hpvm/projects/onnx/onnx_environment.yml @@ -0,0 +1,322 @@ +name: approxhpvm_keras +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - absl-py=0.6.1=py36_0 + - anaconda-project=0.8.2=py36_0 + - asn1crypto=0.24.0=py36_0 + - automat=0.7.0=py36_0 + - babel=2.6.0=py36_0 + - backports=1.0=py36_1 + - backports.os=0.1.1=py36_0 + - beautifulsoup4=4.6.3=py36_0 + - bkcharts=0.2=py36_0 + - blaze=0.11.3=py36_0 + - conda=4.5.11=py36_0 + - conda-env=2.6.0=1 + - contextlib2=0.5.5=py36_0 + - cycler=0.10.0=py36_0 + - dill=0.2.8.2=py36_0 + - docutils=0.14=py36_0 + - entrypoints=0.2.3=py36_2 + - et_xmlfile=1.0.1=py36_0 + - idna=2.7=py36_0 + - imageio=2.4.1=py36_0 + - importlib_metadata=0.6=py36_0 + - ipython_genutils=0.2.0=py36_0 + - isort=4.3.4=py36_0 + - jdcal=1.4=py36_0 + - jedi=0.13.1=py36_0 + - jinja2=2.10=py36_0 + - jmespath=0.9.3=py36_0 + - jsonschema=2.6.0=py36_0 + - keyring=16.0.0=py36_0 + - libgcc=7.2.0=h69d50b8_2 + - libgfortran=3.0.0=1 + - locket=0.2.0=py36_1 + - more-itertools=4.3.0=py36_0 + - nbconvert=5.3.1=py36_0 + - nbformat=4.4.0=py36_0 + - nose=1.3.7=py36_2 + - notebook=5.7.0=py36_0 + - numpydoc=0.8.0=py36_0 + - odo=0.5.1=py36_0 + - pathlib2=2.3.2=py36_0 + - pexpect=4.6.0=py36_0 + - pickleshare=0.7.5=py36_0 + - ply=3.11=py36_0 + - ptyprocess=0.6.0=py36_0 + - pycodestyle=2.4.0=py36_0 + - pygments=2.2.0=py36_0 + - pylint=2.1.1=py36_0 + - pyopenssl=18.0.0=py36_0 + - qtconsole=4.4.2=py36_0 + - requests=2.19.1=py36_0 + - s3transfer=0.1.13=py36_0 + - secretstorage=3.1.0=py36_0 + - setuptools=40.5.0=py36_0 + - singledispatch=3.4.0.3=py36_0 + - six=1.11.0=py36_1 + - snowballstemmer=1.2.1=py36_0 + - sortedcollections=1.0.1=py36_0 + - sphinx=1.8.1=py36_0 + - spyder=3.3.1=py36_1 + - sympy=1.3=py36_0 + - tblib=1.3.2=py36_0 + - termcolor=1.1.0=py36_1 + - terminado=0.8.1=py36_1 + - testpath=0.4.2=py36_0 + - torchvision=0.2.1=py36_0 + - traitlets=4.3.2=py36_0 + - typing=3.6.4=py36_0 + - unicodecsv=0.14.1=py36_0 + - urllib3=1.23=py36_0 + - wcwidth=0.1.7=py36_0 + - wheel=0.32.2=py36_0 + - widgetsnbextension=3.4.2=py36_0 + - xlwt=1.3.0=py36_0 + - _license=1.1=py36_1 + - _tflow_select=2.1.0=gpu + - alabaster=0.7.12=py36_0 + - anaconda-client=1.7.2=py36_0 + - anaconda=custom=py36hbbc8b67_0 + - anaconda-navigator=1.9.2=py36_0 + - appdirs=1.4.3=py36h28b3542_0 + - astor=0.7.1=py36_0 + - astroid=2.0.4=py36_0 + - astropy=3.0.5=py36h7b6447c_0 + - atomicwrites=1.2.1=py36_0 + - attrs=18.2.0=py36h28b3542_0 + - backcall=0.1.0=py36_0 + - backports.shutil_get_terminal_size=1.0.0=py36_2 + - bitarray=0.8.3=py36h14c3975_0 + - blas=1.0=mkl + - bleach=3.0.2=py36_0 + - blosc=1.14.4=hdbcaa40_0 + - bokeh=1.0.1=py36_0 + - boto=2.49.0=py36_0 + - boto3=1.9.35=py36_0 + - botocore=1.12.35=py36_0 + - bottleneck=1.2.1=py36h035aef0_1 + - bz2file=0.98=py36_1 + - bzip2=1.0.6=h14c3975_5 + - ca-certificates=2018.03.07=0 + - cairo=1.14.12=h8948797_3 + - certifi=2018.10.15=py36_0 + - cffi=1.11.5=py36he75722e_1 + - chardet=3.0.4=py36_1 + - chest=0.2.3=py36_1 + - click=7.0=py36_0 + - cloudpickle=0.6.1=py36_0 + - clyent=1.2.2=py36_1 + - colorama=0.4.0=py36_0 + - configobj=5.0.6=py36_1 + - constantly=15.1.0=py36h28b3542_0 + - cryptography=2.3.1=py36hc365091_0 + - cudatoolkit=9.0=h13b8566_0 + - cudnn=7.1.2=cuda9.0_0 + - cupti=9.0.176=0 + - curl=7.61.0=h84994c4_0 + - cython=0.29=py36he6710b0_0 + - cytoolz=0.9.0.1=py36h14c3975_1 + - dask=0.20.0=py36_0 + - dask-core=0.20.0=py36_0 + - datashape=0.5.4=py36_1 + - dbus=1.13.2=h714fa37_1 + - decorator=4.3.0=py36_0 + - defusedxml=0.5.0=py36_1 + - distributed=1.24.0=py36_0 + - expat=2.2.6=he6710b0_0 + - fastcache=1.0.2=py36h14c3975_2 + - filelock=3.0.10=py36_0 + - flask=1.0.2=py36_1 + - flask-cors=3.0.6=py36_0 + - fontconfig=2.13.0=h9420a91_0 + - freetype=2.9.1=h8a8886c_1 + - fribidi=1.0.5=h7b6447c_0 + - gast=0.2.0=py36_0 + - gensim=3.4.0=py36h14c3975_0 + - get_terminal_size=1.0.0=haa9412d_0 + - gevent=1.3.7=py36h7b6447c_1 + - glib=2.56.2=hd408876_0 + - glob2=0.6=py36_1 + - gmp=6.1.2=h6c8ec71_1 + - gmpy2=2.0.8=py36h10f8cd9_2 + - graphite2=1.3.12=h23475e2_2 + - greenlet=0.4.15=py36h7b6447c_0 + - grpcio=1.12.1=py36hdbcaa40_0 + - gst-plugins-base=1.14.0=hbbd80ab_1 + - gstreamer=1.14.0=hb453b48_1 + - h5py=2.8.0=py36h989c5e5_3 + - harfbuzz=1.8.8=hffaf4a1_0 + - hdf5=1.10.2=hba1933b_1 + - heapdict=1.0.0=py36_2 + - html5lib=1.0.1=py36_0 + - hyperlink=18.0.0=py36_0 + - icu=58.2=h9c2bf20_1 + - imagesize=1.1.0=py36_0 + - incremental=17.5.0=py36_0 + - ipykernel=5.1.0=py36h39e3cac_0 + - ipython=7.1.1=py36h39e3cac_0 + - ipywidgets=7.4.2=py36_0 + - itsdangerous=1.1.0=py36_0 + - jbig=2.1=hdba287a_0 + - jeepney=0.4=py36_0 + - jpeg=9b=h024ee3a_2 + - keras=2.1.6=py36_0 + - keras-applications=1.0.6=py36_0 + - keras-preprocessing=1.0.5=py36_0 + - kiwisolver=1.0.1=py36hf484d3e_0 + - lazy-object-proxy=1.3.1=py36h14c3975_2 + - libcurl=7.61.0=h1ad7b7a_0 + - libedit=3.1.20170329=h6b74fdf_2 + - libffi=3.2.1=hd88cf55_4 + - libgcc-ng=8.2.0=hdf63c60_1 + - libgfortran-ng=7.3.0=hdf63c60_0 + - libiconv=1.15=h63c8f33_5 + - libpng=1.6.35=hbc83047_0 + - libprotobuf=3.6.1=hd408876_0 + - libsodium=1.0.16=h1bed415_0 + - libssh2=1.8.0=h9cfc8f7_4 + - libstdcxx-ng=8.2.0=hdf63c60_1 + - libtiff=4.0.9=he85c1e1_2 + - libtool=2.4.6=h7b6447c_5 + - libuuid=1.0.3=h1bed415_2 + - libxcb=1.13=h1bed415_1 + - libxml2=2.9.8=h26e45fe_1 + - libxslt=1.1.32=h1312cb7_0 + - llvmlite=0.25.0=py36hd408876_0 + - lxml=4.2.5=py36hefd8a0e_0 + - lzo=2.10=h49e0be7_2 + - markdown=3.0.1=py36_0 + - markupsafe=1.0=py36h14c3975_1 + - matplotlib=3.0.1=py36h5429711_0 + - mccabe=0.6.1=py36_1 + - mistune=0.8.4=py36h7b6447c_0 + - mkl=2018.0.3=1 + - mkl-service=1.1.2=py36h90e4bf4_5 + - mkl_fft=1.0.6=py36h7dd41cf_0 + - mkl_random=1.0.1=py36h4414c95_1 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.1=hdf1c602_3 + - mpmath=1.0.0=py36_2 + - msgpack-python=0.5.6=py36h6bb024c_1 + - multipledispatch=0.6.0=py36_0 + - navigator-updater=0.2.1=py36_0 + - nccl=1.3.5=cuda9.0_0 + - ncurses=6.1=hf484d3e_0 + - networkx=2.2=py36_1 + - ninja=1.8.2=py36h6bb024c_1 + - nltk=3.3.0=py36_0 + - numba=0.40.0=py36h962f231_0 + - numexpr=2.6.8=py36hd89afb7_0 + - numpy=1.15.3=py36h1d66e8a_0 + - numpy-base=1.15.3=py36h81de0dd_0 + - olefile=0.46=py36_0 + - openpyxl=2.5.9=py36_0 + - openssl=1.0.2p=h14c3975_0 + - packaging=18.0=py36_0 + - pandas=0.23.4=py36h04863e7_0 + - pandoc=2.2.3.2=0 + - pandocfilters=1.4.2=py36_1 + - pango=1.42.4=h049681c_0 + - parso=0.3.1=py36_0 + - partd=0.3.9=py36_0 + - patchelf=0.9=he6710b0_3 + - path.py=11.5.0=py36_0 + - patsy=0.5.1=py36_0 + - pcre=8.42=h439df22_0 + - pep8=1.7.1=py36_0 + - pillow=5.3.0=py36h34e0f95_0 + - pip=18.1=py36_0 + - pixman=0.34.0=hceecf20_3 + - pkginfo=1.4.2=py36_1 + - pluggy=0.8.0=py36_0 + - prometheus_client=0.4.2=py36_0 + - prompt_toolkit=2.0.7=py36_0 + - protobuf=3.6.1=py36he6710b0_0 + - psutil=5.4.8=py36h7b6447c_0 + - py=1.7.0=py36_0 + - pyasn1=0.4.4=py36h28b3542_0 + - pyasn1-modules=0.2.2=py36_0 + - pycosat=0.6.3=py36h14c3975_0 + - pycparser=2.19=py36_0 + - pycrypto=2.6.1=py36h14c3975_9 + - pycurl=7.43.0.2=py36hb7f436b_0 + - pyflakes=2.0.0=py36_0 + - pyhamcrest=1.9.0=py36_2 + - pyodbc=4.0.24=py36he6710b0_0 + - pyparsing=2.2.2=py36_0 + - pyqt=5.9.2=py36h05f1152_2 + - pysocks=1.6.8=py36_0 + - pytables=3.4.4=py36ha205bf6_0 + - pytest=3.9.3=py36_0 + - pytest-arraydiff=0.2=py36h39e3cac_0 + - pytest-astropy=0.4.0=py36_0 + - pytest-doctestplus=0.1.3=py36_0 + - pytest-openfiles=0.3.0=py36_0 + - pytest-remotedata=0.3.1=py36_0 + - python=3.6.6=h6e4f718_2 + - python-dateutil=2.7.5=py36_0 + - pytorch=0.4.1=py36ha74772b_0 + - pytz=2018.7=py36_0 + - pywavelets=1.0.1=py36hdd07704_0 + - pyyaml=3.13=py36h14c3975_0 + - pyzmq=17.1.2=py36h14c3975_0 + - qt=5.9.6=h8703b6f_2 + - qtawesome=0.5.2=py36_0 + - qtpy=1.5.2=py36_0 + - readline=7.0=h7b6447c_5 + - redis=5.0.0=h7b6447c_0 + - redis-py=2.10.6=py36_0 + - rope=0.11.0=py36_0 + - ruamel_yaml=0.15.46=py36h14c3975_0 + - scikit-image=0.14.0=py36hf484d3e_1 + - scikit-learn=0.20.0=py36h4989274_1 + - scipy=1.1.0=py36hfa4b5c9_1 + - seaborn=0.9.0=py36_0 + - send2trash=1.5.0=py36_0 + - service_identity=17.0.0=py36h28b3542_0 + - simplegeneric=0.8.1=py36_2 + - sip=4.19.8=py36hf484d3e_0 + - smart_open=1.7.1=py36_0 + - snappy=1.1.7=hbae5bb6_3 + - sockjs-tornado=1.0.6=py36_0 + - sortedcontainers=2.0.5=py36_0 + - sphinxcontrib=1.0=py36_1 + - sphinxcontrib-websupport=1.1.0=py36_1 + - spyder-kernels=0.2.6=py36_0 + - sqlalchemy=1.2.12=py36h7b6447c_0 + - sqlite=3.25.2=h7b6447c_0 + - statsmodels=0.9.0=py36h035aef0_0 + - tensorboard=1.11.0=py36hf484d3e_0 + - tensorflow=1.11.0=gpu_py36h4459f94_0 + - tensorflow-base=1.11.0=gpu_py36h8e0ae2d_0 + - tensorflow-gpu=1.11.0=h0d30ee6_0 + - tk=8.6.8=hbc83047_0 + - toolz=0.9.0=py36_0 + - tornado=5.1.1=py36h7b6447c_0 + - tqdm=4.28.1=py36h28b3542_0 + - twisted=18.9.0=py36h7b6447c_0 + - typed-ast=1.1.0=py36h14c3975_0 + - unixodbc=2.3.7=h14c3975_0 + - webencodings=0.5.1=py36_1 + - werkzeug=0.14.1=py36_0 + - wrapt=1.10.11=py36h14c3975_2 + - xlrd=1.1.0=py36_1 + - xlsxwriter=1.1.2=py36_0 + - xz=5.2.4=h14c3975_4 + - yaml=0.1.7=had09818_2 + - zeromq=4.2.5=hf484d3e_1 + - zict=0.1.3=py36_0 + - zlib=1.2.11=ha838bed_2 + - zope=1.0=py36_1 + - zope.interface=4.6.0=py36h7b6447c_0 + - cuda91=1.0=h4c16780_0 + - pip: + - msgpack==0.5.6 + - tables==3.4.4 + - torch==0.4.1 + diff --git a/hpvm/projects/onnx/setup.py b/hpvm/projects/onnx/setup.py index 9da7193379454d1d5cdc1e63f6b436d3771e15a5..bf007ea33449f70ff88cf5dc186b6a61fd3f6a67 100644 --- a/hpvm/projects/onnx/setup.py +++ b/hpvm/projects/onnx/setup.py @@ -2,11 +2,11 @@ from setuptools import setup setup( - name='frontend', + name='onnx_frontend', version='1.0', - description='ApproxHPVM frontend modules', - author='Hashim', - author_email='hsharif3@illinois.edu', - packages=['frontend'], + description='HPVM onnx frontend modules', + author='Yuanjing Shi', + author_email='ys26@illinois.edu', + packages=['onnx_frontend'], install_requires=[], ) diff --git a/hpvm/projects/onnx/src/.ipynb_checkpoints/mnist-checkpoint.ipynb b/hpvm/projects/onnx/src/.ipynb_checkpoints/mnist-checkpoint.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..2fd64429bf421126b7000c94ce0f6fd186fbd01f --- /dev/null +++ b/hpvm/projects/onnx/src/.ipynb_checkpoints/mnist-checkpoint.ipynb @@ -0,0 +1,6 @@ +{ + "cells": [], + "metadata": {}, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/hpvm/projects/onnx/src/mnist.ipynb b/hpvm/projects/onnx/src/mnist.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..59cd82bf110c52ec0e141b6d0a6d0644bc1d66f5 --- /dev/null +++ b/hpvm/projects/onnx/src/mnist.ipynb @@ -0,0 +1,115 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import numpy as np\n", + "import onnx\n", + "import glob\n", + "from onnxruntime.backend.backend import OnnxRuntimeBackend as backend\n", + "\n", + "from onnx import numpy_helper" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "model = onnx.load('../models/mnist/mnist.onnx')\n", + "test_data_dir = '../models/mnist/test_data_set_0'" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "\nArrays are not almost equal to 7 decimals\n\nMismatch: 80%\nMax absolute difference: 0.00292969\nMax relative difference: 4.834256e-06\n x: array([[ 975.6701 , -618.72394 , 6574.5684 , 668.02893 ,\n -917.27094 , -1671.6359 , -1952.7599 , -61.549873,\n -777.17664 , -1439.5316 ]], dtype=float32)\n y: array([[ 975.67035 , -618.7242 , 6574.5654 , 668.0283 ,\n -917.27106 , -1671.6361 , -1952.7599 , -61.549576,\n -777.17645 , -1439.5316 ]], dtype=float32)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-19-2c0c3f208847>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Compare the results with reference outputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mref_o\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mref_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtesting\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mref_o\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py\u001b[0m in \u001b[0;36massert_almost_equal\u001b[0;34m(actual, desired, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 570\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mactual\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdesired\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 572\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0massert_array_almost_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mactual\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdesired\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0merr_msg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 573\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0;31m# If one of desired/actual is not finite, handle it specially here:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py\u001b[0m in \u001b[0;36massert_array_almost_equal\u001b[0;34m(x, y, decimal, err_msg, verbose)\u001b[0m\n\u001b[1;32m 1005\u001b[0m assert_array_compare(compare, x, y, err_msg=err_msg, verbose=verbose,\n\u001b[1;32m 1006\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Arrays are not almost equal to %d decimals'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mdecimal\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1007\u001b[0;31m precision=decimal)\n\u001b[0m\u001b[1;32m 1008\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1009\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py\u001b[0m in \u001b[0;36massert_array_compare\u001b[0;34m(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf)\u001b[0m\n\u001b[1;32m 817\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mheader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mheader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 818\u001b[0m names=('x', 'y'), precision=precision)\n\u001b[0;32m--> 819\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mAssertionError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 820\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtraceback\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: \nArrays are not almost equal to 7 decimals\n\nMismatch: 80%\nMax absolute difference: 0.00292969\nMax relative difference: 4.834256e-06\n x: array([[ 975.6701 , -618.72394 , 6574.5684 , 668.02893 ,\n -917.27094 , -1671.6359 , -1952.7599 , -61.549873,\n -777.17664 , -1439.5316 ]], dtype=float32)\n y: array([[ 975.67035 , -618.7242 , 6574.5654 , 668.0283 ,\n -917.27106 , -1671.6361 , -1952.7599 , -61.549576,\n -777.17645 , -1439.5316 ]], dtype=float32)" + ] + } + ], + "source": [ + "# Load inputs\n", + "inputs = []\n", + "inputs_num = len(glob.glob(os.path.join(test_data_dir, 'input_*.pb')))\n", + "print(inputs_num)\n", + "for i in range(inputs_num):\n", + " input_file = os.path.join(test_data_dir, 'input_{}.pb'.format(i))\n", + " tensor = onnx.TensorProto()\n", + " with open(input_file, 'rb') as f:\n", + " tensor.ParseFromString(f.read())\n", + " inputs.append(numpy_helper.to_array(tensor))\n", + "\n", + "# Load reference outputs\n", + "ref_outputs = []\n", + "ref_outputs_num = len(glob.glob(os.path.join(test_data_dir, 'output_*.pb')))\n", + "for i in range(ref_outputs_num):\n", + " output_file = os.path.join(test_data_dir, 'output_{}.pb'.format(i))\n", + " tensor = onnx.TensorProto()\n", + " with open(output_file, 'rb') as f:\n", + " tensor.ParseFromString(f.read())\n", + " ref_outputs.append(numpy_helper.to_array(tensor))\n", + "\n", + "# Run the model on the backend\n", + "outputs = list(backend.run_model(model, inputs))\n", + "\n", + "# Compare the results with reference outputs.\n", + "for ref_o, o in zip(ref_outputs, outputs):\n", + " np.testing.assert_almost_equal(ref_o, o)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}