Skip to content
Snippets Groups Projects
Unverified Commit 62862a08 authored by Lev Zlotnik's avatar Lev Zlotnik Committed by GitHub
Browse files

PyTorch 1.0.0 support + Proper Packaging (Release 0.3) (#144)

Not backward compatible - re-installation is required

* Fixes for PyTorch==1.0.0
* Refactoring folder structure
* Update installation section in docs
parent ee1d160f
No related branches found
No related tags found
No related merge requests found
......@@ -16,15 +16,12 @@
import torch
import os
import sys
import torch.nn as nn
from copy import deepcopy
from collections import OrderedDict
import pytest
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
from distiller.quantization import Quantizer
from distiller.quantization.quantizer import QBits, _ParamToQuant
from distiller.quantization.quantizer import FP_BKP_PREFIX
......
......@@ -16,16 +16,14 @@
import logging
import torch
import os
import sys
import pytest
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
sys.path.append(module_path)
import distiller
from models import ALL_MODEL_NAMES, create_model
from apputils import *
from distiller import normalize_module_name, denormalize_module_name
from distiller.models import ALL_MODEL_NAMES, create_model
from distiller.apputils import *
from distiller import normalize_module_name, denormalize_module_name, \
SummaryGraph, onnx_name_2_pytorch_name
from distiller.model_summaries import connectivity_summary, connectivity_summary_verbose
# Logging configuration
logging.basicConfig(level=logging.DEBUG)
......@@ -61,7 +59,7 @@ def test_connectivity():
assert g is not None
op_names = [op['name'] for op in g.ops.values()]
assert 73 == len(op_names)
assert 80 == len(op_names)
edges = g.edges
assert edges[0].src == '0' and edges[0].dst == 'conv1'
......@@ -69,16 +67,11 @@ def test_connectivity():
# Test two sequential calls to predecessors (this was a bug once)
preds = g.predecessors(g.find_op('bn1'), 1)
preds = g.predecessors(g.find_op('bn1'), 1)
assert preds == ['108', '2', '3', '4', '5']
assert preds == ['129', '2', '3', '4', '5']
# Test successors
succs = g.successors(g.find_op('bn1'), 2)
assert succs == ['relu']
op = g.find_op('layer1.0')
assert op is not None
preds = g.predecessors(op, 2)
assert preds == ['layer1.0.bn2', 'relu']
op = g.find_op('layer1.0.relu2')
assert op is not None
succs = g.successors(op, 4)
......@@ -180,10 +173,10 @@ def test_connectivity_summary():
assert g is not None
summary = connectivity_summary(g)
assert len(summary) == 73
assert len(summary) == 80
verbose_summary = connectivity_summary_verbose(g)
assert len(verbose_summary) == 73
assert len(verbose_summary) == 80
if __name__ == '__main__':
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment