from torch.nn import Linear, ReLU, Sequential, Tanh from ._container import Classifier, make_conv_pool_activ class AlexNet(Classifier): def __init__(self): convs = Sequential( *make_conv_pool_activ(3, 64, 11, Tanh, pool_size=2, padding=5), *make_conv_pool_activ(64, 192, 5, Tanh, pool_size=2, padding=2), *make_conv_pool_activ(192, 384, 3, Tanh, padding=1), *make_conv_pool_activ(384, 256, 3, Tanh, padding=1), *make_conv_pool_activ(256, 256, 3, Tanh, pool_size=2, padding=1) ) linears = Sequential(Linear(4096, 10)) super().__init__(convs, linears) class AlexNet2(Classifier): def __init__(self): convs = Sequential( *make_conv_pool_activ(3, 32, 3, Tanh, padding=1), *make_conv_pool_activ(32, 32, 3, Tanh, pool_size=2, padding=1), *make_conv_pool_activ(32, 64, 3, Tanh, padding=1), *make_conv_pool_activ(64, 64, 3, Tanh, pool_size=2, padding=1), *make_conv_pool_activ(64, 128, 3, Tanh, padding=1), *make_conv_pool_activ(128, 128, 3, Tanh, pool_size=2, padding=1) ) linears = Sequential(Linear(2048, 10)) super().__init__(convs, linears) class AlexNetImageNet(Classifier): def __init__(self): convs = Sequential( *make_conv_pool_activ( 3, 64, 11, ReLU, padding=2, stride=4, pool_size=3, pool_stride=2 ), *make_conv_pool_activ( 64, 192, 5, ReLU, padding=2, pool_size=3, pool_stride=2 ), *make_conv_pool_activ(192, 384, 3, ReLU, padding=1), *make_conv_pool_activ(384, 256, 3, ReLU, padding=1), *make_conv_pool_activ( 256, 256, 3, ReLU, padding=1, pool_size=3, pool_stride=2 ) ) linears = Sequential( Linear(9216, 4096), ReLU(), Linear(4096, 4096), ReLU(), Linear(4096, 1000), ) super().__init__(convs, linears)