from torch.nn import Linear, Sequential, Tanh from ._container import Classifier, make_conv_pool_activ class LeNet(Classifier): def __init__(self): convs = Sequential( *make_conv_pool_activ(1, 32, 5, Tanh, 2, padding=2), *make_conv_pool_activ(32, 64, 5, Tanh, 2, padding=2) ) linears = Sequential(Linear(7 * 7 * 64, 1024), Tanh(), Linear(1024, 10), Tanh()) super().__init__(convs, linears)