Source code for deeprob.spn.utils.validity

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import Optional, List, Tuple

from deeprob.context import is_check_spn_enabled
from deeprob.spn.structure.node import Node, Sum, Product
from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.structure.cltree import BinaryCLT
from deeprob.spn.utils.filter import collect_nodes


[docs]def check_spn( root: Node, labeled: bool = True, smooth: bool = False, decomposable: bool = False, structured_decomposable: bool = False ): """ Check a SPN have certain properties. Defaults to checking only 'labeled'. This function combines several checks over a SPN, hence reducing the computational effort used to retrieve the nodes from the SPN. :param root: The root node of the SPN. :param labeled: Whether to check if the SPN is correctly labeled. :param smooth: Whether to check if the SPN is smooth. :param decomposable: Whether to check if the SPN is decomposable. :param structured_decomposable: Whether to check if the SPN is structured decomposable. :raises ValueError: If the SPN doesn't have a certain property. """ if not is_check_spn_enabled(): # Skip the checks entirely, if specified return # Collect the nodes starting from the root node (cache) nodes = collect_nodes(root) # Check the SPN nodes are correctly labeled if labeled: result = is_labeled(root, nodes=nodes) if result is not None: raise ValueError(f"SPN is not correctly labeled: {result}") # Check the SPN is smooth if smooth: result = is_smooth(root, nodes=nodes) if result is not None: raise ValueError(f"SPN is not smooth: {result}") # Check the SPN is decomposable if decomposable: result = is_decomposable(root, nodes=nodes) if result is not None: raise ValueError(f"SPN is not decomposable: {result}") # Check the SPN is structured decomposable if structured_decomposable: result = is_structured_decomposable(root, nodes=nodes) if result is not None: raise ValueError(f"SPN is not structured decomposable: {result}")
[docs]def is_labeled(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]: """ Check if the SPN is labeled correctly. It checks that the initial id is zero and each id is consecutive. :param root: The root of the SPN. :param nodes: The list of nodes. If None, it will be retrieved starting from the root node. :return: None if the SPN is labeled correctly, a reason otherwise. """ if nodes is None: nodes = collect_nodes(root) ids = set(map(lambda n: n.id, nodes)) if None in ids: return "Some nodes have missing ids" if len(ids) != len(nodes): return "Some nodes have repeated ids" if min(ids) != 0: return "Node ids are not starting at 0" if max(ids) != len(ids) - 1: return "Node ids are not consecutive" return None
[docs]def is_smooth(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]: """ Check if the SPN is smooth (or complete). It checks that each child of a sum node has the same scope. :param root: The root of the SPN. :param nodes: The list of nodes. If None, it will be retrieved starting from the root node. :return: None if the SPN is smooth, a reason otherwise. """ if nodes is None: nodes = collect_nodes(root) sum_nodes: List[Sum] = list(filter(lambda n: isinstance(n, Sum), nodes)) for node in sum_nodes: if len(node.children) == 0: return f"Sum node #{node.id} has no children" if len(node.children) != len(node.weights): return f"Weights and children length mismatch in node #{node.id}" if any(map(lambda c: set(c.scope) != set(node.scope), node.children)): return f"Children of Sum node #{node.id} have different scopes" return None
[docs]def is_decomposable(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]: """ Check if the SPN is decomposable (or consistent). It checks that each child of a product node has disjointed scopes. :param root: The root of the SPN. :param nodes: The list of nodes. If None, it will be retrieved starting from the root node. :return: None if the SPN is decomposable, a reason otherwise. """ if nodes is None: nodes = collect_nodes(root) product_nodes: List[Product] = list(filter(lambda n: isinstance(n, Product), nodes)) for node in product_nodes: if len(node.children) == 0: return f"Product node #{node.id} has no children" s_scope = set(sum([c.scope for c in node.children], [])) if set(node.scope) != s_scope: return f"Children of Product node #{node.id} don't have disjointed scopes" return None
[docs]def is_structured_decomposable(root: Node, nodes: Optional[List[Node]] = None) -> Optional[str]: """ Check if the PC is structured decomposable. It checks that product nodes follow a vtree. Note that if a PC is structured decomposable then it's also decomposable. :param root: The root of the PC. :param nodes: The list of nodes. If None, it will be retrieved starting from the root node. :return: None if the PC is structured decomposable, a reason otherwise. """ # Shortcut: a PC is structured decomposable if it is compatible with itself if nodes is None: nodes = collect_nodes(root) return are_compatible(root, root, nodes_a=nodes, nodes_b=nodes)
[docs]def are_compatible( root_a: Node, root_b: Node, nodes_a: Optional[List[Node]] = None, nodes_b: Optional[List[Node]] = None ) -> Optional[str]: """ Check if two PCs are compatible. :param root_a: The root of the first PC. :param root_b: The root of the second PC. :param nodes_a: The list of nodes of the first PC. If None, it will be retrieved starting from the root node. :param nodes_b: The list of nodes of the second PC. If None, it will be retrieved starting from the root node. :return: None if the two PCs are compatible, a reason otherwise. """ if nodes_a is None: nodes_a = collect_nodes(root_a) if nodes_b is None: nodes_b = collect_nodes(root_b) # Check smoothness and decomposability first res = is_smooth(root_a, nodes_a) if res is not None: return f'First PC: {res}' res = is_decomposable(root_a, nodes_a) if res is not None: return f'First PC: {res}' res = is_smooth(root_b, nodes_b) if res is not None: return f'Second PC: {res}' res = is_decomposable(root_b, nodes_b) if res is not None: return f'Second PC: {res}' # Get scopes as sets scopes_a = collect_scopes(nodes_a) scopes_b = collect_scopes(nodes_b) scopes_a = list(map(lambda s: set(s), scopes_a)) scopes_b = list(map(lambda s: set(s), scopes_b)) # Quadratic in the number of product nodes for s1 in scopes_a: for s2 in scopes_b: int_len = len(s1.intersection(s2)) if int_len != 0 and int_len != min(len(s1), len(s2)): return f"Incompatibility found between scope {s1} and scope {s2}" return None
[docs]def collect_scopes(nodes: List[Node]) -> List[Tuple[int]]: """ Collect the scopes of each node. :param nodes: The list of nodes. :return: A list of scopes. """ scopes = list() for n in nodes: if isinstance(n, Product): scopes.append(tuple(sorted(n.scope))) elif isinstance(n, BinaryCLT): scopes.extend([tuple(sorted(scope)) for scope in n.get_scopes()]) elif not isinstance(n, Sum) and not isinstance(n, Leaf): raise NotImplementedError(f"Case not considered for {type(n)} nodes") return scopes