Newer
Older
from torch.nn import Linear, Sequential, Tanh
from ._container import Classifier, make_conv_pool_activ
class LeNet(Classifier):
def __init__(self):
convs = Sequential(
*make_conv_pool_activ(1, 32, 5, Tanh, 2, padding=2),
*make_conv_pool_activ(32, 64, 5, Tanh, 2, padding=2)
)
linears = Sequential(Linear(7 * 7 * 64, 1024), Tanh(), Linear(1024, 10), Tanh())
super().__init__(convs, linears)