Source code for deeprob.spn.learning.wrappers

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import Optional, Union, Type, List

import numpy as np
from tqdm import tqdm

from deeprob.spn.structure.leaf import LeafType, Leaf
from deeprob.spn.structure.node import Node, Sum, assign_ids
from deeprob.spn.learning.learnspn import learn_spn
from deeprob.spn.learning.xpc import learn_xpc, learn_expc
from deeprob.spn.algorithms.structure import prune


[docs]def learn_estimator( data: np.ndarray, distributions: List[Type[Leaf]], domains: Optional[List[Union[list, tuple]]] = None, method: str = 'learnspn', **kwargs ) -> Node: """ Learn a SPN density estimator given some training data, the features distributions and domains. :param data: The training data. :param distributions: A list of distribution classes (one for each feature). :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete distributions, or a tuple (consisting of min value and max value), for continuous distributions. If None, domains are determined automatically. :param method: The method used for structure learning. It can be either 'learnspn', 'xpc' or 'ensemble-xpc'. :param kwargs: Additional parameters for structure learning. :return: A learned valid and optimized SPN. :raises ValueError: If the method used for structure learning is not known. :raises ValueError: If the method is 'xpc' or 'ensemble-xpc' but the variable domains are not binary. """ if domains is None: domains = compute_data_domains(data, distributions) if method == 'learnspn': root = learn_spn(data, distributions, domains, **kwargs) return prune(root, copy=False) if method == 'xpc': if not all(d == [0, 1] for d in domains): raise ValueError("The domains must be binary for learning a XPC") root, _ = learn_xpc(data, **kwargs) return root if method == 'ensemble-xpc': if not all(d == [0, 1] for d in domains): raise ValueError("The domains must be binary for learning an Ensemble-XPC") root, _ = learn_expc(data, **kwargs) return root raise ValueError("Unknown SPN learning method called {}".format(method))
[docs]def learn_classifier( data: np.ndarray, distributions: List[Type[Leaf]], domains: Optional[List[Union[list, tuple]]] = None, class_idx: int = -1, verbose: bool = True, **kwargs ) -> Node: """ Learn a SPN classifier given some training data, the features distributions and domains and the class index in the training data. :param data: The training data. :param distributions: A list of distribution classes (one for each feature). :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete distributions, or a tuple (consisting of min value and max value), for continuous distributions. If None, domains are determined automatically. :param class_idx: The index of the class feature in the training data. :param verbose: Whether to enable verbose mode. :param kwargs: Other parameters for structure learning. :return: A learned valid and optimized SPN. """ if domains is None: domains = compute_data_domains(data, distributions) n_samples, _ = data.shape classes = data[:, class_idx] # Initialize the tqdm wrapped unique classes array, if verbose is enabled unique_classes = np.unique(classes) if verbose: unique_classes = tqdm(unique_classes, bar_format='{l_bar}{bar:24}{r_bar}', unit='class') # Learn each sub-spn's structure individually weights = [] children = [] for c in unique_classes: local_data = data[classes == c] weight = len(local_data) / n_samples branch = learn_spn(local_data, distributions, domains, verbose=verbose, **kwargs) weights.append(weight) children.append(prune(branch, copy=False)) root = Sum(children=children, weights=weights) return assign_ids(root)
[docs]def compute_data_domains(data: np.ndarray, distributions: List[Type[Leaf]]) -> List[Union[list, tuple]]: """ Compute the domains based on the training data and the features distributions. :param data: The training data. :param distributions: A list of distribution classes. :return: A list of domains. Each domain is either a list of values, for discrete distributions, or a tuple (consisting of min value and max value), for continuous distributions. :raises ValueError: If an unknown distribution type is found. """ domains = [] for i, d in enumerate(distributions): col = data[:, i] if d.LEAF_TYPE == LeafType.DISCRETE: vals = np.unique(col).tolist() domains.append(vals) elif d.LEAF_TYPE == LeafType.CONTINUOUS: vmin = np.min(col).item() vmax = np.max(col).item() domains.append((vmin, vmax)) else: raise ValueError("Unknown distribution type {}".format(d.LEAF_TYPE)) return domains