Source code for deeprob.spn.learning.splitting.rows

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala, Federico Luzzi

from typing import Union, Type, Tuple, List, Callable, Any

import numpy as np

from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.learning.splitting.cluster import gmm, kmeans, kmeans_mb, dbscan, wald
from deeprob.spn.learning.splitting.rdc import rdc_rows
from deeprob.spn.learning.splitting.random import random_rows

#: A signature for a rows splitting function.
SplitRowsFunc = Callable[
    [np.ndarray,                # The data
     List[Type[Leaf]],          # The distributions
     List[Union[list, tuple]],  # The domains
     np.random.RandomState,     # The random state
     Any],                      # Other arguments
    np.ndarray                  # The rows ids
]


[docs]def split_rows_clusters( data: np.ndarray, clusters: np.ndarray ) -> Tuple[List[np.ndarray], List[float]]: """ Split the data horizontally given the clusters. :param data: The data. :param clusters: The clusters. :return: (slices, weights) where slices is a list of partial data and weights is a list of proportions of the local data in respect to the original data. """ slices = list() weights = list() n_samples = len(data) unique_clusters = np.unique(clusters) for c in unique_clusters: local_data = data[clusters == c, :] slices.append(local_data) weights.append(len(local_data) / n_samples) return slices, weights
[docs]def get_split_rows_method(split_rows: str) -> SplitRowsFunc: """ Get the rows splitting method given a string. :param split_rows: The string of the method do get. :return: The corresponding rows splitting function. :raises ValueError: If the rows splitting method is unknown. """ if split_rows == 'kmeans': return kmeans if split_rows == 'kmeans_mb': return kmeans_mb if split_rows == 'dbscan': return dbscan if split_rows == 'wald': return wald if split_rows == 'gmm': return gmm if split_rows == 'rdc': return rdc_rows if split_rows == 'random': return random_rows raise ValueError("Unknown split rows method called {}".format(split_rows))