Skip to content
Snippets Groups Projects
Commit 526e1c91 authored by crides's avatar crides
Browse files

add(AnalysisTree): `visualize_dot`

- Adds `graphviz` to dependencies
- Uses the `graphviz` package for visualizing the AnalysisTree (needs
  the `dot` etc. commands installed, usually available from the
  `graphviz` on most Linuxes)
parent b0d339df
No related branches found
No related tags found
No related merge requests found
...@@ -16,4 +16,5 @@ torch~=1.12.1 ...@@ -16,4 +16,5 @@ torch~=1.12.1
tqdm~=4.64.1 tqdm~=4.64.1
z3-solver~=4.8.17.0 z3-solver~=4.8.17.0
treelib~=1.6.1 treelib~=1.6.1
portion~=2.3.1 portion~=2.3.1
\ No newline at end of file graphviz~=0.20
...@@ -48,6 +48,7 @@ setup( ...@@ -48,6 +48,7 @@ setup(
"z3-solver~=4.8.17.0", "z3-solver~=4.8.17.0",
"treelib~=1.6.1", "treelib~=1.6.1",
"portion~=2.3.1", "portion~=2.3.1",
"graphviz~=0.20",
], ],
classifiers=[ classifiers=[
'Development Status :: 2 - Pre-Alpha', 'Development Status :: 2 - Pre-Alpha',
......
from functools import reduce from functools import reduce
import pickle import pickle
from typing import Iterable, List, Dict, Any, Optional, Tuple, TypeVar from typing import Iterable, List, Dict, Any, Optional, Tuple, TypeVar, Literal
import json import json
from treelib import Tree from treelib import Tree
import numpy.typing as nptyp, numpy as np, portion import numpy.typing as nptyp, numpy as np, portion
...@@ -8,7 +8,8 @@ import numpy.typing as nptyp, numpy as np, portion ...@@ -8,7 +8,8 @@ import numpy.typing as nptyp, numpy as np, portion
from verse.analysis.dryvr import _EPSILON from verse.analysis.dryvr import _EPSILON
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib as mpl, matplotlib.pyplot as plt
import graphviz
TraceType = nptyp.NDArray[np.float_] TraceType = nptyp.NDArray[np.float_]
...@@ -113,6 +114,9 @@ class AnalysisTreeNode: ...@@ -113,6 +114,9 @@ class AnalysisTreeNode:
type = data['type'], type = data['type'],
) )
def color_interp(c1: str, c2: str, mix: float) -> str:
return mpl.colors.to_hex((1 - mix) * np.array(mpl.colors.to_rgb(c1)) + mix * np.array(mpl.colors.to_rgb(c2)))
class AnalysisTree: class AnalysisTree:
def __init__(self, root): def __init__(self, root):
self.root:AnalysisTreeNode = root self.root:AnalysisTreeNode = root
...@@ -181,6 +185,7 @@ class AnalysisTree: ...@@ -181,6 +185,7 @@ class AnalysisTree:
nid = AnalysisTree._dump_tree(child, tree, id, nid) nid = AnalysisTree._dump_tree(child, tree, id, nid)
return nid + 1 return nid + 1
# TODO Generalize to different timesteps
def contains(self, other: "AnalysisTree", strict: bool = True, tol: Optional[float] = None) -> bool: def contains(self, other: "AnalysisTree", strict: bool = True, tol: Optional[float] = None) -> bool:
""" """
Returns, for reachability, whether the current tree fully contains the other tree or not; Returns, for reachability, whether the current tree fully contains the other tree or not;
...@@ -243,7 +248,16 @@ class AnalysisTree: ...@@ -243,7 +248,16 @@ class AnalysisTree:
# bloat and containment # bloat and containment
return all(other_tree[aid][i][j] in this_tree[aid][i][j].apply(lambda x: x.replace(lower=lambda v: v - tol, upper=lambda v: v + tol)) for aid in other_agents for i in range(total_len) for j in range(cont_num)) return all(other_tree[aid][i][j] in this_tree[aid][i][j].apply(lambda x: x.replace(lower=lambda v: v - tol, upper=lambda v: v + tol)) for aid in other_agents for i in range(total_len) for j in range(cont_num))
@staticmethod
def _get_len(node: AnalysisTreeNode, lens: Dict[int, int]) -> int:
res = len(next(iter(node.trace.values()))) + (0 if len(node.child) == 0 else max(AnalysisTree._get_len(c, lens) for c in node.child))
lens[node.id] = res
return res
def visualize(self): def visualize(self):
lens = {}
total_len = AnalysisTree._get_len(self.root, lens)
gradient = lambda id: color_interp("red", "blue", lens[id] / total_len)
G = nx.Graph() G = nx.Graph()
for node in self.nodes: for node in self.nodes:
G.add_node(node.id,time=(node.id,node.start_time)) G.add_node(node.id,time=(node.id,node.start_time))
...@@ -251,9 +265,33 @@ class AnalysisTree: ...@@ -251,9 +265,33 @@ class AnalysisTree:
G.add_node(child.id,time=(child.id,child.start_time)) G.add_node(child.id,time=(child.id,child.start_time))
G.add_edge(node.id, child.id) G.add_edge(node.id, child.id)
labels = nx.get_node_attributes(G, 'time') labels = nx.get_node_attributes(G, 'time')
nx.draw_planar(G,labels=labels) colors = [gradient(id) for id in G]
nx.draw_planar(G, node_color=colors, labels=labels)
plt.show() plt.show()
def visualize_dot(self, filename: str, otype: Literal["png", "svg", "pdf", "jpg"] = "png", font: Optional[str] = None):
"""
`filename` is the prefix, i.e. doesn't include extensions. `filename.dot` will be saved as well as `filename.png`
"""
def diff(a: AnalysisTreeNode, b: AnalysisTreeNode) -> List[str]:
return [aid for aid in a.agent if a.mode[aid] != b.mode[aid]]
lens = {}
total_len = AnalysisTree._get_len(self.root, lens)
gradient = lambda id: color_interp("red", "blue", lens[id] / total_len)
graph = graphviz.Digraph()
for node in self.nodes:
tooltip = "\n".join(f"{aid}: {[*node.mode[aid], *node.init[aid]]}" for aid in node.agent)
graph.node(str(node.id), label=str(node.id), color=gradient(node.id), tooltip=tooltip)
for c in node.child:
d = diff(node, c)
tooltip = "\n".join(f"{aid}: {node.mode[aid]} -> {c.mode[aid]}" for aid in d)
graph.edge(str(node.id), str(c.id), label=", ".join(d), tooltip=tooltip)
if font != None:
graph.node_attr.update(fontname=font)
graph.edge_attr.update(fontname=font)
graph.graph_attr.update(fontname=font)
graph.render(filename + ".dot", format=otype, outfile=filename + "." + otype, engine="twopi")
def is_equal(self, other:"AnalysisTree"): def is_equal(self, other:"AnalysisTree"):
return self.contains(other) and other.contains(self) return self.contains(other) and other.contains(self)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment