# MIT License: Copyright (c) 2021 Lorenzo Loconte, Gennaro Gala
import os
import json
from typing import Optional, Union, Type, List, Dict, IO
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from networkx.drawing import nx_pydot
from networkx.drawing.layout import rescale_layout_dict
from networkx.algorithms.tree import is_arborescence
from networkx.algorithms.dag import is_directed_acyclic_graph
from networkx.algorithms.traversal import bfs_predecessors
from networkx.readwrite.json_graph import node_link_data, node_link_graph
from deeprob.spn.structure.node import Node, Sum, Product, topological_order
from deeprob.spn.structure.leaf import Leaf, Bernoulli, Categorical, Isotonic, Uniform, Gaussian
from deeprob.spn.structure.cltree import BinaryCLT
[docs]def save_digraph_json(graph: nx.DiGraph, f: Union[IO, os.PathLike, str]):
"""
Save a NetworkX directed graph by using the JSON format.
:param graph: The NetworkX directed graph.
:param f: A file-like object or a filepath of the output JSON file.
"""
# Obtain the JSON object to serialize
json_obj = json.dumps(node_link_data(graph))
# Save the object
if isinstance(f, (os.PathLike, str)):
with open(f, 'w', encoding='utf-8') as file:
file.write(json_obj)
else:
f.write(json_obj)
[docs]def load_digraph_json(f: Union[IO, os.PathLike, str]) -> nx.DiGraph:
"""
Load a NetworkX directed graph by using the JSON format.
:param f: A file-like object or a filepath of the input JSON file.
:return: The NetworkX directed graph.
"""
# Load the object
if isinstance(f, (os.PathLike, str)):
with open(f, 'r', encoding='utf-8') as file:
json_obj = json.load(file)
else:
json_obj = json.load(f)
# Obtain the NetworkX graph
graph = node_link_graph(json_obj, directed=True, multigraph=False)
return graph
[docs]def save_spn_json(root: Node, f: Union[IO, os.PathLike, str]):
"""
Save SPN to file by using the JSON format.
:param root: The root node of the SPN.
:param f: A file-like object or a filepath of the output JSON file.
"""
# Convert the SPN to a NetworkX graph
graph = spn_to_digraph(root)
# Save the NetworkX graph
save_digraph_json(graph, f)
[docs]def load_spn_json(f: Union[IO, os.PathLike, str], leaves: Optional[List[Type[Leaf]]] = None) -> Node:
"""
Load SPN from file by using the JSON format.
:param f: A file-like object or a filepath of the input JSON file.
:param leaves: An optional list of custom leaf classes. Useful when dealing with user-defined leaves.
:return: The loaded SPN with initialied ids for each node.
:raises ValueError: If multiple custom leaf classes with the same name are defined.
"""
# Set the default leaf classes map
leaf_map: Dict[str, Type[Leaf]] = {
cls.__name__: cls
for cls in [
Bernoulli, Categorical, Isotonic, Uniform, Gaussian, BinaryCLT
]
}
# Augment the leaf mapper dictionary, if custom leaf classes are defined
if leaves is not None:
for cls in leaves:
name = cls.__name__
if name in leaf_map:
raise ValueError("Custom leaf class {} already defined".format(name))
leaf_map[name] = cls
# Load the NetworkX graph
graph = load_digraph_json(f)
# Convert the NetworkX graph to a SPN
return digraph_to_spn(graph, leaf_map)
[docs]def save_binary_clt_json(clt: BinaryCLT, f: Union[IO, os.PathLike, str]):
"""
Save Binary Chow-Liu Tree (CLT) to file by using the JSON format.
:param clt: The binary CLT.
:param f: A file-like object or a filepath of the output JSON file.
"""
# Convert the CLT to a NetworkX digraph
graph = binary_clt_to_digraph(clt)
# Save the NetworkX graph
save_digraph_json(graph, f)
[docs]def load_binary_clt_json(f: Union[IO, os.PathLike, str]) -> BinaryCLT:
"""
Load Binary Chow-Liu Tree (CLT) from file by using the JSON format.
:param f: A file-like object or a filepath of the input JSON file.
:return: The loaded binary CLT.
"""
# Load the NetworkX graph
graph = load_digraph_json(f)
# Convert the NetworkX graph to a binary CLT
return digraph_to_binary_clt(graph)
[docs]def spn_to_digraph(root: Node) -> nx.DiGraph:
"""
Convert a SPN to a NetworkX directed graph.
:param root: The root node of the SPN.
:return: The corresponding NetworkX directed graph.
:raises ValueError: If the SPN structure is not a directed acyclic graph (DAG).
"""
# Check the SPN
nodes = topological_order(root)
if nodes is None:
raise ValueError("SPN structure is not a directed acyclic graph (DAG)")
graph = nx.DiGraph()
# Add nodes to the graph
for node in nodes:
if isinstance(node, Sum):
weights = [round(float(w), 8) for w in node.weights]
attr = {'class': Sum.__name__, 'scope': node.scope, 'weights': weights}
elif isinstance(node, Product):
attr = {'class': Product.__name__, 'scope': node.scope}
elif isinstance(node, Leaf):
params = node.params_dict()
for name, value in params.items():
if isinstance(value, np.ndarray): # Convert Numpy arrays into lists
if value.dtype in [np.float32, np.float64]:
value = value.astype(np.float64)
params[name] = np.around(value, 8).tolist()
else:
params[name] = value.tolist()
elif isinstance(value, (np.float32, np.float64)): # Convert Numpy floats into Python float
params[name] = round(float(value), 8)
elif isinstance(value, float): # Round Python floats
params[name] = round(value, 8)
attr = {'class': node.__class__.__name__, 'scope': node.scope, 'params': params}
else:
raise ValueError("Unknown node of type {}".format(node.__class__.__name__))
graph.add_node(node.id, **attr)
# Add edges to the graph
for node in nodes:
for i, c in enumerate(node.children):
graph.add_edge(c.id, node.id, idx=i)
return graph
[docs]def digraph_to_spn(graph: nx.DiGraph, leaf_map: Dict[str, Type[Leaf]]) -> Node:
"""
Convert a NetworkX directed graph to a SPN.
:param graph: The NetworkX directed graph.
:param leaf_map: The leaf distributions mapper dictionary.
:return: The corresponding SPN.
:raises ValueError: If the graph is not a directed acyclic graph (DAG).
"""
# Check the graph
if not is_directed_acyclic_graph(graph):
raise ValueError("The graph is not a directed acyclic graph (DAG)")
nodes: Dict[int, Leaf] = dict()
# Instantiate the nodes in the graph
for node_id in graph.nodes:
attr = graph.nodes[node_id]
name = attr['class']
scope = attr['scope']
if name == Sum.__name__:
node = Sum(scope, weights=attr['weights'])
elif name == Product.__name__:
node = Product(scope)
elif name in leaf_map:
node = leaf_map[name](scope, **attr['params'])
else:
raise ValueError("Unknown node of type {}".format(name))
node.id = node_id
nodes[node_id] = node
# Build the edges between the nodes as parent-children dependencies
for child_id, parent_id in graph.edges:
idx = graph.edges[child_id, parent_id]['idx']
parent_node = nodes[parent_id]
n_children = len(parent_node.children)
if idx >= n_children:
parent_node.children.extend([None] * (idx - n_children + 1))
parent_node.children[idx] = nodes[child_id]
# Get the root of the SPN
return nodes[0]
[docs]def binary_clt_to_digraph(clt: BinaryCLT) -> nx.DiGraph:
"""
Convert a binary Chow-Liu Tree (CLT) to a NetworkX directed graph.
:param clt: The binary CLT.
:return: The corresponding NetworkX directed graph.
:raises ValueError: If the CLT is not initialized.
"""
if clt.tree is None:
raise ValueError("The CLT's structure must be already initialized")
graph = nx.DiGraph()
# Add nodes to the graph
for node_id in range(len(clt.tree)):
weight = np.around(clt.params[node_id].astype(np.float64), 8).tolist()
attr = {'scope': clt.scope[node_id], 'weight': weight}
graph.add_node(int(node_id), **attr)
# Add edges to the graph
for node_id, parent_node_id in enumerate(clt.tree):
if parent_node_id != -1:
graph.add_edge(int(parent_node_id), node_id)
return graph
[docs]def digraph_to_binary_clt(graph: nx.DiGraph) -> BinaryCLT:
"""
Convert a NetworkX directed graph to a binary Chow-Liu Tree (CLT).
:param graph: The NetworkX directed graph.
:return: The corresponding Chow-Liu Tree.
:raises ValueError: If the graph is not a tree.
"""
# Check the graph and get the root id
if not is_arborescence(graph):
raise ValueError("The graph is not a tree")
root_id = next(node_id for node_id, c in graph.in_degree() if c == 0)
scope: list = [None] * len(graph)
tree: list = [None] * len(graph)
params: list = [None] * len(graph)
# Include the information about the root node
attr = graph.nodes[root_id]
scope[root_id] = attr['scope']
tree[root_id] = -1
params[root_id] = attr['weight']
# Proceed by BFS starting from the root node
for node_id, parent_id in bfs_predecessors(graph, source=root_id):
attr = graph.nodes[node_id]
tree[node_id] = parent_id
scope[node_id] = attr['scope']
params[node_id] = attr['weight']
# Instantiate a Binary CLT
return BinaryCLT(scope, tree=tree, params=params)
[docs]def plot_spn(root: Node, f: Union[IO, os.PathLike, str]):
"""
Plot a SPN into file.
:param root: The SPN root node.
:param f: A file-like object or a filepath of the output file.
:raises ValueError: If an unknown node type is found.
:raises ValueError: If the SPN structure is not a DAG.
"""
# Convert the SPN to a NetworkX directed graph
graph = spn_to_digraph(root)
# Build the dictionaries of node labels and colors
labels = dict()
colors = dict()
for node_id in graph.nodes:
attr = graph.nodes[node_id]
name = attr['class']
if name == Sum.__name__:
label = '+'
color = '#083d77'
for child_id, _ in graph.in_edges(node_id):
idx = graph.edges[child_id, node_id]['idx']
graph.edges[child_id, node_id]['weight'] = round(attr['weights'][idx], ndigits=2)
elif name == Product.__name__:
label = 'x'
color = '#bf3100'
else:
label = repr(attr['scope']).replace(',', '')
color = '#542188'
labels[node_id] = label
colors[node_id] = color
# Compute the nodes positions using PyDot + Graphviz
pos = nx_pydot.graphviz_layout(graph, prog='dot')
pos = {node_id: (x, -y) for node_id, (x, y) in pos.items()}
pos = rescale_layout_dict(pos)
# Set the figure size
figdim = np.maximum(2, np.sqrt(graph.number_of_nodes() + 2 * graph.number_of_edges()))
plt.figure(figsize=(figdim, figdim))
# Draw the nodes and edges
nx.draw_networkx(
graph, pos=pos, node_color=[colors[node_id] for node_id in graph.nodes],
labels=labels, arrows=True, font_size=8, font_color='#ffffff'
)
nx.draw_networkx_edge_labels(
graph, pos=pos, edge_labels=nx.get_edge_attributes(graph, 'weight'),
rotate=False, font_size=8, font_color='#000000'
)
# Plot the final figure
plt.tight_layout()
plt.axis('off')
plt.savefig(f, bbox_inches='tight', pad_inches=0)
plt.clf()
[docs]def plot_binary_clt(clt: BinaryCLT, f: Union[IO, os.PathLike, str], show_weights: bool = True):
"""
Plot a binary Chow-Liu Tree (CLT) into file.
:param clt: The binary CLT.
:param f: A file-like object or a filepath of the output file.
:param show_weights: Whether to show the conditional probability tables (CPTs).
"""
# Convert the CLT to a NetworkX directed graph
graph = binary_clt_to_digraph(clt)
# Build the dictionary of node labels
labels = dict()
for node_id in graph.nodes:
labels[node_id] = clt.scope[node_id]
# Compute the nodes positions using PyDot + Graphviz
pos = nx_pydot.graphviz_layout(graph, prog='dot')
pos = rescale_layout_dict(pos)
# Set the figure size
figdim = np.maximum(2, np.sqrt(graph.number_of_nodes() + 3 * graph.number_of_edges()))
plt.figure(figsize=(figdim, figdim))
# Draw the nodes and edges
nx.draw_networkx(
graph, pos=pos, node_color='#542188',
labels=labels, arrows=True, font_size=8, font_color='#ffffff'
)
if show_weights:
# Initialize the edges labels, using the CPTs
for node_id in graph.nodes:
attr = graph.nodes[node_id]
scope = attr['scope']
weight = attr['weight']
for child_id, _ in graph.in_edges(node_id):
cpt = np.around(np.exp(weight), 2)
label = "$P(X_{{{sc}}}|0)$ {:.2f} {:.2f}\n$P(X_{{{sc}}}|1)$ {:.2f} {:.2f}".format(
cpt[0, 0], cpt[0, 1], cpt[1, 0], cpt[1, 1], sc=scope
)
graph.edges[child_id, node_id]['weight'] = label
# Initialize the root node label, using the root CPT
root_id = next(node_id for node_id, c in graph.in_degree() if c == 0)
attr = graph.nodes[root_id]
scope = attr['scope']
weight = attr['weight']
cpt = np.around(np.exp(weight), 2)
root_label = "$P(X_{{{}}})$ {:.2f} {:.2f}".format(scope, cpt[0, 0], cpt[0, 1])
# Draw root node CPT and other nodes CPTs
cpt_style_kwargs = {
'font_size': 6, 'font_color': '#000000', 'font_family': 'monospace',
'bbox': {'boxstyle': 'round', 'ec': '#444444', 'fc': '#ffffff', 'pad': 0.2}
}
root_label_delta = 1.0 + 0.667 / figdim
nx.draw_networkx_labels(
graph, pos={root_id: (pos[root_id][0], pos[root_id][1] * root_label_delta)}, labels={root_id: root_label},
**cpt_style_kwargs
)
nx.draw_networkx_edge_labels(
graph, pos=pos, edge_labels=nx.get_edge_attributes(graph, 'weight'),
rotate=False, **cpt_style_kwargs
)
# Plot the final figure
plt.tight_layout()
plt.axis('off')
plt.savefig(f, bbox_inches='tight', pad_inches=0)
plt.clf()