Source code for deeprob.flows.layers.densenet

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import List

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

from deeprob.torch.utils import WeightNormConv2d


[docs]class DenseLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int, use_checkpoint: bool = False): """ Initialize a dense layer as in DenseNet. :param in_channels: The number of input channels. :param out_channels: The number of output channels. :param use_checkpoint: Whether to use a checkpoint in order to reduce memory usage (by increasing training time caused by re-computations). """ super().__init__() self.use_checkpoint = use_checkpoint # Build the bottleneck network # Use 4 * out_channels as number of mid features channels mid_channels = 4 * out_channels self.bottleneck_network = nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), WeightNormConv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=False) ) # Build the main dense layer self.network = nn.Sequential( nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), WeightNormConv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False) )
[docs] def bottleneck(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Pass through the bottleneck layer. :param inputs: A list of previous feature maps. :return: The outputs of the bottleneck. """ x = torch.cat(inputs, dim=1) return self.bottleneck_network(x)
[docs] def checkpoint_bottleneck(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Pass through the bottleneck layer (by using a checkpoint). :param inputs: A list of previous feature maps. :return: The outputs of the bottleneck. """ def closure(*inputs): return self.bottleneck(inputs) return checkpoint(closure, *inputs)
[docs] def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Evaluate the dense layer. :param inputs: A list of previous feature maps. :return: The outputs of the layer. """ # Pass through the bottleneck if self.use_checkpoint and any(map(lambda t: t.requires_grad, inputs)): x = self.checkpoint_bottleneck(inputs) else: x = self.bottleneck(inputs) # Pass through the main dense layer x = self.network(x) return x
[docs]class DenseBlock(nn.Module): def __init__(self, n_layers: int, in_channels: int, out_channels: int, use_checkpoint: bool = False): """ Initialize a dense block as in DenseNet. :param n_layers: The number of dense layers. :param in_channels: The number of input channels. :param out_channels: The number of output channels. :param use_checkpoint: Whether to use a checkpoint in order to reduce memory usage (by increasing training time caused by re-computations). """ super().__init__() self.layers = nn.ModuleList() # Build the dense layers for i in range(n_layers): self.layers.append(DenseLayer( in_channels + i * out_channels, out_channels, use_checkpoint=use_checkpoint ))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the dense block. :param x: The inputs. :return: The outputs. """ outputs = [x] for layer in self.layers: x = layer(outputs) outputs.append(x) return torch.cat(outputs, dim=1)
[docs]class Transition(nn.Module): def __init__(self, in_channels: int, out_channels: int, bias: bool = True): """ Initialize a transition layer as in DenseNet. :param in_channels: The number of input channels. :param out_channels: The number of output channels. :param bias: Whether to use bias in the last convolutional layer. """ super().__init__() # Build the transition layer self.network = torch.nn.Sequential( nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), WeightNormConv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=bias) )
[docs] def forward(self, x): """ Evaluate the layer. :param x: The inputs. :return: The outputs of the layer. """ return self.network(x)
[docs]class DenseNetwork(nn.Module): def __init__( self, in_channels: int, mid_channels: int, out_channels: int, n_blocks: int, use_checkpoint: bool = False ): """ Initialize a dense network (DenseNet) with only one dense block. :param in_channels: The number of input channels. :param mid_channels: The number of mid channels. :param out_channels: The number of output channels. :param n_blocks: The number of dense blocks. :param use_checkpoint: Whether to use a checkpoint in order to reduce memory usage (by increasing training time caused by re-computations). """ super().__init__() self.blocks = nn.ModuleList() # Build the input convolutional layer self.in_conv = WeightNormConv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False) # Build the list of dense blocks and transition layers # Use four dense layer for each dense block for i in range(n_blocks): self.blocks.append(DenseBlock(4, mid_channels, mid_channels, use_checkpoint=use_checkpoint)) if i == n_blocks - 1: self.blocks.append(Transition(5 * mid_channels, out_channels, bias=True)) else: self.blocks.append(Transition(5 * mid_channels, mid_channels, bias=False))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the dense network. :param x: The inputs. :return: The outputs. """ # Pass through the input convolutional layer x = self.in_conv(x) # Pass through the dense blocks for block in self.blocks: x = block(x) return x