Source code for deeprob.flows.layers.resnet

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

import torch
from torch import nn

from deeprob.torch.utils import WeightNormConv2d


[docs]class ResidualBlock(nn.Module): def __init__(self, n_channels: int): """ Build a basic residual block as in ResNet. :param n_channels: The number of channels. """ super().__init__() # Build the residual block self.block = nn.Sequential( nn.BatchNorm2d(n_channels), nn.ReLU(inplace=True), WeightNormConv2d(n_channels, n_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(n_channels), nn.ReLU(inplace=True), WeightNormConv2d(n_channels, n_channels, kernel_size=3, padding=1, bias=False) )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the residual block. :param x: The inputs. :return: The outputs. """ return x + self.block(x)
[docs]class ResidualNetwork(nn.Module): def __init__(self, in_channels: int, mid_channels: int, out_channels: int, n_blocks: int): """ Initialize a residual network (ResNet) with skip connections. :param in_channels: The number of input channels. :param mid_channels: The number of mid channels. :param out_channels: The number of output channels. :param n_blocks: The number of residual blocks. :raises ValueError: If a parameter is out of domain. """ if n_blocks <= 0: raise ValueError("The number of residual blocks must be positve") super().__init__() self.blocks = nn.ModuleList() self.skips = nn.ModuleList() # Build the input convolutional layer and input skip layer self.in_conv = WeightNormConv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False) self.in_skip = WeightNormConv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=True) # Build the lists of residual blocks and skip connections for _ in range(n_blocks): self.blocks.append(ResidualBlock(mid_channels)) self.skips.append(WeightNormConv2d(mid_channels, mid_channels, kernel_size=1, padding=0, bias=True)) # Build the output network self.out_network = nn.Sequential( nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), WeightNormConv2d(mid_channels, out_channels, kernel_size=1, padding=0, bias=True) )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the residual network. :param x: The inputs. :return: The outputs. """ # Pass through the input layers x = self.in_conv(x) z = self.in_skip(x) # Pass through the residual blocks for block, skip in zip(self.blocks, self.skips): x = block(x) z += skip(x) # Pass through the output network x = self.out_network(z) return x