Source code for deeprob.spn.learning.splitting.cols

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala, Federico Luzzi

from typing import Union, Type, Tuple, List, Callable, Any

import numpy as np

from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.learning.splitting.gvs import gvs_cols, rgvs_cols, wrgvs_cols
from deeprob.spn.learning.splitting.entropy import entropy_cols, entropy_adaptive_cols
from deeprob.spn.learning.splitting.gini import gini_cols, gini_adaptive_cols
from deeprob.spn.learning.splitting.rdc import rdc_cols
from deeprob.spn.learning.splitting.random import random_cols

#: A signature for a columns splitting function.
SplitColsFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     np.random.RandomState,     # The random state
     Any],                      # Other arguments
    np.ndarray                  # The columns ids
]


[docs]def split_cols_clusters( data: np.ndarray, clusters: np.ndarray, scope: List[int] ) -> Tuple[List[np.ndarray], List[List[int]]]: """ Split the data vertically given the clusters. :param data: The data. :param clusters: The clusters. :param scope: The original scope. :return: (slices, scopes) where slices is a list of partial data and scopes is a list of partial scopes. """ slices = list() scopes = list() scope = np.asarray(scope) unique_clusters = np.unique(clusters) for c in unique_clusters: cols = (clusters == c) slices.append(data[:, cols]) scopes.append(scope[cols].tolist()) return slices, scopes
[docs]def get_split_cols_method(split_cols: str) -> SplitColsFunc: """ Get the columns splitting method given a string. :param split_cols: The string of the method do get. :return: The corresponding columns splitting function. :raises ValueError: If the columns splitting method is unknown. """ if split_cols == 'gvs': return gvs_cols if split_cols == 'rgvs': return rgvs_cols if split_cols == 'wrgvs': return wrgvs_cols if split_cols == 'ebvs': return entropy_cols if split_cols == 'ebvs_ae': return entropy_adaptive_cols if split_cols == 'gbvs': return gini_cols if split_cols == 'gbvs_ag': return gini_adaptive_cols if split_cols == 'rdc': return rdc_cols if split_cols == 'random': return random_cols raise ValueError("Unknown split rows method called {}".format(split_cols))