Source code for deeprob.spn.structure.node

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from __future__ import annotations
import abc
from typing import Optional, Union, List, Iterator
from collections import deque, defaultdict

import numpy as np
from scipy.special import logsumexp


[docs]class Node(abc.ABC): def __init__(self, scope: List[int], children: Optional[List[Node]] = None): """ Initialize a SPN node given the children list and its scope. :param scope: The scope. :param children: A list of nodes. If None, children are initialized as an empty list. :raises ValueError: If the scope is empty. :raises ValueError: If the scope contains duplicates. """ if not scope: raise ValueError("The scope must not be empty") if len(scope) != len(set(scope)): raise ValueError("The scope must not contain duplicates") if children is None: children = list() self.id = 0 self.scope = scope self.children = children
[docs] @abc.abstractmethod def likelihood(self, x: np.ndarray) -> np.ndarray: """ Compute the likelihood of the node given some input. :param x: The inputs. :return: The resulting likelihoods. """
[docs] @abc.abstractmethod def log_likelihood(self, x: np.ndarray) -> np.ndarray: """ Compute the logarithmic likelihood of the node given some input. :param x: The inputs. :return: The resulting log-likelihoods. """
[docs]class Sum(Node): def __init__( self, scope: Optional[List[int]] = None, children: Optional[List[Node]] = None, weights: Optional[Union[List[float], np.ndarray]] = None, ): """ Initialize a SPN sum node given a list of children and their weights and a scope. :param scope: The scope. If None, the scope is initialized based on children scopes. :param children: A list of nodes. If None, children are initialized as an empty list. :param weights: The weights associated to each children node. It can be None. :raises ValueError: If both scope and children are None. :raises ValueError: If children nodes have different scopes. :raises ValueError: If the length of weights and children are different. :raises ValueError: If weights don't sum up to 1. """ if children is None: if scope is None: raise ValueError("Cannot infer Sum node's scope without children") else: if scope is None: scope = children[0].scope s_scope = set(scope) if any(map(lambda c: set(c.scope) != s_scope, children[1:])): raise ValueError("Children of Sum node have different scopes") if weights is not None and len(weights) != len(children): raise ValueError("Weights and children length mismatch") if weights is not None: if isinstance(weights, list): weights = np.array(weights, dtype=np.float32) if not np.isclose(np.sum(weights), 1.0): raise ValueError("Weights don't sum up to 1") self.weights = weights super().__init__(scope, children)
[docs] def em_init(self, random_state: np.random.RandomState): """ Random initialize the node's parameters for Expectation-Maximization (EM). :param random_state: The random state. """ weights = random_state.dirichlet(np.ones(len(self.children))) self.weights = weights.astype(np.float32)
[docs] def em_step(self, stats: np.ndarray, step_size: float): """ Compute a batch Expectation-Maximization (EM) step. :param stats: The sufficient statistics of each sample. :param step_size: The step size of update. """ unnorm_weights = self.weights * np.sum(stats, axis=1) + np.finfo(np.float32).eps weights = unnorm_weights / np.sum(unnorm_weights) # Update the parameters self.weights = (1.0 - step_size) * self.weights + step_size * weights
[docs] def likelihood(self, x: np.ndarray) -> np.ndarray: return np.expand_dims(np.dot(x, self.weights), axis=1)
[docs] def log_likelihood(self, x: np.ndarray) -> np.ndarray: return logsumexp(x, b=self.weights, axis=1, keepdims=True)
[docs]class Product(Node): def __init__( self, scope: Optional[List[int]] = None, children: Optional[List[Node]] = None ): """ Initialize a product node given a list of children and its scope. :param scope: The scope. If None, the scope is initialized based on children scopes. :param children: A list of nodes. If None, children are initialized as an empty list. :raises ValueError: If both scope and children are None. :raises ValueError: If children nodes don't have disjointed scopes. """ if children is None: if scope is None: raise ValueError("Cannot infer Product node's scope without children") else: c_scope = list(sum([c.scope for c in children], [])) s_scope = set(c_scope) if scope is None: if len(c_scope) != len(s_scope): raise ValueError("Children of Product node don't have disjointed scopes") scope = c_scope elif set(scope) != s_scope: raise ValueError("Children of Product node don't have disjointed scopes") super().__init__(scope, children)
[docs] def likelihood(self, x: np.ndarray) -> np.ndarray: return np.prod(x, axis=1, keepdims=True)
[docs] def log_likelihood(self, x: np.append) -> np.ndarray: return np.sum(x, axis=1, keepdims=True)
[docs]def assign_ids(root: Node) -> Node: """ Assign the ids to the nodes of a SPN. :param root: The root of the SPN. :return: The same SPN with each node having modified ids. :raises ValueError: If the SPN structure is not a DAG. """ nodes = topological_order(root) if nodes is None: raise ValueError("SPN structure is not a directed acyclic graph (DAG)") next_id = 0 for node in nodes: node.id = next_id next_id += 1 return root
[docs]def bfs(root: Node) -> Iterator[Node]: """ Compute the Breadth First Search (BFS) ordering for a SPN. :param root: The root of the SPN. :return: The BFS nodes iterator. """ seen, queue = {root}, deque([root]) while queue: node = queue.popleft() yield node for c in node.children: if c not in seen: seen.add(c) queue.append(c)
[docs]def dfs_post_order(root: Node) -> Iterator[Node]: """ Compute Depth First Search (DFS) Post-Order ordering for a SPN. :param root: The root of the SPN. :return: The DFS Post-Order nodes iterator. """ seen, stack = {root}, [root] while stack: node = stack[-1] if set(node.children).issubset(seen): stack.pop() yield node continue for c in node.children: if c not in seen: seen.add(c) stack.append(c)
[docs]def topological_order(root: Node) -> Optional[List[Node]]: """ Compute the Topological Ordering for a SPN, using the Kahn's Algorithm. :param root: The root of the SPN. :return: A list of nodes that form a topological ordering. If the SPN graph is not acyclic, it returns None. """ ordering = list() num_outgoings = defaultdict(int) num_outgoings[root] = 0 # Initialize the number of outgoings edges for each node for node in bfs(root): for c in node.children: num_outgoings[c] += 1 # Check the unusual case where the root node have outgoings edges, i.e. a trivial cycle has been found if num_outgoings[root] != 0: return None # Non-layered topological ordering implementation queue = deque([root]) while queue: node = queue.popleft() ordering.append(node) for c in node.children: num_outgoings[c] -= 1 if num_outgoings[c] == 0: queue.append(c) # Check if a cycle has been found if sum(num_outgoings.values()) != 0: return None return ordering
[docs]def topological_order_layered(root: Node) -> Optional[List[List[Node]]]: """ Compute the Topological Ordering Layered for a SPN, using the Kahn's Algorithm. :param root: The root of the SPN. :return: A list of layers that form a topological ordering. If the SPN graph is not acyclic, it returns None. """ ordering = list() num_outgoings = defaultdict(int) num_outgoings[root] = 0 # Initialize the number of outgoings edges for each node for node in bfs(root): for c in node.children: num_outgoings[c] += 1 # Check the unusual case where the root node have outgoings edges, i.e. a trivial cycle has been found if num_outgoings[root] != 0: return None # Layered topological ordering implementation ordering.append([root]) while True: layer = list() for node in ordering[-1]: for c in node.children: num_outgoings[c] -= 1 if num_outgoings[c] == 0: layer.append(c) if not layer: break ordering.append(layer) # Check if a cycle has been found if sum(num_outgoings.values()) != 0: return None return ordering