webber.viz
Visualization library for Webber DAGs.
1""" 2Visualization library for Webber DAGs. 3""" 4import sys as _sys 5import json as _json 6import types as _types 7import os.path as _path 8import typing as _typing 9import flask as _flask 10import networkx as _nx 11import webber.xcoms as _xcoms 12from webber.edges import Condition 13from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader 14 15# Wait for preload thread to finish to avoid circular imports with matplotlib/IPython 16# Uses non-blocking check first; only blocks if preload is still running 17import webber as _webber 18if not _webber._viz_ready.is_set(): 19 _webber.wait_for_viz_ready(timeout=3.0) 20 21import matplotlib.pyplot as _plt 22from pyvis.network import Network as _Network 23from netgraph import InteractiveGraph as _IGraph 24from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader 25 26__all__ = ["generate_pyvis_network", "visualize_plt", "visualize_browser"] 27 28edge_colors: _typing.Dict[Condition, str] = { 29 Condition.Success: 'grey', 30 Condition.AnyCase: 'blue', 31 Condition.Failure: 'red' 32} 33 34def edge_color(c: Condition): 35 """ 36 Given a Webber Condition, return corresponding color for edge visualizations. 37 """ 38 return edge_colors[c] 39 40def node_color(c: _typing.Callable): 41 """ 42 Given a callable, return a color that to be used in visualizations 43 mapping to the callable's type (lambda, function, built-in, class). 44 """ 45 _class = str(c.__class__).strip("<class '").rstrip("'>") 46 match _class: 47 case 'type': 48 return '#71C6B1' 49 case 'function': 50 return '#679AD1' if isinstance(c, _types.LambdaType) else '#DCDCAF' 51 case 'builtin_function_or_method': 52 return '#DCDCAF' 53 return '#AADAFB' 54 55def get_layers(graph: _nx.DiGraph) -> _typing.List[_typing.List[str]]: 56 """ 57 Generates ordered list of node identifiers given a directed network graph. 58 """ 59 layers = [] 60 for nodes in _nx.topological_generations(graph): 61 layers.append(nodes) 62 return layers 63 64def annotate_node(node: _typing.Dict[str, _typing.Any]) -> str: 65 """ 66 Given a Webber node, construct an annotation to be used in graph visualizations. 67 """ 68 args, kwargs = [], {} 69 for a in node['args']: 70 try: 71 args.append(_json.dumps(a)) 72 except: 73 if isinstance(a, _xcoms.Promise): 74 if isinstance(a.key, str): 75 name = a.key.split('__')[0] 76 else: 77 name = a.key.__name__ 78 args.append(f'Promise({name})') 79 else: 80 args.append(f'Object({str(a.__class__)})') 81 for k,v in node['kwargs'].items(): 82 try: 83 _json.dumps(k) 84 try: 85 kwargs[_json.dumps(k)] = _json.dumps(v) 86 except: 87 if isinstance(v, _xcoms.Promise): 88 if isinstance(v.key, str): 89 name = v.key.split('__')[0] 90 else: 91 name = v.key.__name__ 92 kwargs[k] = f"Promise('{name}')" 93 else: 94 kwargs[k] = f'Object({str(v.__class__)})' 95 except: 96 pass 97 node_title = f"{node['name']}:" 98 99 try: 100 node_title += f" {node['callable'].__doc__.split('\n')[0]}" 101 except: 102 pass 103 104 node_title += f"\nuuid: {node['id'].split('__')[-1]}" 105 node_title += f"\nposargs: {', '.join(args)}" if args else "" 106 node_title += f"\nkwargs: {_json.dumps(kwargs)}" if kwargs else "" 107 108 return node_title 109 110def visualize_plt( 111 graph: _nx.DiGraph, 112 interactive: bool = True, 113 optimize_layout: bool = True 114) -> _IGraph: 115 """ 116 Generates basic network for visualization using the NetGraph library. 117 118 Args: 119 graph: NetworkX DiGraph to visualize 120 interactive: If True, enables interactive mode in notebooks. 121 Ignored when using non-interactive backend (Agg). 122 optimize_layout: If True, reduces edge crossings (slower but prettier). 123 Set to False for faster rendering on large graphs. 124 """ 125 # Check if we're using a non-interactive backend 126 import matplotlib 127 backend = matplotlib.get_backend().lower() 128 is_interactive_backend = backend not in ('agg', 'pdf', 'svg', 'ps', 'cairo') 129 130 if _in_notebook() and interactive and is_interactive_backend: 131 _plt.ion() 132 _plt.close() 133 return _IGraph( 134 graph, arrows=True, node_shape='o', node_size=5, 135 node_layout='multipartite', 136 node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=optimize_layout), 137 node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')}, 138 node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')}, 139 edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')}, 140 annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)}, 141 annotation_fontdict=dict(horizontalalignment='left') 142 ) 143 144def generate_pyvis_network(graph: _nx.DiGraph) -> _Network: 145 """ 146 Generates basic network for visualization in Vis.js library. 147 Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support. 148 """ 149 if len(graph.nodes()) == 0: 150 err_msg = "Visualizations cannot be generated for DAGs without nodes." 151 raise RuntimeError(err_msg) 152 153 network = _Network( 154 directed=True, 155 layout='hierarchical' 156 ) 157 network.inherit_edge_colors(False) 158 159 generations = [sorted(generation) for generation in _nx.topological_generations(graph)] 160 node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0] 161 162 for n in graph.nodes: 163 node = graph.nodes[n] 164 network.add_node( 165 n, 166 label=node['name'], 167 shape='box', 168 title= annotate_node(node), 169 labelHighlightBold=True, 170 color=node_color(node['callable']), 171 level=node_generation(n) 172 ) 173 174 for source_edge, dest_edge in graph.edges: 175 condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition'] 176 network.add_edge(source_edge, dest_edge, color=edge_color(condition)) 177 178 return network 179 180 181def generate_vis_js_script(graph: _nx.DiGraph) -> str: 182 """ 183 Generates script for modeling Vis.js network graphs from a NetworkX DiGraph. 184 Conformant to: Vis.js 4.20.1-SNAPSHOT 185 """ 186 network: _Network = generate_pyvis_network(graph) 187 network_data = dict( 188 zip(["nodes", "edges", "heading", "height", "width", "options"], 189 network.get_network_data()) 190 ) 191 192 script_js = "var nodes = new vis.DataSet(" + _json.dumps(network_data['nodes']) + """);\n""" 193 script_js += "var edges = new vis.DataSet(" + _json.dumps(network_data['edges']) + """);\n""" 194 script_js += """var container = document.getElementById("mynetwork");\n""" 195 script_js += """var data = { nodes: nodes, edges: edges, };\n""" 196 script_js += """var options = { 197 "autoResize": true, 198 "configure": { 199 "enabled": false 200 }, 201 "edges": { 202 "color": { 203 "inherit": false 204 }, 205 "smooth": { 206 "enabled": false, 207 }, 208 "arrows": { 209 "to": true, 210 "from": true 211 } 212 }, 213 "interaction": { 214 "dragNodes": true, 215 "hideEdgesOnDrag": false, 216 "hideNodesOnDrag": false 217 }, 218 "layout": { 219 "hierarchical": { 220 "direction": "UD", 221 "blockShifting": true, 222 "edgeMinimization": false, 223 "enabled": true, 224 "parentCentralization": true, 225 "sortMethod": "hubsize", 226 }, 227 "improvedLayout": true, 228 "randomSeed": 0, 229 }, 230 "physics": { 231 "enabled": true, 232 "stabilization": { 233 "enabled": true, 234 "fit": true, 235 "iterations": 1000, 236 "onlyDynamicEdges": false, 237 "updateInterval": 50 238 } 239 } 240 };\n""" 241 script_js += """var network = new vis.Network(container, data, options);\n""" 242 243 return script_js 244 245 246def generate_vis_html(graph: _nx.DiGraph) -> str: 247 """ 248 Generates HTML wrapper for Vis.js visualization -- used on both browser and GUI. 249 """ 250 script = generate_vis_js_script(graph) 251 if len(script) == 0: 252 err_msg = "Empty JavaScript string given for Vis.js visualization." 253 raise RuntimeError(err_msg) 254 255 script = """<script type="text/javascript">\n""" + script + """</script>\n""" 256 257 root = _path.dirname(_path.abspath(__file__)) 258 templates_dir = _path.join(root, 'templates') 259 env = _Environment( loader = _FileSystemLoader(templates_dir) ) 260 template = env.get_template("vis_gui.html") 261 262 return template.render( 263 network_script = script, 264 ) 265 266 267def visualize_browser(graph: _nx.DiGraph): 268 """ 269 Visualizes Network graphs using a Flask app served to a desktop browser. 270 """ 271 if _sys.platform not in ['darwin', 'win32', 'linux', 'linux2']: 272 err_msg = "Unknown/unsupported operating system for GUI visualizations." 273 raise RuntimeError(err_msg) 274 275 gui_html = generate_vis_html(graph) 276 277 server = _flask.Flask(__name__) 278 server.add_url_rule("/", "index", lambda: gui_html) 279 280 print('Serving visualization...\n') 281 282 server.run(host="127.0.0.1", port=5000) 283 284 print('\nVisualization closed.') 285 286def _in_notebook() -> bool: 287 """ 288 Internal only. Helper to default to interactive notebooks when available 289 if visualization type is not specified. 290 """ 291 try: 292 from IPython.core.getipython import get_ipython 293 if 'IPKernelApp' not in get_ipython().config: 294 return False 295 except ImportError: 296 return False 297 except AttributeError: 298 return False 299 return True
def
generate_pyvis_network(graph: networkx.classes.digraph.DiGraph) -> pyvis.network.Network:
145def generate_pyvis_network(graph: _nx.DiGraph) -> _Network: 146 """ 147 Generates basic network for visualization in Vis.js library. 148 Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support. 149 """ 150 if len(graph.nodes()) == 0: 151 err_msg = "Visualizations cannot be generated for DAGs without nodes." 152 raise RuntimeError(err_msg) 153 154 network = _Network( 155 directed=True, 156 layout='hierarchical' 157 ) 158 network.inherit_edge_colors(False) 159 160 generations = [sorted(generation) for generation in _nx.topological_generations(graph)] 161 node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0] 162 163 for n in graph.nodes: 164 node = graph.nodes[n] 165 network.add_node( 166 n, 167 label=node['name'], 168 shape='box', 169 title= annotate_node(node), 170 labelHighlightBold=True, 171 color=node_color(node['callable']), 172 level=node_generation(n) 173 ) 174 175 for source_edge, dest_edge in graph.edges: 176 condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition'] 177 network.add_edge(source_edge, dest_edge, color=edge_color(condition)) 178 179 return network
Generates basic network for visualization in Vis.js library. Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
def
visualize_plt( graph: networkx.classes.digraph.DiGraph, interactive: bool = True, optimize_layout: bool = True) -> netgraph._main.InteractiveGraph:
111def visualize_plt( 112 graph: _nx.DiGraph, 113 interactive: bool = True, 114 optimize_layout: bool = True 115) -> _IGraph: 116 """ 117 Generates basic network for visualization using the NetGraph library. 118 119 Args: 120 graph: NetworkX DiGraph to visualize 121 interactive: If True, enables interactive mode in notebooks. 122 Ignored when using non-interactive backend (Agg). 123 optimize_layout: If True, reduces edge crossings (slower but prettier). 124 Set to False for faster rendering on large graphs. 125 """ 126 # Check if we're using a non-interactive backend 127 import matplotlib 128 backend = matplotlib.get_backend().lower() 129 is_interactive_backend = backend not in ('agg', 'pdf', 'svg', 'ps', 'cairo') 130 131 if _in_notebook() and interactive and is_interactive_backend: 132 _plt.ion() 133 _plt.close() 134 return _IGraph( 135 graph, arrows=True, node_shape='o', node_size=5, 136 node_layout='multipartite', 137 node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=optimize_layout), 138 node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')}, 139 node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')}, 140 edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')}, 141 annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)}, 142 annotation_fontdict=dict(horizontalalignment='left') 143 )
Generates basic network for visualization using the NetGraph library.
Args: graph: NetworkX DiGraph to visualize interactive: If True, enables interactive mode in notebooks. Ignored when using non-interactive backend (Agg). optimize_layout: If True, reduces edge crossings (slower but prettier). Set to False for faster rendering on large graphs.
def
visualize_browser(graph: networkx.classes.digraph.DiGraph):
268def visualize_browser(graph: _nx.DiGraph): 269 """ 270 Visualizes Network graphs using a Flask app served to a desktop browser. 271 """ 272 if _sys.platform not in ['darwin', 'win32', 'linux', 'linux2']: 273 err_msg = "Unknown/unsupported operating system for GUI visualizations." 274 raise RuntimeError(err_msg) 275 276 gui_html = generate_vis_html(graph) 277 278 server = _flask.Flask(__name__) 279 server.add_url_rule("/", "index", lambda: gui_html) 280 281 print('Serving visualization...\n') 282 283 server.run(host="127.0.0.1", port=5000) 284 285 print('\nVisualization closed.')
Visualizes Network graphs using a Flask app served to a desktop browser.