Source code for deeprob.spn.layers.ratspn

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

import abc
from typing import Optional, Tuple, List

import numpy as np
import torch
from torch import nn
from torch import distributions

from deeprob.torch.initializers import dirichlet_


[docs]class RegionGraphLayer(abc.ABC, nn.Module): def __init__( self, in_features: int, out_channels: int, regions: List[tuple], rg_depth: int, dropout: Optional[float] = None, **kwargs ): """ Initialize a Region Graph-based base distribution. :param in_features: The number of input features. :param out_channels: The number of channels for each base distribution layer. :param regions: The regions of the distributions. :param rg_depth: The depth of the region graph. :param dropout: The leaf nodes dropout rate. It can be None. """ super().__init__() self.in_features = in_features self.in_regions = len(regions) self.out_channels = out_channels self.rg_depth = rg_depth self.dropout = dropout self.distribution = None # Compute the padding and the number of features for each base distribution batch self.pad = -self.in_features % (2 ** self.rg_depth) in_features_pad = self.in_features + self.pad self.dimension = in_features_pad // (2 ** self.rg_depth) # Append dummy variables to regions orderings and update the pad mask mask = regions.copy() if self.pad > 0: pad_mask = np.zeros(shape=(len(regions), 1, self.dimension), dtype=np.bool_) for i, region in enumerate(regions): n_dummy = self.dimension - len(region) if n_dummy > 0: pad_mask[i, :, -n_dummy:] = True mask[i] = mask[i] + (mask[i][-1],) * n_dummy self.register_buffer('pad_mask', torch.tensor(pad_mask)) self.register_buffer('mask', torch.tensor(mask)) # Build the flatten inverse mask inv_mask = torch.argsort(torch.reshape(self.mask, [-1, in_features_pad]), dim=1) self.register_buffer('inv_mask', inv_mask) # Build the flatten inverted pad mask if self.pad > 0: inv_pad_mask = torch.reshape(self.pad_mask, [-1, in_features_pad]) inv_pad_mask = torch.gather(inv_pad_mask, dim=1, index=self.inv_mask) self.register_buffer('inv_pad_mask', inv_pad_mask)
[docs] def unpad_samples(self, x: torch.Tensor, idx_group: torch.Tensor) -> torch.Tensor: """ Reorder and unpad some samples. :param x: The samples. :param idx_group: The group indices. :return: The reordered samples with padding dummy variables removed. """ n_samples = idx_group.shape[0] # Reorder the samples idx_repetitions = torch.div(idx_group[:, 0], 2 ** self.rg_depth, rounding_mode='floor') samples = torch.gather(x, dim=1, index=self.inv_mask[idx_repetitions]) # Remove the padding, if required if self.pad > 0: samples = samples[self.inv_pad_mask[idx_repetitions]].view(n_samples, self.in_features) return samples
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Execute the layer on some inputs. :param x: The inputs. :return: The log-likelihoods of each distribution leaf. """ # Gather the inputs and compute the log-likelihoods x = torch.unsqueeze(x[:, self.mask], dim=2) x = self.distribution.log_prob(x) # Apply the input dropout, if specified if self.training and self.dropout is not None: x[torch.lt(torch.rand_like(x), self.dropout)] = np.nan # Marginalize missing values (denoted with NaNs) torch.nan_to_num_(x) # Pad to zeros if self.pad > 0: x.masked_fill_(self.pad_mask, 0.0) return torch.sum(x, dim=-1)
[docs] @abc.abstractmethod def distribution_mode(self) -> torch.Tensor: """ Get the mode of the distribution. :return: The mode of the distribution. """
[docs] @torch.no_grad() def mpe(self, x: torch.Tensor, idx_group: torch.Tensor, idx_offset: torch.Tensor) -> torch.Tensor: """ Evaluate the layer given some inputs for maximum at posteriori estimation. :param x: The inputs. Random variables can be marginalized using NaN values. :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The samples having maximum at posteriori estimates on marginalized random variables. """ # Get the maximum at posteriori estimation of the base distribution # and filter the base samples by the region and offset indices mode = self.distribution_mode() samples = torch.flatten(mode[idx_group, idx_offset], start_dim=1) # Reorder the variables and remove the padding (if required) samples = self.unpad_samples(samples, idx_group) # Assign the maximum at posteriori estimation to NaN random variables return torch.where(torch.isnan(x), samples, x)
[docs] @torch.no_grad() def sample(self, idx_group: torch.Tensor, idx_offset: torch.Tensor) -> torch.Tensor: """ Sample from a base distribution. :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The computed samples. """ n_samples = idx_group.shape[0] # Sample from the base distribution and filter the samples samples = self.distribution.sample([n_samples]) samples = samples[torch.unsqueeze(torch.arange(n_samples), dim=1), idx_group, idx_offset] samples = torch.flatten(samples, start_dim=1) # Reorder the variables and remove the padding (if required) samples = self.unpad_samples(samples, idx_group) return samples
[docs]class GaussianLayer(RegionGraphLayer): def __init__( self, in_features: int, out_channels: int, regions: List[tuple], rg_depth: int, dropout: Optional[float] = None, uniform_loc: Optional[Tuple[float, float]] = None, optimize_scale: bool = False ): """ Initialize a Gaussian distributions input layer. :param in_features: The number of input features. :param out_channels: The number of channels for each base distribution layer. :param regions: The regions of the distributions. :param rg_depth: The depth of the region graph. :param dropout: The leaf nodes dropout rate. It can be None. :param uniform_loc: The optional uniform distribution parameters for location initialization. :param optimize_scale: Whether to optimize scale and location jointly. """ super().__init__(in_features, out_channels, regions, rg_depth, dropout) # Instantiate the location variable if uniform_loc is None: self.loc = nn.Parameter( torch.randn(self.in_regions, self.out_channels, self.dimension), requires_grad=True ) else: a, b = uniform_loc self.loc = nn.Parameter( a + (b - a) * torch.rand(self.in_regions, self.out_channels, self.dimension), requires_grad=True ) # Instantiate the scale variable if optimize_scale: self.scale = nn.Parameter( 0.5 + 0.1 * torch.tanh(torch.randn(self.in_regions, self.out_channels, self.dimension)), requires_grad=True ) else: self.scale = nn.Parameter( torch.ones(self.in_regions, self.out_channels, self.dimension), requires_grad=False ) # Instantiate the multi-batch normal distribution self.distribution = distributions.Normal(self.loc, self.scale, validate_args=False)
[docs] def distribution_mode(self) -> torch.Tensor: return self.distribution.mean
[docs]class BernoulliLayer(RegionGraphLayer): def __init__( self, in_features: int, out_channels: int, regions: List[tuple], rg_depth: int, dropout: Optional[float] = None ): """ Initialize a Bernoulli distributions input layer. :param in_features: The number of input features. :param out_channels: The number of channels for each base distribution layer. :param regions: The regions of the distributions. :param rg_depth: The depth of the region graph. :param dropout: The leaf nodes dropout rate. It can be None. """ super().__init__(in_features, out_channels, regions, rg_depth, dropout) # Instantiate the logit distribution parameters self.logits = nn.Parameter( torch.randn(self.in_regions, self.out_channels, self.dimension), requires_grad=True ) # Instantiate the multi-batch Bernoulli distribution self.distribution = distributions.Bernoulli(logits=self.logits, validate_args=False)
[docs] def distribution_mode(self) -> torch.Tensor: probs = self.distribution.mean return (probs >= 0.5).float()
[docs]class ProductLayer(nn.Module): def __init__( self, in_regions: int, in_nodes: int ): """ Initialize the Product layer. :param in_regions: The number of input regions. :param in_nodes: The number of input nodes per region. """ super().__init__() self.in_regions = in_regions self.in_nodes = in_nodes self.out_partitions = in_regions // 2 self.out_nodes = in_nodes ** 2 # Initialize the mask used to compute the outer product mask = [True, False] * self.out_partitions self.register_buffer('mask', torch.tensor(mask))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the layer given some inputs. :param x: The inputs. :return: The outputs. """ # Compute the outer product (the "outer sum" in log domain) x1 = x[:, self.mask] # (-1, out_partitions, in_nodes) x2 = x[:, ~self.mask] # (-1, out_partitions, in_nodes) x1 = torch.unsqueeze(x1, dim=3) # (-1, out_partitions, in_nodes, 1) x2 = torch.unsqueeze(x2, dim=2) # (-1, out_partitions, 1, in_nodes) x = x1 + x2 # (-1, out_partitions, in_nodes, in_nodes) x = x.view(-1, self.out_partitions, self.out_nodes) # (-1, out_partitions, out_nodes) return x
[docs] @torch.no_grad() def mpe( self, x: torch.Tensor, idx_group: torch.Tensor, idx_offset: torch.Tensor ) -> [torch.Tensor, torch.Tensor]: """ Evaluate the layer given some inputs for maximum at posteriori estimation. :param x: The inputs (not used here). :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The group and offset indices. """ return self.sample(idx_group, idx_offset)
[docs] @torch.no_grad() def sample( self, idx_group: torch.Tensor, idx_offset: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample from a product layer. :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The group and offset indices. """ idx_offset0 = torch.div(idx_offset, self.in_nodes, rounding_mode='floor') idx_offset1 = torch.remainder(idx_offset, self.in_nodes) # Compute the corresponding group and offset indices idx_group = torch.flatten( torch.stack([idx_group * 2, idx_group * 2 + 1], dim=2), start_dim=1 ) idx_offset = torch.flatten( torch.stack([idx_offset0, idx_offset1], dim=2), start_dim=1 ) return idx_group, idx_offset
[docs]class SumLayer(nn.Module): def __init__( self, in_partitions: int, in_nodes: int, out_nodes: int, dropout: Optional[float] = None ): """ Initialize the sum layer. :param in_partitions: The number of input partitions. :param in_nodes: The number of input nodes per partition. :param out_nodes: The number of output nodes per region. :param dropout: The input nodes dropout rate. It can be None. """ super().__init__() self.in_partitions = in_partitions self.in_nodes = in_nodes self.out_regions = in_partitions self.out_nodes = out_nodes self.dropout = dropout # Instantiate the weights self.weight = nn.Parameter( torch.empty(self.out_regions, self.out_nodes, self.in_nodes), requires_grad=True ) dirichlet_(self.weight, alpha=1.0)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Evaluate the layer given some inputs. :param x: The inputs. :return: The outputs. """ # Apply the dropout, if specified if self.training and self.dropout is not None: x[torch.lt(torch.rand_like(x), self.dropout)] = -np.inf # Compute the log-likelihood using the "logsumexp" trick w = torch.log_softmax(self.weight, dim=2) # (out_regions, out_nodes, in_nodes) x = torch.unsqueeze(x, dim=2) # (-1, in_partitions, 1, in_nodes) with in_partitions = out_regions x = torch.logsumexp(x + w, dim=3) # (-1, out_regions, out_nodes) return x
[docs] @torch.no_grad() def mpe( self, x: torch.Tensor, idx_group: torch.Tensor, idx_offset: torch.Tensor ) -> [torch.Tensor, torch.Tensor]: """ Evaluate the layer given some inputs for maximum at posteriori estimation. :param x: The inputs. :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The group and offset indices. """ # Compute the offset indices evaluating the sum nodes as an argmax x = x[torch.unsqueeze(torch.arange(x.shape[0]), dim=1), idx_group] w = torch.log_softmax(self.weight[idx_group, idx_offset], dim=2) idx_offset = torch.argmax(x + w, dim=2) return idx_group, idx_offset
[docs] @torch.no_grad() def sample( self, idx_group: torch.Tensor, idx_offset: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample from a sum layer. :param idx_group: The group indices. :param idx_offset: The offset indices. :return: The group and offset indices. """ # Compute the indices by sampling from a categorical distribution that is parametrized by sum layer's weights w = torch.log_softmax(self.weight[idx_group, idx_offset], dim=2) idx_offset = distributions.Categorical(logits=w).sample() return idx_group, idx_offset
[docs]class RootLayer(nn.Module): def __init__( self, in_partitions: int, in_nodes: int, out_classes: int ): """ Initialize the root layer. :param in_partitions: The number of input partitions. :param in_nodes: The number of input nodes per partition. :param out_classes: The number of output nodes. """ super().__init__() self.in_partitions = in_partitions self.in_nodes = in_nodes self.out_classes = out_classes # Instantiate the weights self.weight = nn.Parameter( torch.empty(self.out_classes, self.in_partitions * self.in_nodes), requires_grad=True ) dirichlet_(self.weight, alpha=1.0)
[docs] def forward(self, x): """ Evaluate the layer given some inputs. :param x: The inputs. :return: The outputs. """ # Compute the log-likelihood using the "logsumexp" trick x = torch.flatten(x, start_dim=1) # (-1, in_partitions * in_nodes) w = torch.log_softmax(self.weight, dim=1) # (out_classes, in_partitions * in_nodes) x = torch.unsqueeze(x, dim=1) # (-1, 1, in_partitions * in_nodes) x = torch.logsumexp(x + w, dim=2) # (-1, out_classes) return x
[docs] @torch.no_grad() def mpe(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate the layer given some inputs for maximum at posteriori estimation. :param x: The inputs. :param y: The target classes. :return: The group and offset indices. """ # Compute the layer top-down and get the group and offset indices x = torch.flatten(x, start_dim=1) w = torch.log_softmax(self.weight, dim=1) idx = torch.argmax(x + w[y], dim=1, keepdim=True) idx_group = torch.div(idx, self.in_nodes, rounding_mode='floor') idx_offset = torch.remainder(idx, self.in_nodes) return idx_group, idx_offset
[docs] @torch.no_grad() def sample(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample from the root layer. :param y: The target classes. :return: The group and offset indices. """ # Compute the layer top-down and get the indices w = torch.log_softmax(self.weight, dim=1) idx = distributions.Categorical(logits=w[y]).sample().unsqueeze(dim=1) idx_group = torch.div(idx, self.in_nodes, rounding_mode='floor') idx_offset = torch.remainder(idx, self.in_nodes) return idx_group, idx_offset