webber.viz

Visualization library for Webber DAGs.

  1"""
  2Visualization library for Webber DAGs.
  3"""
  4import sys as _sys
  5import json as _json
  6import types as _types
  7import os.path as _path
  8import typing as _typing
  9import flask as _flask
 10import networkx as _nx
 11import webber.xcoms as _xcoms
 12from webber.edges import Condition
 13from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader
 14
 15# Wait for preload thread to finish to avoid circular imports with matplotlib/IPython
 16# Uses non-blocking check first; only blocks if preload is still running
 17import webber as _webber
 18if not _webber._viz_ready.is_set():
 19    _webber.wait_for_viz_ready(timeout=3.0)
 20
 21import matplotlib.pyplot as _plt
 22from pyvis.network import Network as _Network
 23from netgraph import InteractiveGraph as _IGraph
 24from jinja2 import Environment as _Environment, FileSystemLoader as _FileSystemLoader
 25
 26__all__ = ["generate_pyvis_network", "visualize_plt", "visualize_browser"]
 27
 28edge_colors: _typing.Dict[Condition, str] = {
 29    Condition.Success: 'grey',
 30    Condition.AnyCase: 'blue',
 31    Condition.Failure: 'red'
 32}
 33
 34def edge_color(c: Condition):
 35    """
 36    Given a Webber Condition, return corresponding color for edge visualizations.
 37    """
 38    return edge_colors[c]
 39
 40def node_color(c: _typing.Callable):
 41    """
 42    Given a callable, return a color that to be used in visualizations
 43    mapping to the callable's type (lambda, function, built-in, class).
 44    """
 45    _class = str(c.__class__).strip("<class '").rstrip("'>")
 46    match _class:
 47        case 'type':
 48            return '#71C6B1'
 49        case 'function':
 50            return '#679AD1' if isinstance(c, _types.LambdaType) else '#DCDCAF'
 51        case 'builtin_function_or_method':
 52            return '#DCDCAF'
 53    return '#AADAFB'
 54
 55def get_layers(graph: _nx.DiGraph) -> _typing.List[_typing.List[str]]:
 56    """
 57    Generates ordered list of node identifiers given a directed network graph.
 58    """
 59    layers = []
 60    for nodes in _nx.topological_generations(graph):
 61        layers.append(nodes)
 62    return layers
 63
 64def annotate_node(node: _typing.Dict[str, _typing.Any]) -> str:
 65    """
 66    Given a Webber node, construct an annotation to be used in graph visualizations.
 67    """
 68    args, kwargs = [], {}
 69    for a in node['args']:
 70        try:
 71            args.append(_json.dumps(a))
 72        except:
 73            if isinstance(a, _xcoms.Promise):
 74                if isinstance(a.key, str):
 75                    name = a.key.split('__')[0]
 76                else:
 77                    name = a.key.__name__
 78                args.append(f'Promise({name})')
 79            else:
 80                args.append(f'Object({str(a.__class__)})')
 81    for k,v in node['kwargs'].items():
 82        try:
 83            _json.dumps(k)
 84            try:
 85                kwargs[_json.dumps(k)] = _json.dumps(v)
 86            except:
 87                if isinstance(v, _xcoms.Promise):
 88                    if isinstance(v.key, str):
 89                        name = v.key.split('__')[0]
 90                    else:
 91                        name = v.key.__name__
 92                    kwargs[k] = f"Promise('{name}')"
 93                else:
 94                    kwargs[k] = f'Object({str(v.__class__)})'
 95        except:
 96            pass
 97    node_title  = f"{node['name']}:"
 98
 99    try:
100        node_title += f" {node['callable'].__doc__.split('\n')[0]}"
101    except:
102        pass
103
104    node_title += f"\nuuid:    {node['id'].split('__')[-1]}"
105    node_title += f"\nposargs: {', '.join(args)}" if args else ""
106    node_title += f"\nkwargs:  {_json.dumps(kwargs)}" if kwargs else ""
107
108    return node_title
109
110def visualize_plt(
111    graph: _nx.DiGraph,
112    interactive: bool = True,
113    optimize_layout: bool = True
114) -> _IGraph:
115    """
116    Generates basic network for visualization using the NetGraph library.
117
118    Args:
119        graph: NetworkX DiGraph to visualize
120        interactive: If True, enables interactive mode in notebooks.
121                     Ignored when using non-interactive backend (Agg).
122        optimize_layout: If True, reduces edge crossings (slower but prettier).
123                         Set to False for faster rendering on large graphs.
124    """
125    # Check if we're using a non-interactive backend
126    import matplotlib
127    backend = matplotlib.get_backend().lower()
128    is_interactive_backend = backend not in ('agg', 'pdf', 'svg', 'ps', 'cairo')
129
130    if _in_notebook() and interactive and is_interactive_backend:
131        _plt.ion()
132        _plt.close()
133    return _IGraph(
134        graph, arrows=True, node_shape='o', node_size=5,
135        node_layout='multipartite',
136        node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=optimize_layout),
137        node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')},
138        node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')},
139        edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')},
140        annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)},
141        annotation_fontdict=dict(horizontalalignment='left')
142    )
143
144def generate_pyvis_network(graph: _nx.DiGraph) -> _Network:
145    """
146    Generates basic network for visualization in Vis.js library.
147    Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
148    """
149    if len(graph.nodes()) == 0:
150        err_msg = "Visualizations cannot be generated for DAGs without nodes."
151        raise RuntimeError(err_msg)
152
153    network = _Network(
154        directed=True,
155        layout='hierarchical'
156    )
157    network.inherit_edge_colors(False)
158
159    generations = [sorted(generation) for generation in _nx.topological_generations(graph)]
160    node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0]
161
162    for n in graph.nodes:
163        node = graph.nodes[n]
164        network.add_node(
165            n,
166            label=node['name'],
167            shape='box',
168            title= annotate_node(node),
169            labelHighlightBold=True,
170            color=node_color(node['callable']),
171            level=node_generation(n)
172        )
173
174    for source_edge, dest_edge in graph.edges:
175        condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition']
176        network.add_edge(source_edge, dest_edge, color=edge_color(condition))
177
178    return network
179
180
181def generate_vis_js_script(graph: _nx.DiGraph) -> str:
182    """
183    Generates script for modeling Vis.js network graphs from a NetworkX DiGraph.
184    Conformant to: Vis.js 4.20.1-SNAPSHOT
185    """
186    network: _Network = generate_pyvis_network(graph)
187    network_data = dict(
188        zip(["nodes", "edges", "heading", "height", "width", "options"],
189        network.get_network_data())
190    )
191
192    script_js = "var nodes = new vis.DataSet(" + _json.dumps(network_data['nodes']) + """);\n"""
193    script_js += "var edges = new vis.DataSet(" + _json.dumps(network_data['edges']) + """);\n"""
194    script_js += """var container = document.getElementById("mynetwork");\n"""
195    script_js += """var data = { nodes: nodes, edges: edges, };\n"""
196    script_js += """var options = {
197                    "autoResize": true,
198                    "configure": {
199                        "enabled": false
200                    },
201                    "edges": {
202                        "color": {
203                            "inherit": false
204                        },
205                        "smooth": {
206                            "enabled": false,
207                        },
208                        "arrows": {
209                            "to": true,
210                            "from": true
211                        }
212                    },
213                    "interaction": {
214                        "dragNodes": true,
215                        "hideEdgesOnDrag": false,
216                        "hideNodesOnDrag": false
217                    },
218                    "layout": {
219                        "hierarchical": {
220                            "direction": "UD",
221                            "blockShifting": true,
222                            "edgeMinimization": false,
223                            "enabled": true,
224                            "parentCentralization": true,
225                            "sortMethod": "hubsize",
226                        },
227                        "improvedLayout": true,
228                        "randomSeed": 0,
229                    },
230                    "physics": {
231                        "enabled": true,
232                        "stabilization": {
233                            "enabled": true,
234                            "fit": true,
235                            "iterations": 1000,
236                            "onlyDynamicEdges": false,
237                            "updateInterval": 50
238                        }
239                    }
240                };\n"""
241    script_js += """var network = new vis.Network(container, data, options);\n"""
242
243    return script_js
244
245
246def generate_vis_html(graph: _nx.DiGraph) -> str:
247    """
248    Generates HTML wrapper for Vis.js visualization -- used on both browser and GUI.
249    """
250    script = generate_vis_js_script(graph)
251    if len(script) == 0:
252        err_msg = "Empty JavaScript string given for Vis.js visualization."
253        raise RuntimeError(err_msg)
254
255    script = """<script type="text/javascript">\n""" + script + """</script>\n"""
256
257    root = _path.dirname(_path.abspath(__file__))
258    templates_dir = _path.join(root, 'templates')
259    env = _Environment( loader = _FileSystemLoader(templates_dir) )
260    template = env.get_template("vis_gui.html")
261
262    return template.render(
263        network_script = script,
264    )
265
266
267def visualize_browser(graph: _nx.DiGraph):
268    """
269    Visualizes Network graphs using a Flask app served to a desktop browser.
270    """
271    if _sys.platform not in ['darwin', 'win32', 'linux', 'linux2']:
272        err_msg = "Unknown/unsupported operating system for GUI visualizations."
273        raise RuntimeError(err_msg)
274
275    gui_html = generate_vis_html(graph)
276
277    server = _flask.Flask(__name__)
278    server.add_url_rule("/", "index", lambda: gui_html)
279
280    print('Serving visualization...\n')
281
282    server.run(host="127.0.0.1", port=5000)
283
284    print('\nVisualization closed.')
285
286def _in_notebook() -> bool:
287    """
288    Internal only. Helper to default to interactive notebooks when available
289    if visualization type is not specified.
290    """
291    try:
292        from IPython.core.getipython import get_ipython
293        if 'IPKernelApp' not in get_ipython().config:
294            return False
295    except ImportError:
296        return False
297    except AttributeError:
298        return False
299    return True
def generate_pyvis_network(graph: networkx.classes.digraph.DiGraph) -> pyvis.network.Network:
145def generate_pyvis_network(graph: _nx.DiGraph) -> _Network:
146    """
147    Generates basic network for visualization in Vis.js library.
148    Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.
149    """
150    if len(graph.nodes()) == 0:
151        err_msg = "Visualizations cannot be generated for DAGs without nodes."
152        raise RuntimeError(err_msg)
153
154    network = _Network(
155        directed=True,
156        layout='hierarchical'
157    )
158    network.inherit_edge_colors(False)
159
160    generations = [sorted(generation) for generation in _nx.topological_generations(graph)]
161    node_generation = lambda n: [i for i, G in enumerate(generations) if n in G][0]
162
163    for n in graph.nodes:
164        node = graph.nodes[n]
165        network.add_node(
166            n,
167            label=node['name'],
168            shape='box',
169            title= annotate_node(node),
170            labelHighlightBold=True,
171            color=node_color(node['callable']),
172            level=node_generation(n)
173        )
174
175    for source_edge, dest_edge in graph.edges:
176        condition: Condition = graph.edges.get((source_edge, dest_edge))['Condition']
177        network.add_edge(source_edge, dest_edge, color=edge_color(condition))
178
179    return network

Generates basic network for visualization in Vis.js library. Depends on PyVis, Vis.js modules/libraries -- both are under legacy/minimal community support.

def visualize_plt( graph: networkx.classes.digraph.DiGraph, interactive: bool = True, optimize_layout: bool = True) -> netgraph._main.InteractiveGraph:
111def visualize_plt(
112    graph: _nx.DiGraph,
113    interactive: bool = True,
114    optimize_layout: bool = True
115) -> _IGraph:
116    """
117    Generates basic network for visualization using the NetGraph library.
118
119    Args:
120        graph: NetworkX DiGraph to visualize
121        interactive: If True, enables interactive mode in notebooks.
122                     Ignored when using non-interactive backend (Agg).
123        optimize_layout: If True, reduces edge crossings (slower but prettier).
124                         Set to False for faster rendering on large graphs.
125    """
126    # Check if we're using a non-interactive backend
127    import matplotlib
128    backend = matplotlib.get_backend().lower()
129    is_interactive_backend = backend not in ('agg', 'pdf', 'svg', 'ps', 'cairo')
130
131    if _in_notebook() and interactive and is_interactive_backend:
132        _plt.ion()
133        _plt.close()
134    return _IGraph(
135        graph, arrows=True, node_shape='o', node_size=5,
136        node_layout='multipartite',
137        node_layout_kwargs=dict(layers=get_layers(graph), reduce_edge_crossings=optimize_layout),
138        node_labels={id: c.__name__ for id,c in graph.nodes.data(data='callable')},
139        node_color={id: node_color(c) for id,c in graph.nodes.data(data='callable')},
140        edge_color={e[:-1]: edge_color(e[-1]) for e in graph.edges.data(data='Condition')},
141        annotations={id: annotate_node(n) for id,n in graph.nodes.data(data=True)},
142        annotation_fontdict=dict(horizontalalignment='left')
143    )

Generates basic network for visualization using the NetGraph library.

Args: graph: NetworkX DiGraph to visualize interactive: If True, enables interactive mode in notebooks. Ignored when using non-interactive backend (Agg). optimize_layout: If True, reduces edge crossings (slower but prettier). Set to False for faster rendering on large graphs.

def visualize_browser(graph: networkx.classes.digraph.DiGraph):
268def visualize_browser(graph: _nx.DiGraph):
269    """
270    Visualizes Network graphs using a Flask app served to a desktop browser.
271    """
272    if _sys.platform not in ['darwin', 'win32', 'linux', 'linux2']:
273        err_msg = "Unknown/unsupported operating system for GUI visualizations."
274        raise RuntimeError(err_msg)
275
276    gui_html = generate_vis_html(graph)
277
278    server = _flask.Flask(__name__)
279    server.add_url_rule("/", "index", lambda: gui_html)
280
281    print('Serving visualization...\n')
282
283    server.run(host="127.0.0.1", port=5000)
284
285    print('\nVisualization closed.')

Visualizes Network graphs using a Flask app served to a desktop browser.