Source code for deeprob.torch.initializers

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

import torch
from torch import distributions


[docs]def dirichlet_(tensor: torch.Tensor, alpha: float = 1.0, log_space: bool = True, dim: int = -1): """ Initialize a tensor using the symmetric Dirichlet distribution. :param tensor: The tensor to initialize. :param alpha: The concentration parameter. :param log_space: Whether to initialize the tensor in the logarithmic space. :param dim: The dimension over which to sample. """ shape = tensor.shape if len(shape) == 0: raise ValueError("Singleton tensors are not valid") min_dim, max_dim = -len(shape), len(shape) - 1 if dim not in range(min_dim, max_dim): raise IndexError( "Dimension out of range (expected to be in range of [{}, {}], but got {})".format(min_dim, max_dim, dim) ) idx = (len(shape) + dim) % len(shape) with torch.no_grad(): concentration = torch.full([shape[idx]], alpha) dirichlet = distributions.Dirichlet(concentration) samples = dirichlet.sample([d for i, d in enumerate(shape) if i != idx]) if log_space: samples = torch.log(samples) tensor.copy_(torch.transpose(samples, idx, -1))