Remsky's picture
Multiplot support, bokeh and plotly, multiple graph layout support.
7a5e46b
import plotly.graph_objects as go
import networkx as nx
import numpy as np
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx
def create_graph(entities, relationships):
G = nx.Graph()
for entity_id, entity_data in entities.items():
G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
for source, relation, target in relationships:
G.add_edge(source, target, label=relation)
return G
def improved_spectral_layout(G, scale=1):
pos = nx.spectral_layout(G)
# Add some random noise to prevent overlapping
pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
# Scale the layout
pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
return pos
def create_bokeh_plot(G, layout_type='spring'):
plot = Plot(width=600, height=600,
x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
plot.title.text = "Knowledge Graph Interaction"
node_hover = HoverTool(tooltips=[("Entity", "@label")])
edge_hover = HoverTool(tooltips=[("Relation", "@label")])
plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
# Create layout based on layout_type
if layout_type == 'spring':
pos = nx.spring_layout(G, k=0.5, iterations=50)
elif layout_type == 'fruchterman_reingold':
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
elif layout_type == 'circular':
pos = nx.circular_layout(G)
elif layout_type == 'random':
pos = nx.random_layout(G)
elif layout_type == 'spectral':
pos = improved_spectral_layout(G)
elif layout_type == 'shell':
pos = nx.shell_layout(G)
else:
pos = nx.spring_layout(G, k=0.5, iterations=50)
graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))
graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)
graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()
plot.renderers.append(graph_renderer)
# Add node labels
x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
node_labels = nx.get_node_attributes(G, 'label')
source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
text_font_size='8pt', background_fill_alpha=0.7)
plot.renderers.append(labels)
# Add edge labels
edge_x, edge_y, edge_labels = [], [], []
for (start_node, end_node, label) in G.edges(data='label'):
start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
edge_x.append((start_x + end_x) / 2)
edge_y.append((start_y + end_y) / 2)
edge_labels.append(label)
edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
background_fill_color='white', text_font_size='8pt',
background_fill_alpha=0.7)
plot.renderers.append(edge_labels)
return plot
def create_plotly_plot(G, layout_type='spring'):
# Create layout based on layout_type
if layout_type == 'spring':
pos = nx.spring_layout(G, k=0.5, iterations=50)
elif layout_type == 'fruchterman_reingold':
pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
elif layout_type == 'circular':
pos = nx.circular_layout(G)
elif layout_type == 'random':
pos = nx.random_layout(G)
elif layout_type == 'spectral':
pos = improved_spectral_layout(G)
elif layout_type == 'shell':
pos = nx.shell_layout(G)
else:
pos = nx.spring_layout(G, k=0.5, iterations=50)
edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
line_width=2),
text=[], textposition="top center")
edge_labels = []
for edge in G.edges():
x0, y0 = pos[edge[0]]
x1, y1 = pos[edge[1]]
edge_trace["x"] += (x0, x1, None)
edge_trace["y"] += (y0, y1, None)
mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))
for node in G.nodes():
x, y = pos[node]
node_trace["x"] += (x,)
node_trace["y"] += (y,)
node_trace["text"] += (G.nodes[node]["label"],)
node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
margin=dict(b=20, l=5, r=5, t=40), annotations=[],
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
width=800, height=600))
fig.update_layout(newshape=dict(line_color="#009900"),
xaxis=dict(scaleanchor="y", scaleratio=1),
yaxis=dict(scaleanchor="x", scaleratio=1))
return fig