Source code for deeprob.spn.learning.learnspn

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from enum import Enum
from collections import deque
from typing import Optional, Union, Type, List, NamedTuple

import numpy as np
from tqdm import tqdm

from deeprob.utils.random import RandomState, check_random_state
from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.structure.node import Node, Sum, Product, assign_ids
from deeprob.spn.learning.leaf import LearnLeafFunc, get_learn_leaf_method, learn_naive_factorization
from deeprob.spn.learning.splitting.rows import SplitRowsFunc, get_split_rows_method, split_rows_clusters
from deeprob.spn.learning.splitting.cols import SplitColsFunc, get_split_cols_method, split_cols_clusters


[docs]class OperationKind(Enum): """ Operation kind used by LearnSPN algorithm. """ REM_FEATURES = 1 CREATE_LEAF = 2 SPLIT_NAIVE = 3 SPLIT_ROWS = 4 SPLIT_COLS = 5
[docs]class Task(NamedTuple): """ Recursive task information used by LearnSPN algorithm. """ parent: Node data: np.ndarray scope: List[int] no_cols_split: bool = False no_rows_split: bool = False is_first: bool = False
[docs]def learn_spn( data: np.ndarray, distributions: List[Type[Leaf]], domains: List[Union[list, tuple]], learn_leaf: Union[str, LearnLeafFunc] = 'mle', split_rows: Union[str, SplitRowsFunc] = 'kmeans', split_cols: Union[str, SplitColsFunc] = 'rdc', learn_leaf_kwargs: dict = None, split_rows_kwargs: dict = None, split_cols_kwargs: dict = None, min_rows_slice: int = 256, min_cols_slice: int = 2, random_state: Optional[RandomState] = None, verbose: bool = True ) -> Node: """ Learn the structure and parameters of a SPN given some training data and several hyperparameters. :param data: The training data. :param distributions: A list of distribution classes (one for each feature). :param domains: A list of domains (one for each feature). Each domain is either a list of values, for discrete distributions, or a tuple (consisting of min value and max value), for continuous distributions. :param learn_leaf: The method to use to learn a distribution leaf node, It can be either 'mle', 'isotonic', 'binary-clt' or a custom LearnLeafFunc. :param split_rows: The rows splitting method. It can be either 'kmeans', 'gmm', 'rdc', 'random' or a custom SplitRowsFunc function. :param split_cols: The columns splitting method. It can be either 'gvs', 'rgvs', 'wrgvs', 'ebvs', 'ebvs_ae', 'gbvs', 'gbvs_ag', 'rdc', 'random' or a custom SplitColsFunc function. :param learn_leaf_kwargs: The parameters of the learn leaf method. :param split_rows_kwargs: The parameters of the rows splitting method. :param split_cols_kwargs: The parameters of the cols splitting method. :param min_rows_slice: The minimum number of samples required to split horizontally. :param min_cols_slice: The minimum number of features required to split vertically. :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState. :param verbose: Whether to enable verbose mode. :return: A learned valid SPN. :raises ValueError: If a parameter is out of scope. """ if len(distributions) == 0: raise ValueError("The list of distribution classes must be non-empty") if len(domains) == 0: raise ValueError("The list of domains must be non-empty") if min_rows_slice <= 0: raise ValueError("The minimum number of samples required to split horizontally must be positive") if min_cols_slice <= 0: raise ValueError("The minimum number of samples required to split vertically must be positive") n_samples, n_features = data.shape if len(distributions) != n_features or len(domains) != n_features: raise ValueError("Each data column should correspond to a random variable having a distribution and a domain") # Setup the learn leaf, split rows and split cols functions learn_leaf_func = get_learn_leaf_method(learn_leaf) if isinstance(learn_leaf, str) else learn_leaf split_rows_func = get_split_rows_method(split_rows) if isinstance(split_rows, str) else split_rows split_cols_func = get_split_cols_method(split_cols) if isinstance(split_cols, str) else split_cols if learn_leaf_kwargs is None: learn_leaf_kwargs = dict() if split_rows_kwargs is None: split_rows_kwargs = dict() if split_cols_kwargs is None: split_cols_kwargs = dict() # Setup the initial scope as [0, # of features - 1] initial_scope = list(range(n_features)) # Check the random state random_state = check_random_state(random_state) # Add the random state to learning leaf parameters learn_leaf_kwargs['random_state'] = random_state # Initialize the progress bar (with unspecified total), if verbose is enabled if verbose: tk = tqdm( total=np.inf, leave=None, unit='node', bar_format='{n_fmt}/{total_fmt} [{elapsed}, {rate_fmt}]' ) tasks = deque() tmp_node = Product(initial_scope) tasks.append(Task(tmp_node, data, initial_scope, is_first=True)) while tasks: # Get the next task task = tasks.popleft() # Select the operation to apply n_samples, n_features = task.data.shape # Get the indices of uninformative features zero_var_idx = np.isclose(np.var(task.data, axis=0), 0.0) # If all the features are uninformative, then split using Naive Bayes model if np.all(zero_var_idx): op = OperationKind.SPLIT_NAIVE # If only some of the features are uninformative, then remove them elif np.any(zero_var_idx): op = OperationKind.REM_FEATURES # Create a leaf node if the data split dimension is small or last rows splitting failed elif task.no_rows_split or n_features < min_cols_slice or n_samples < min_rows_slice: op = OperationKind.CREATE_LEAF # Use rows splitting if previous columns splitting failed or it is the first task elif task.no_cols_split or task.is_first: op = OperationKind.SPLIT_ROWS # Defaults to columns splitting else: op = OperationKind.SPLIT_COLS if op == OperationKind.REM_FEATURES: node = Product(task.scope) # Model the removed features using Naive Bayes rem_scope = [task.scope[i] for i, in np.argwhere(zero_var_idx)] dists, doms = [distributions[s] for s in rem_scope], [domains[s] for s in rem_scope] naive = learn_naive_factorization( task.data[:, zero_var_idx], dists, doms, rem_scope, learn_leaf_func=learn_leaf_func, **learn_leaf_kwargs ) node.children.append(naive) # Add the tasks regarding non-removed features is_first = task.is_first and len(tasks) == 0 oth_scope = [task.scope[i] for i, in np.argwhere(~zero_var_idx)] tasks.append(Task(node, task.data[:, ~zero_var_idx], oth_scope, is_first=is_first)) task.parent.children.append(node) elif op == OperationKind.CREATE_LEAF: # Create a leaf node dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope] leaf = learn_leaf_func(task.data, dists, doms, task.scope, **learn_leaf_kwargs) task.parent.children.append(leaf) elif op == OperationKind.SPLIT_NAIVE: # Split the data using a naive factorization dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope] node = learn_naive_factorization( task.data, dists, doms, task.scope, learn_leaf_func=learn_leaf_func, **learn_leaf_kwargs ) task.parent.children.append(node) elif op == OperationKind.SPLIT_ROWS: # Split the data by rows (sum node) dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope] clusters = split_rows_func(task.data, dists, doms, random_state, **split_rows_kwargs) slices, weights = split_rows_clusters(task.data, clusters) # Check whether only one partitioning is returned if len(slices) == 1: tasks.append(Task(task.parent, task.data, task.scope, no_cols_split=False, no_rows_split=True)) continue # Add sub-tasks and append Sum node node = Sum(task.scope, weights=weights) for local_data in slices: tasks.append(Task(node, local_data, task.scope)) task.parent.children.append(node) elif op == OperationKind.SPLIT_COLS: # Split the data by columns (product node) dists, doms = [distributions[s] for s in task.scope], [domains[s] for s in task.scope] clusters = split_cols_func(task.data, dists, doms, random_state, **split_cols_kwargs) slices, scopes = split_cols_clusters(task.data, clusters, task.scope) # Check whether only one partitioning is returned if len(slices) == 1: tasks.append(Task(task.parent, task.data, task.scope, no_cols_split=True, no_rows_split=False)) continue # Add sub-tasks and append Product node node = Product(task.scope) for i, local_data in enumerate(slices): tasks.append(Task(node, local_data, scopes[i])) task.parent.children.append(node) else: raise NotImplementedError("Operation of kind {} not implemented".format(op)) if verbose: tk.update() tk.refresh() if verbose: tk.close() root = tmp_node.children[0] return assign_ids(root)