Source code for deeprob.spn.learning.leaf

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import Optional, Union, Type, List, Callable, Any

import numpy as np

from deeprob.utils.random import RandomState
from deeprob.spn.structure.node import Node, Product
from deeprob.spn.structure.leaf import LeafType, Leaf, Bernoulli, Isotonic
from deeprob.spn.structure.cltree import BinaryCLT

#: A signature for a learn SPN leaf function.
LearnLeafFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     List[int],                 # The scope
     Any],                      # Other arguments
    Node                        # A SPN node
]


[docs]def get_learn_leaf_method(learn_leaf: str) -> LearnLeafFunc: """ Get the learn leaf method. :param learn_leaf: The learn leaf method string to use. :return: A learn leaf function. :raises ValueError: If the leaf learning method is unknown. """ if learn_leaf == 'mle': return learn_mle if learn_leaf == 'isotonic': return learn_isotonic if learn_leaf == 'binary-clt': return learn_binary_clt raise ValueError("Unknown learn leaf method called {}".format(learn_leaf))
[docs]def learn_mle( data: np.ndarray, distributions: List[Type[Leaf]], domains: List[Union[list, tuple]], scope: List[int], alpha: float = 0.1, random_state: Optional[RandomState] = None ) -> Node: """ Learn a leaf using Maximum Likelihood Estimate (MLE). If the data is multivariate, a naive factorized model is learned. :param data: The data, where each column correspond to a random variable. :param distributions: The distributions of the random variables. :param domains: The domains of the random variables. :param scope: The scope of the leaf. :param alpha: Laplace smoothing factor. :param random_state: The random state. It can be None. :return: A leaf distribution. :raises ValueError: If there are inconsistencies between the data, distributions and domains. """ if len(scope) != len(distributions) or len(domains) != len(distributions): raise ValueError("Each data column should correspond to a random variable having a distribution and a domain") if len(scope) == 1: sc, dist, dom = scope[0], distributions[0], domains[0] leaf = dist(sc) leaf.fit(data, dom, alpha=alpha) return leaf return learn_naive_factorization( data, distributions, domains, scope, learn_mle, alpha=alpha, random_state=random_state )
[docs]def learn_isotonic( data: np.ndarray, distributions: List[Type[Leaf]], domains: List[Union[list, tuple]], scope: List[int], alpha: float = 0.1, random_state: Optional[RandomState] = None ) -> Node: """ Learn a leaf using Isotonic method. If the data is multivariate, a naive factorized model is learned. :param data: The data. :param distributions: The distribution of the random variables. :param domains: The domain of the random variables. :param scope: The scope of the leaf. :param alpha: Laplace smoothing factor. :param random_state: The random sate. It can be None. :return: A leaf distribution. :raises ValueError: If there are inconsistencies between the data, distributions and domains. """ if len(scope) != len(distributions) or len(domains) != len(distributions): raise ValueError("Each data column should correspond to a random variable having a distribution and a domain") if len(scope) == 1: sc, dist, dom = scope[0], distributions[0], domains[0] leaf = Isotonic(sc) if dist.LEAF_TYPE == LeafType.CONTINUOUS else dist(sc) leaf.fit(data, dom, alpha=alpha) return leaf return learn_naive_factorization( data, distributions, domains, scope, learn_isotonic, alpha=alpha, random_state=random_state )
[docs]def learn_binary_clt( data: np.ndarray, distributions: List[Type[Leaf]], domains: List[Union[list, tuple]], scope: List[int], to_pc: bool = False, alpha: float = 0.1, random_state: Optional[RandomState] = None ) -> Node: """ Learn a leaf using a Binary Chow-Liu Tree (CLT). If the data is univariate, a Maximum Likelihood Estimate (MLE) leaf is returned. :param data: The data. :param distributions: The distributions of the random variables. :param domains: The domains of the random variables. :param scope: The scope of the leaf. :param to_pc: Whether to convert the CLT into an equivalent PC. :param alpha: Laplace smoothing factor. :param random_state: The random state. It can be None. :return: A leaf distribution. :raises ValueError: If there are inconsistencies between the data, distributions and domains. :raises ValueError: If the data doesn't follow a Bernoulli distribution. """ if len(scope) != len(distributions) or len(domains) != len(distributions): raise ValueError("Each data column should correspond to a random variable having a distribution and a domain") if any(d != Bernoulli for d in distributions): raise ValueError("Binary Chow-Liu trees are only available for Bernoulli data") # If univariate, learn using MLE instead if len(scope) == 1: return learn_mle( data, distributions, domains, scope, alpha=alpha, random_state=random_state ) # If multivariate, learn a binary CLTree leaf = BinaryCLT(scope) leaf.fit(data, domains, alpha=alpha, random_state=random_state) # Make the conversion to a probabilistic circuit, if specified if to_pc: return leaf.to_pc() return leaf
[docs]def learn_naive_factorization( data: np.ndarray, distributions: List[Type[Leaf]], domains: List[Union[list, tuple]], scope: List[int], learn_leaf_func: LearnLeafFunc, **learn_leaf_kwargs ) -> Node: """ Learn a leaf as a naive factorized model. :param data: The data. :param distributions: The distribution of the random variables. :param domains: The domain of the random variables. :param scope: The scope of the leaf. :param learn_leaf_func: The function to use to learn the sub-distributions parameters. :param learn_leaf_kwargs: Additional parameters for learn_leaf_func. :return: A naive factorized model. :raises ValueError: If there are inconsistencies between the data, distributions and domains. """ if len(scope) != len(distributions) or len(domains) != len(distributions): raise ValueError("Each data column should correspond to a random variable having a distribution and a domain") node = Product(scope) for i, s in enumerate(scope): leaf = learn_leaf_func(data[:, [i]], [distributions[i]], [domains[i]], [s], **learn_leaf_kwargs) leaf.id = i + 1 # Set the leaves ids sequentially node.children.append(leaf) return node