Source code for deeprob.spn.learning.em

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import Optional

import numpy as np
from tqdm import tqdm

from deeprob.context import ContextState
from deeprob.utils.random import RandomState, check_random_state
from deeprob.spn.utils.filter import filter_nodes_by_type
from deeprob.spn.utils.validity import check_spn
from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.structure.node import Node, Sum
from deeprob.spn.algorithms.inference import log_likelihood
from deeprob.spn.algorithms.gradient import eval_backward


[docs]def expectation_maximization( root: Node, data: np.ndarray, num_iter: int = 100, batch_perc: float = 0.1, step_size: float = 0.5, random_init: bool = True, random_state: Optional[RandomState] = None, verbose: bool = True ) -> Node: """ Learn the parameters of a SPN by batch Expectation-Maximization (EM). See https://arxiv.org/abs/1604.07243 and https://arxiv.org/abs/2004.06231 for details. :param root: The spn structure. :param data: The data to use to learn the parameters. :param num_iter: The number of iterations. :param batch_perc: The percentage of data to use for each step. :param step_size: The step size for batch EM. :param random_init: Whether to random initialize the weights of the SPN. :param random_state: The random state. It can be either None, a seed integer or a Numpy RandomState. :param verbose: Whether to enable verbose learning. :return: The spn with learned parameters. :raises ValueError: If a parameter is out of domain. """ if num_iter <= 0: raise ValueError("The number of iterations must be positive") if batch_perc <= 0.0 or batch_perc >= 1.0: raise ValueError("The batch percentage must be in (0, 1)") if step_size <= 0.0 or step_size >= 1.0: raise ValueError("The step size must be in (0, 1)") # Check the SPN check_spn(root, labeled=True, smooth=True, decomposable=True) # Compute the batch size n_samples = len(data) batch_size = int(batch_perc * n_samples) # Compute a list-based cache for accessing nodes cached_nodes = { 'sum': filter_nodes_by_type(root, Sum), 'leaf': filter_nodes_by_type(root, Leaf) } # Check the random state random_state = check_random_state(random_state) # Random initialize the parameters of the SPN, if specified if random_init: # Initialize the sum parameters for node in cached_nodes['sum']: node.em_init(random_state) # Initialize the leaf parameters for node in cached_nodes['leaf']: node.em_init(random_state) # Initialize the tqdm bar, if verbose is specified iterator = range(num_iter) if verbose: iterator = tqdm( iterator, leave=None, unit='batch', bar_format='{desc}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' ) for _ in iterator: # Sample a batch of data randomly with uniform distribution batch_indices = random_state.choice(n_samples, size=batch_size, replace=False) batch_data = data[batch_indices] # Prevent checking the SPN at every forward inference step, we already did that! with ContextState(check_spn=False): # Forward step, obtaining the LLs at each node root_ll, lls = log_likelihood(root, batch_data, return_results=True) mean_ll = np.mean(root_ll) # Backward step, compute the log-gradients required to compute the sufficient statistics grads = eval_backward(root, lls) # Update the weights of each sum node for node in cached_nodes['sum']: children_ll = lls[list(map(lambda c: c.id, node.children))] stats = np.exp(children_ll - root_ll + grads[node.id]) node.em_step(stats, step_size) # Update the parameters of each leaf node for node in cached_nodes['leaf']: stats = np.exp(lls[node.id] - root_ll + grads[node.id]) node.em_step(stats, batch_data[:, node.scope], step_size) # Update the progress bar if verbose: iterator.set_description('Batch Avg. LL: {:.4f}'.format(mean_ll)) return root