Skip to content
Snippets Groups Projects
_container.py 1.06 KiB
Newer Older
  • Learn to ignore specific revisions
  • Yifan Zhao's avatar
    Yifan Zhao committed
    from typing import Callable, Optional
    
    import torch
    from torch.nn import Conv2d, MaxPool2d, Module, Sequential, Softmax
    
    ActivT = Optional[Callable[[], Module]]
    
    
    def make_conv_pool_activ(
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        activation: ActivT = None,
        pool_size: Optional[int] = None,
        pool_stride: Optional[int] = None,
        **conv_kwargs
    ):
        layers = [Conv2d(in_channels, out_channels, kernel_size, **conv_kwargs)]
        if pool_size is not None:
            layers.append(MaxPool2d(pool_size, stride=pool_stride))
        if activation:
            layers.append(activation())
        return layers
    
    
    class Classifier(Module):
        def __init__(
            self, convs: Sequential, linears: Sequential, use_softmax: bool = False
        ):
            super().__init__()
            self.convs = convs
            self.linears = linears
            self.softmax = Softmax(1) if use_softmax else Sequential()
    
        def forward(self, inputs: torch.Tensor) -> torch.Tensor:
            outputs = self.convs(inputs)
            return self.softmax(self.linears(outputs.view(outputs.shape[0], -1)))