webber.viz

Visualization library for Webber DAGs.

Last updated by: Jan 22, 2024 (v0.0.2)

  1"""
  2Visualization library for Webber DAGs.
  3
  4Last updated by: Jan 22, 2024 (v0.0.2)
  5"""
  6import sys as _sys
  7import json as _json
  8import types as _types
  9import os.path as _path
 10import typing as _typing
 11import flask as _flask
 12import networkx as _nx
 13import webber.xcoms as _xcoms
 14import matplotlib.pyplot as _plt
 15from webber.edges import Condition
 16from pyvis.network import Network as _Network
 17from netgraph import InteractiveGraph as _IGraph
 18# from PyQt6.QtWidgets import QApplication as _QApplication
 19# from PyQt6.QtWebEngineCore import QWebEnginePage as _QWebEnginePage
 20# from PyQt6.QtWebEngineWidgets import QWebEngineView as _QWebEngineView
 21
 22from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader
 23
 24__all__ = ["generate_pyvis_network", "visualize_plt", "visualize_browser"]
 25
 26edge_colors: dict[Condition, str] = {
 27    Condition.Success: 'grey',
 28    Condition.AnyCase: 'blue',
 29    Condition.Failure: 'red'
 30}
 31
 32def edge_color(c: Condition):
 33    """
 34    Given a Webber Condition, return corresponding color for edge visualizations. 
 35    """
 36    return edge_colors[c]
 37
 38def node_color(c: _typing.Callable):
 39    """
 40    Given a callable, return a color that to be used in visualizations
 41    mapping to the callable's type (lambda, function, built-in, class).
 42    """
 43    _class = str(c.__class__).strip("<class '").rstrip("'>")
 44    match _class:
 45        case 'type':
 46            return '#71C6B1'
 47        case 'function':
 48            return '#679AD1' if isinstance(c, _types.LambdaType) else '#DCDCAF'
 49        case 'builtin_function_or_method':
 50            return '#DCDCAF'
 51    return '#AADAFB'
 52
 53def get_layers(graph: _nx.DiGraph) -> list[list[str]]:
 54    """
 55    Generates ordered list of node identifiers given a directed network graph.
 56    """
 57    layers = []
 58    for nodes in _nx.topological_generations(graph):
 59        layers.append(nodes)
 60    return layers
 61
 62def annotate_node(node: dict):
 63    """
 64    Given a Webber node, construct an annotation to be used in graph visualizations.
 65    """
 66    args, kwargs = [], {}
 67    for a in node['args']:
 68        try:
 69            args.append(_json.dumps(a))
 70        except:
 71            if isinstance(a, _xcoms.Promise):
 72                if isinstance(a.key, str):
 73                    name = a.key.split('__')[0]
 74                else:
 75                    name = a.key.__name__
 76                args.append(f'Promise({name})')
 77            else:
 78                args.append(f'Object({str(a.__class__)})')
 79    for k,v in node['kwargs'].items():
 80        try:
 81            _json.dumps(k)
 82            try:
 83                kwargs[_json.dumps(k)] = _json.dumps(v)
 84            except:
 85                if isinstance(v, _xcoms.Promise):
 86                    if isinstance(v.key, str):
 87                        name = v.key.split('__')[0]
 88                    else:
 89                        name = v.key.__name__
 90                    kwargs[k] = f"Promise('{name}')"
 91                else:
 92                    kwargs[k] = f'Object({str(v.__class__)})'
 93        except:
 94            pass
 95    node_title  = f"{node['name']}:"
 96
 97    try:
 98        node_title += f" {node['callable'].__doc__.split('\n')[0]}"
 99    except:
100        pass
101
102    node_title += f"\nuuid:    {node['id'].split('__')[-1]}"
103    node_title += f"\nposargs: {', '.join(args)}" if args else ""
104    node_title += f"\nkwargs:  {_json.dumps(kwargs)}" if kwargs else ""
105
106    return node_title
107
108def visualize_plt(graph: _nx.DiGraph, interactive=True) -> _IGraph:
109    """
110    Generates basic network for visualization using the NetGraph library.
111    """
112    if _in_notebook() and interactive:
113        _plt.ion()
114        _plt.close()
115    return _IGraph(
116        graph, arrows=True, node_shape='o', node_size=5,
117        node_layout='multipartite',
118        node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=True),
119        node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')},
120        node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')},
121        edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')},
122        annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)},
123        annotation_fontdict=dict(horizontalalignment='left')
124    )
125
126def generate_pyvis_network(graph: _nx.DiGraph) -> _Network:
127    """
128    Generates basic network for visualization in Vis.js library.
129    Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
130    """
131    if len(graph.nodes()) == 0:
132        err_msg = "Visualizations cannot be generated for DAGs without nodes."
133        raise RuntimeError(err_msg)
134
135    network = _Network(
136        directed=True,
137        layout='hierarchical'
138    )
139    network.inherit_edge_colors(False)
140
141    generations = [sorted(generation) for generation in _nx.topological_generations(graph)]
142    node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0]
143
144    for n in graph.nodes:
145        node = graph.nodes[n]
146        network.add_node(
147            n,
148            label=node['name'],
149            shape='box',
150            title= annotate_node(node),
151            labelHighlightBold=True,
152            color=node_color(node['callable']),
153            level=node_generation(n)
154        )
155
156    for source_edge, dest_edge in graph.edges:
157        condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition']
158        network.add_edge(source_edge, dest_edge, color=edge_color(condition))
159    
160    return network
161
162
163def generate_vis_js_script(graph: _nx.DiGraph) -> str:
164    """
165    Generates script for modeling Vis.js network graphs from a NetworkX DiGraph.
166    Conformant to: Vis.js 4.20.1-SNAPSHOT
167    """
168    network: _Network = generate_pyvis_network(graph)
169    network_data = dict(
170        zip(["nodes", "edges", "heading", "height", "width", "options"],
171        network.get_network_data())
172    )
173
174    script_js = "var nodes = new vis.DataSet(" + _json.dumps(network_data['nodes']) + """);\n"""
175    script_js += "var edges = new vis.DataSet(" + _json.dumps(network_data['edges']) + """);\n"""
176    script_js += """var container = document.getElementById("mynetwork");\n"""
177    script_js += """var data = { nodes: nodes, edges: edges, };\n"""
178    script_js += """var options = {
179                    "autoResize": true,
180                    "configure": {
181                        "enabled": false
182                    },
183                    "edges": {
184                        "color": {
185                            "inherit": false
186                        },
187                        "smooth": {
188                            "enabled": false,
189                        },
190                        "arrows": {
191                            "to": true,
192                            "from": true
193                        }
194                    },
195                    "interaction": {
196                        "dragNodes": true,
197                        "hideEdgesOnDrag": false,
198                        "hideNodesOnDrag": false
199                    },
200                    "layout": {
201                        "hierarchical": {
202                            "direction": "UD",
203                            "blockShifting": true,
204                            "edgeMinimization": false,
205                            "enabled": true,
206                            "parentCentralization": true,
207                            "sortMethod": "hubsize",
208                        },
209                        "improvedLayout": true,
210                        "randomSeed": 0,
211                    },
212                    "physics": {
213                        "enabled": true,
214                        "stabilization": {
215                            "enabled": true,
216                            "fit": true,
217                            "iterations": 1000,
218                            "onlyDynamicEdges": false,
219                            "updateInterval": 50
220                        }
221                    }
222                };\n"""
223    script_js += """var network = new vis.Network(container, data, options);\n"""
224
225    return script_js
226
227
228def generate_vis_html(graph: _nx.DiGraph) -> str:
229    """
230    Generates HTML wrapper for Vis.js visualization -- used on both browser and GUI.
231    """
232    script = generate_vis_js_script(graph)
233    if len(script) == 0:
234        err_msg = "Empty JavaScript string given for Vis.js visualization."
235        raise RuntimeError(err_msg)
236
237    # Invalid JavaScript check:
238    # if not (script):
239        # err_msg = "Invalid JavaScript string for Vis.js visualization."
240        # raise RuntimeError(err_msg)
241
242    script = """<script type="text/javascript">\n""" + script + """</script>\n"""
243
244    root = _path.dirname(_path.abspath(__file__))
245    templates_dir = _path.join(root, 'templates')
246    env = _Environment( loader = _FileSystemLoader(templates_dir) )
247    template = env.get_template("vis_gui.html")
248
249    return template.render(
250        network_script = script,
251    )
252
253
254def visualize_browser(graph: _nx.DiGraph):
255    """
256    Visualizes Network graphs using a Flask app served to a desktop browser.
257    """
258    if _sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']:
259        err_msg = "Unknown/unsupported operating system for GUI visualizations."
260        raise RuntimeError(err_msg)
261
262    gui_html = generate_vis_html(graph)
263
264    server = _flask.Flask(__name__)
265    server.add_url_rule("/", "index", lambda: gui_html)
266
267    print('Serving visualization...\n')
268
269    server.run(host="127.0.0.1", port=5000)
270
271    print('\nVisualization closed.')
272
273def _in_notebook() -> bool:
274    """
275    Internal only. Helper to default to interactive notebooks when available
276    if visualization type is not specified.
277    """
278    try:
279        from IPython import get_ipython
280        if 'IPKernelApp' not in get_ipython().config:
281            return False
282    except ImportError:
283        return False
284    except AttributeError:
285        return False
286    return True
287
288# def visualize_gui(graph: _nx.DiGraph):
289#     """
290#     Visualizes Network graphs using a desktop GUI generated by the PyQt6 library.
291#     """
292#     if sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']:
293#         err_msg = "Unknown/unsupported operating system for GUI visualizations."
294#         raise RuntimeError(err_msg)
295
296#     gui_html = generate_vis_html(graph)
297
298#     class WebEngineView(_QWebEngineView):
299#         """
300#         A small Qt-based WebEngineView to generate a GUI using embedded HTML and JavaScript.
301#         """
302#         def __init__(self, parent=None):
303#             super().__init__(parent)
304#             self.webpage = _QWebEnginePage()
305#             self.setPage(self.webpage)
306#             self.webpage.setHtml(gui_html)
307
308#     app = _QApplication([])
309#     web_engine_view = WebEngineView()
310#     web_engine_view.showNormal()
311#     app.exec()
def generate_pyvis_network(graph: networkx.classes.digraph.DiGraph) -> pyvis.network.Network:
127def generate_pyvis_network(graph: _nx.DiGraph) -> _Network:
128    """
129    Generates basic network for visualization in Vis.js library.
130    Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
131    """
132    if len(graph.nodes()) == 0:
133        err_msg = "Visualizations cannot be generated for DAGs without nodes."
134        raise RuntimeError(err_msg)
135
136    network = _Network(
137        directed=True,
138        layout='hierarchical'
139    )
140    network.inherit_edge_colors(False)
141
142    generations = [sorted(generation) for generation in _nx.topological_generations(graph)]
143    node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0]
144
145    for n in graph.nodes:
146        node = graph.nodes[n]
147        network.add_node(
148            n,
149            label=node['name'],
150            shape='box',
151            title= annotate_node(node),
152            labelHighlightBold=True,
153            color=node_color(node['callable']),
154            level=node_generation(n)
155        )
156
157    for source_edge, dest_edge in graph.edges:
158        condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition']
159        network.add_edge(source_edge, dest_edge, color=edge_color(condition))
160    
161    return network

Generates basic network for visualization in Vis.js library. Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.

def visualize_plt( graph: networkx.classes.digraph.DiGraph, interactive=True) -> netgraph._main.InteractiveGraph:
109def visualize_plt(graph: _nx.DiGraph, interactive=True) -> _IGraph:
110    """
111    Generates basic network for visualization using the NetGraph library.
112    """
113    if _in_notebook() and interactive:
114        _plt.ion()
115        _plt.close()
116    return _IGraph(
117        graph, arrows=True, node_shape='o', node_size=5,
118        node_layout='multipartite',
119        node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=True),
120        node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')},
121        node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')},
122        edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')},
123        annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)},
124        annotation_fontdict=dict(horizontalalignment='left')
125    )

Generates basic network for visualization using the NetGraph library.

def visualize_browser(graph: networkx.classes.digraph.DiGraph):
255def visualize_browser(graph: _nx.DiGraph):
256    """
257    Visualizes Network graphs using a Flask app served to a desktop browser.
258    """
259    if _sys.platform not in ['darwin', 'win32', 'linux', 'linxu2']:
260        err_msg = "Unknown/unsupported operating system for GUI visualizations."
261        raise RuntimeError(err_msg)
262
263    gui_html = generate_vis_html(graph)
264
265    server = _flask.Flask(__name__)
266    server.add_url_rule("/", "index", lambda: gui_html)
267
268    print('Serving visualization...\n')
269
270    server.run(host="127.0.0.1", port=5000)
271
272    print('\nVisualization closed.')

Visualizes Network graphs using a Flask app served to a desktop browser.