Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Callable, Optional
import torch
from torch.nn import Conv2d, MaxPool2d, Module, Sequential, Softmax
ActivT = Optional[Callable[[], Module]]
def make_conv_pool_activ(
in_channels: int,
out_channels: int,
kernel_size: int,
activation: ActivT = None,
pool_size: Optional[int] = None,
pool_stride: Optional[int] = None,
**conv_kwargs
):
layers = [Conv2d(in_channels, out_channels, kernel_size, **conv_kwargs)]
if pool_size is not None:
layers.append(MaxPool2d(pool_size, stride=pool_stride))
if activation:
layers.append(activation())
return layers
class Classifier(Module):
def __init__(
self, convs: Sequential, linears: Sequential, use_softmax: bool = False
):
super().__init__()
self.convs = convs
self.linears = linears
self.softmax = Softmax(1) if use_softmax else Sequential()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.convs(inputs)
return self.softmax(self.linears(outputs.view(outputs.shape[0], -1)))