File size: 6,742 Bytes
4289090
 
7a5e46b
b26b502
 
 
 
 
7a5e46b
b26b502
 
7a5e46b
 
b26b502
 
7a5e46b
 
 
 
 
 
 
 
 
 
 
 
 
b26b502
 
 
 
 
 
 
7a5e46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b26b502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a5e46b
b26b502
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a5e46b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4289090
 
 
 
 
 
 
 
 
 
7a5e46b
 
4289090
 
 
 
 
7a5e46b
4289090
 
7a5e46b
 
 
 
 
 
 
 
 
 
4289090
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import plotly.graph_objects as go
import networkx as nx
import numpy as np
from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges, 
                          Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
from bokeh.palettes import Spectral4
from bokeh.plotting import from_networkx

def create_graph(entities, relationships):
    G = nx.Graph()
    for entity_id, entity_data in entities.items():
        G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
    
    for source, relation, target in relationships:
        G.add_edge(source, target, label=relation)
    
    return G

def improved_spectral_layout(G, scale=1):
    pos = nx.spectral_layout(G)
    # Add some random noise to prevent overlapping
    pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
    # Scale the layout
    pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
    return pos

def create_bokeh_plot(G, layout_type='spring'):
    plot = Plot(width=600, height=600,
                x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
    plot.title.text = "Knowledge Graph Interaction"

    node_hover = HoverTool(tooltips=[("Entity", "@label")])
    edge_hover = HoverTool(tooltips=[("Relation", "@label")])
    plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())

    # Create layout based on layout_type
    if layout_type == 'spring':
        pos = nx.spring_layout(G, k=0.5, iterations=50)
    elif layout_type == 'fruchterman_reingold':
        pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
    elif layout_type == 'circular':
        pos = nx.circular_layout(G)
    elif layout_type == 'random':
        pos = nx.random_layout(G)
    elif layout_type == 'spectral':
        pos = improved_spectral_layout(G)
    elif layout_type == 'shell':
        pos = nx.shell_layout(G)
    else:
        pos = nx.spring_layout(G, k=0.5, iterations=50)

    graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))

    graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
    graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
    graph_renderer.node_renderer.hover_glyph = Scatter(size=15, fill_color=Spectral4[1])

    graph_renderer.edge_renderer.glyph = MultiLine(line_color="#000", line_alpha=0.9, line_width=3)
    graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=4)
    graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=3)

    graph_renderer.selection_policy = NodesAndLinkedEdges()
    graph_renderer.inspection_policy = NodesAndLinkedEdges()

    plot.renderers.append(graph_renderer)

    # Add node labels
    x, y = zip(*graph_renderer.layout_provider.graph_layout.values())
    node_labels = nx.get_node_attributes(G, 'label')
    source = ColumnDataSource({'x': x, 'y': y, 'label': [node_labels[node] for node in G.nodes()]})
    labels = LabelSet(x='x', y='y', text='label', source=source, background_fill_color='white',
                      text_font_size='8pt', background_fill_alpha=0.7)
    plot.renderers.append(labels)

    # Add edge labels
    edge_x, edge_y, edge_labels = [], [], []
    for (start_node, end_node, label) in G.edges(data='label'):
        start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
        end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
        edge_x.append((start_x + end_x) / 2)
        edge_y.append((start_y + end_y) / 2)
        edge_labels.append(label)

    edge_label_source = ColumnDataSource({'x': edge_x, 'y': edge_y, 'label': edge_labels})
    edge_labels = LabelSet(x='x', y='y', text='label', source=edge_label_source,
                           background_fill_color='white', text_font_size='8pt',
                           background_fill_alpha=0.7)
    plot.renderers.append(edge_labels)

    return plot

def create_plotly_plot(G, layout_type='spring'):
    # Create layout based on layout_type
    if layout_type == 'spring':
        pos = nx.spring_layout(G, k=0.5, iterations=50)
    elif layout_type == 'fruchterman_reingold':
        pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
    elif layout_type == 'circular':
        pos = nx.circular_layout(G)
    elif layout_type == 'random':
        pos = nx.random_layout(G)
    elif layout_type == 'spectral':
        pos = improved_spectral_layout(G)
    elif layout_type == 'shell':
        pos = nx.shell_layout(G)
    else:
        pos = nx.spring_layout(G, k=0.5, iterations=50)

    edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
    node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
                            marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
                                        colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
                                        line_width=2),
                            text=[], textposition="top center")

    edge_labels = []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace["x"] += (x0, x1, None)
        edge_trace["y"] += (y0, y1, None)
        
        mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
        edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
                                      textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))

    for node in G.nodes():
        x, y = pos[node]
        node_trace["x"] += (x,)
        node_trace["y"] += (y,)
        node_trace["text"] += (G.nodes[node]["label"],)
        node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)

    fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
                    layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
                                     margin=dict(b=20, l=5, r=5, t=40), annotations=[],
                                     xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                                     width=800, height=600))

    fig.update_layout(newshape=dict(line_color="#009900"),
                      xaxis=dict(scaleanchor="y", scaleratio=1),
                      yaxis=dict(scaleanchor="x", scaleratio=1))

    return fig