Source code for deeprob.torch.constraints

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

import torch
from torch import nn


[docs]class ScaleClipper(nn.Module): def __init__(self, eps: float = 1e-5): """ Constraints the scale to be positive. :param eps: The epsilon minimum value threshold. :raises ValueError: If the epsilon value is out of domain. """ if eps <= 0.0: raise ValueError("The epsilon value must be positive") super().__init__() self.register_buffer('eps', torch.tensor(eps))
[docs] def forward(self, module: nn.Module): """ Call the constraint. :param module: The module. """ with torch.no_grad(): module.scale.clamp_(self.eps)