webber.viz
Visualization library for Webber DAGs.
Last updated by: Jan 22, 2024 (v0.0.2)
1""" 2Visualization library for Webber DAGs. 3 4Last updated by: Jan 22, 2024 (v0.0.2) 5""" 6import sys as _sys 7import json as _json 8import types as _types 9import os.path as _path 10import typing as _typing 11import flask as _flask 12import networkx as _nx 13import webber.xcoms as _xcoms 14import matplotlib.pyplot as _plt 15from webber.edges import Condition 16from pyvis.network import Network as _Network 17from netgraph import InteractiveGraph as _IGraph 18# from PyQt6.QtWidgets import QApplication as _QApplication 19# from PyQt6.QtWebEngineCore import QWebEnginePage as _QWebEnginePage 20# from PyQt6.QtWebEngineWidgets import QWebEngineView as _QWebEngineView 21 22from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader 23 24__all__ = ["generate_pyvis_network", "visualize_plt", "visualize_browser"] 25 26edge_colors: dict[Condition, str] = { 27 Condition.Success: 'grey', 28 Condition.AnyCase: 'blue', 29 Condition.Failure: 'red' 30} 31 32def edge_color(c: Condition): 33 """ 34 Given a Webber Condition, return corresponding color for edge visualizations. 35 """ 36 return edge_colors[c] 37 38def node_color(c: _typing.Callable): 39 """ 40 Given a callable, return a color that to be used in visualizations 41 mapping to the callable's type (lambda, function, built-in, class). 42 """ 43 _class = str(c.__class__).strip("<class '").rstrip("'>") 44 match _class: 45 case 'type': 46 return '#71C6B1' 47 case 'function': 48 return '#679AD1' if isinstance(c, _types.LambdaType) else '#DCDCAF' 49 case 'builtin_function_or_method': 50 return '#DCDCAF' 51 return '#AADAFB' 52 53def get_layers(graph: _nx.DiGraph) -> list[list[str]]: 54 """ 55 Generates ordered list of node identifiers given a directed network graph. 56 """ 57 layers = [] 58 for nodes in _nx.topological_generations(graph): 59 layers.append(nodes) 60 return layers 61 62def annotate_node(node: dict): 63 """ 64 Given a Webber node, construct an annotation to be used in graph visualizations. 65 """ 66 args, kwargs = [], {} 67 for a in node['args']: 68 try: 69 args.append(_json.dumps(a)) 70 except: 71 if isinstance(a, _xcoms.Promise): 72 if isinstance(a.key, str): 73 name = a.key.split('__')[0] 74 else: 75 name = a.key.__name__ 76 args.append(f'Promise({name})') 77 else: 78 args.append(f'Object({str(a.__class__)})') 79 for k,v in node['kwargs'].items(): 80 try: 81 _json.dumps(k) 82 try: 83 kwargs[_json.dumps(k)] = _json.dumps(v) 84 except: 85 if isinstance(v, _xcoms.Promise): 86 if isinstance(v.key, str): 87 name = v.key.split('__')[0] 88 else: 89 name = v.key.__name__ 90 kwargs[k] = f"Promise('{name}')" 91 else: 92 kwargs[k] = f'Object({str(v.__class__)})' 93 except: 94 pass 95 node_title = f"{node['name']}:" 96 97 try: 98 node_title += f" {node['callable'].__doc__.split('\n')[0]}" 99 except: 100 pass 101 102 node_title += f"\nuuid: {node['id'].split('__')[-1]}" 103 node_title += f"\nposargs: {', '.join(args)}" if args else "" 104 node_title += f"\nkwargs: {_json.dumps(kwargs)}" if kwargs else "" 105 106 return node_title 107 108def visualize_plt(graph: _nx.DiGraph, interactive=True) -> _IGraph: 109 """ 110 Generates basic network for visualization using the NetGraph library. 111 """ 112 if _in_notebook() and interactive: 113 _plt.ion() 114 _plt.close() 115 return _IGraph( 116 graph, arrows=True, node_shape='o', node_size=5, 117 node_layout='multipartite', 118 node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=True), 119 node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')}, 120 node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')}, 121 edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')}, 122 annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)}, 123 annotation_fontdict=dict(horizontalalignment='left') 124 ) 125 126def generate_pyvis_network(graph: _nx.DiGraph) -> _Network: 127 """ 128 Generates basic network for visualization in Vis.js library. 129 Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support. 130 """ 131 if len(graph.nodes()) == 0: 132 err_msg = "Visualizations cannot be generated for DAGs without nodes." 133 raise RuntimeError(err_msg) 134 135 network = _Network( 136 directed=True, 137 layout='hierarchical' 138 ) 139 network.inherit_edge_colors(False) 140 141 generations = [sorted(generation) for generation in _nx.topological_generations(graph)] 142 node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0] 143 144 for n in graph.nodes: 145 node = graph.nodes[n] 146 network.add_node( 147 n, 148 label=node['name'], 149 shape='box', 150 title= annotate_node(node), 151 labelHighlightBold=True, 152 color=node_color(node['callable']), 153 level=node_generation(n) 154 ) 155 156 for source_edge, dest_edge in graph.edges: 157 condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition'] 158 network.add_edge(source_edge, dest_edge, color=edge_color(condition)) 159 160 return network 161 162 163def generate_vis_js_script(graph: _nx.DiGraph) -> str: 164 """ 165 Generates script for modeling Vis.js network graphs from a NetworkX DiGraph. 166 Conformant to: Vis.js 4.20.1-SNAPSHOT 167 """ 168 network: _Network = generate_pyvis_network(graph) 169 network_data = dict( 170 zip(["nodes", "edges", "heading", "height", "width", "options"], 171 network.get_network_data()) 172 ) 173 174 script_js = "var nodes = new vis.DataSet(" + _json.dumps(network_data['nodes']) + """);\n""" 175 script_js += "var edges = new vis.DataSet(" + _json.dumps(network_data['edges']) + """);\n""" 176 script_js += """var container = document.getElementById("mynetwork");\n""" 177 script_js += """var data = { nodes: nodes, edges: edges, };\n""" 178 script_js += """var options = { 179 "autoResize": true, 180 "configure": { 181 "enabled": false 182 }, 183 "edges": { 184 "color": { 185 "inherit": false 186 }, 187 "smooth": { 188 "enabled": false, 189 }, 190 "arrows": { 191 "to": true, 192 "from": true 193 } 194 }, 195 "interaction": { 196 "dragNodes": true, 197 "hideEdgesOnDrag": false, 198 "hideNodesOnDrag": false 199 }, 200 "layout": { 201 "hierarchical": { 202 "direction": "UD", 203 "blockShifting": true, 204 "edgeMinimization": false, 205 "enabled": true, 206 "parentCentralization": true, 207 "sortMethod": "hubsize", 208 }, 209 "improvedLayout": true, 210 "randomSeed": 0, 211 }, 212 "physics": { 213 "enabled": true, 214 "stabilization": { 215 "enabled": true, 216 "fit": true, 217 "iterations": 1000, 218 "onlyDynamicEdges": false, 219 "updateInterval": 50 220 } 221 } 222 };\n""" 223 script_js += """var network = new vis.Network(container, data, options);\n""" 224 225 return script_js 226 227 228def generate_vis_html(graph: _nx.DiGraph) -> str: 229 """ 230 Generates HTML wrapper for Vis.js visualization -- used on both browser and GUI. 231 """ 232 script = generate_vis_js_script(graph) 233 if len(script) == 0: 234 err_msg = "Empty JavaScript string given for Vis.js visualization." 235 raise RuntimeError(err_msg) 236 237 # Invalid JavaScript check: 238 # if not (script): 239 # err_msg = "Invalid JavaScript string for Vis.js visualization." 240 # raise RuntimeError(err_msg) 241 242 script = """<script type="text/javascript">\n""" + script + """</script>\n""" 243 244 root = _path.dirname(_path.abspath(__file__)) 245 templates_dir = _path.join(root, 'templates') 246 env = _Environment( loader = _FileSystemLoader(templates_dir) ) 247 template = env.get_template("vis_gui.html") 248 249 return template.render( 250 network_script = script, 251 ) 252 253 254def visualize_browser(graph: _nx.DiGraph): 255 """ 256 Visualizes Network graphs using a Flask app served to a desktop browser. 257 """ 258 if _sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']: 259 err_msg = "Unknown/unsupported operating system for GUI visualizations." 260 raise RuntimeError(err_msg) 261 262 gui_html = generate_vis_html(graph) 263 264 server = _flask.Flask(__name__) 265 server.add_url_rule("/", "index", lambda: gui_html) 266 267 print('Serving visualization...\n') 268 269 server.run(host="127.0.0.1", port=5000) 270 271 print('\nVisualization closed.') 272 273def _in_notebook() -> bool: 274 """ 275 Internal only. Helper to default to interactive notebooks when available 276 if visualization type is not specified. 277 """ 278 try: 279 from IPython import get_ipython 280 if 'IPKernelApp' not in get_ipython().config: 281 return False 282 except ImportError: 283 return False 284 except AttributeError: 285 return False 286 return True 287 288# def visualize_gui(graph: _nx.DiGraph): 289# """ 290# Visualizes Network graphs using a desktop GUI generated by the PyQt6 library. 291# """ 292# if sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']: 293# err_msg = "Unknown/unsupported operating system for GUI visualizations." 294# raise RuntimeError(err_msg) 295 296# gui_html = generate_vis_html(graph) 297 298# class WebEngineView(_QWebEngineView): 299# """ 300# A small Qt-based WebEngineView to generate a GUI using embedded HTML and JavaScript. 301# """ 302# def __init__(self, parent=None): 303# super().__init__(parent) 304# self.webpage = _QWebEnginePage() 305# self.setPage(self.webpage) 306# self.webpage.setHtml(gui_html) 307 308# app = _QApplication([]) 309# web_engine_view = WebEngineView() 310# web_engine_view.showNormal() 311# app.exec()
def
generate_pyvis_network(graph: networkx.classes.digraph.DiGraph) -> pyvis.network.Network:
127def generate_pyvis_network(graph: _nx.DiGraph) -> _Network: 128 """ 129 Generates basic network for visualization in Vis.js library. 130 Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support. 131 """ 132 if len(graph.nodes()) == 0: 133 err_msg = "Visualizations cannot be generated for DAGs without nodes." 134 raise RuntimeError(err_msg) 135 136 network = _Network( 137 directed=True, 138 layout='hierarchical' 139 ) 140 network.inherit_edge_colors(False) 141 142 generations = [sorted(generation) for generation in _nx.topological_generations(graph)] 143 node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0] 144 145 for n in graph.nodes: 146 node = graph.nodes[n] 147 network.add_node( 148 n, 149 label=node['name'], 150 shape='box', 151 title= annotate_node(node), 152 labelHighlightBold=True, 153 color=node_color(node['callable']), 154 level=node_generation(n) 155 ) 156 157 for source_edge, dest_edge in graph.edges: 158 condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition'] 159 network.add_edge(source_edge, dest_edge, color=edge_color(condition)) 160 161 return network
Generates basic network for visualization in Vis.js library. Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
def
visualize_plt( graph: networkx.classes.digraph.DiGraph, interactive=True) -> netgraph._main.InteractiveGraph:
109def visualize_plt(graph: _nx.DiGraph, interactive=True) -> _IGraph: 110 """ 111 Generates basic network for visualization using the NetGraph library. 112 """ 113 if _in_notebook() and interactive: 114 _plt.ion() 115 _plt.close() 116 return _IGraph( 117 graph, arrows=True, node_shape='o', node_size=5, 118 node_layout='multipartite', 119 node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=True), 120 node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')}, 121 node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')}, 122 edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')}, 123 annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)}, 124 annotation_fontdict=dict(horizontalalignment='left') 125 )
Generates basic network for visualization using the NetGraph library.
def
visualize_browser(graph: networkx.classes.digraph.DiGraph):
255def visualize_browser(graph: _nx.DiGraph): 256 """ 257 Visualizes Network graphs using a Flask app served to a desktop browser. 258 """ 259 if _sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']: 260 err_msg = "Unknown/unsupported operating system for GUI visualizations." 261 raise RuntimeError(err_msg) 262 263 gui_html = generate_vis_html(graph) 264 265 server = _flask.Flask(__name__) 266 server.add_url_rule("/", "index", lambda: gui_html) 267 268 print('Serving visualization...\n') 269 270 server.run(host="127.0.0.1", port=5000) 271 272 print('\nVisualization closed.')
Visualizes Network graphs using a Flask app served to a desktop browser.