Source code for deeprob.spn.utils.filter

# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala

from typing import Union, List, Tuple, Type

from deeprob.spn.structure.leaf import Leaf
from deeprob.spn.structure.node import Node, Sum, Product, bfs


[docs]def collect_nodes(root: Node) -> List[Node]: """ Get all the nodes in a SPN. :param root: The root of the SPN. :return: A list of nodes. """ return filter_nodes_by_type(root)
[docs]def filter_nodes_by_type( root: Node, ntype: Union[Type[Node], Tuple[Type[Node], ...]] = Node ) -> List[Union[Node, Leaf, Sum, Product]]: """ Get the nodes of some specified types in a SPN. :param root: The root of the SPN. :param ntype: The node type. Multiple node types can be specified as a tuple. :return: A list of nodes of some specific types. """ return list(filter(lambda n: isinstance(n, ntype), bfs(root)))