Remsky commited on
Commit
7a5e46b
β€’
1 Parent(s): c906105

Multiplot support, bokeh and plotly, multiple graph layout support.

Browse files
Files changed (2) hide show
  1. app.py +65 -43
  2. lib/visualize.py +71 -118
app.py CHANGED
@@ -1,52 +1,66 @@
1
-
2
  import random
3
 
4
  import gradio as gr
5
- import spaces
6
 
7
  from lib.graph_extract import triplextract, parse_triples
8
- from lib.visualize import create_bokeh_plot #, create_plotly_plot
9
  from lib.samples import snippets
10
 
11
  WORD_LIMIT = 300
12
 
13
- def process_text(text, entity_types, predicates):
14
  if not text:
15
- return None, "Please enter some text."
16
 
17
  words = text.split()
18
  if len(words) > WORD_LIMIT:
19
- return None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
20
 
21
  entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
22
  predicates = [p.strip() for p in predicates.split(",") if p.strip()]
23
 
24
  if not entity_types:
25
- return None, "Please enter at least one entity type."
26
  if not predicates:
27
- return None, "Please enter at least one predicate."
28
 
29
  try:
30
  prediction = triplextract(text, entity_types, predicates)
31
  if prediction.startswith("Error"):
32
- return None, prediction
33
 
34
  entities, relationships = parse_triples(prediction)
35
 
36
  if not entities and not relationships:
37
- return (
38
- None,
39
- "No entities or relationships found. Try different text or check your input.",
40
- )
41
-
42
- fig = create_bokeh_plot(entities, relationships)
43
- return (
44
- fig,
45
- f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}",
46
- )
 
47
  except Exception as e:
48
- print(f"Error in process_text: {e}")
49
- return None, f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def update_inputs(sample_name):
52
  sample = snippets[sample_name]
@@ -60,34 +74,42 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
60
 
61
  with gr.Row():
62
  with gr.Column(scale=1):
63
- sample_dropdown = gr.Dropdown(
64
- choices=list(snippets.keys()),
65
- label="Select Sample",
66
- value=default_sample_name
67
- )
68
- input_text = gr.Textbox(
69
- label="Input Text",
70
- lines=5,
71
- value=default_sample.text_input
72
- )
73
  entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
74
  predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
75
- submit_btn = gr.Button("Extract Knowledge Graph")
 
 
 
76
  with gr.Column(scale=2):
77
  output_graph = gr.Plot(label="Knowledge Graph")
78
  error_message = gr.Textbox(label="Textual Output")
79
 
80
- sample_dropdown.change(
81
- update_inputs,
82
- inputs=[sample_dropdown],
83
- outputs=[input_text, entity_types, predicates]
84
- )
 
 
 
 
 
85
 
86
- submit_btn.click(
87
- process_text,
88
- inputs=[input_text, entity_types, predicates],
89
- outputs=[output_graph, error_message],
90
- )
 
 
 
 
 
 
 
 
91
 
92
  if __name__ == "__main__":
93
- demo.launch()
 
 
1
  import random
2
 
3
  import gradio as gr
4
+ import networkx as nx
5
 
6
  from lib.graph_extract import triplextract, parse_triples
7
+ from lib.visualize import create_graph, create_bokeh_plot, create_plotly_plot
8
  from lib.samples import snippets
9
 
10
  WORD_LIMIT = 300
11
 
12
+ def process_text(text, entity_types, predicates, layout_type, visualization_type):
13
  if not text:
14
+ return None, None, "Please enter some text."
15
 
16
  words = text.split()
17
  if len(words) > WORD_LIMIT:
18
+ return None, None, f"Please limit your input to {WORD_LIMIT} words. Current word count: {len(words)}"
19
 
20
  entity_types = [et.strip() for et in entity_types.split(",") if et.strip()]
21
  predicates = [p.strip() for p in predicates.split(",") if p.strip()]
22
 
23
  if not entity_types:
24
+ return None, None, "Please enter at least one entity type."
25
  if not predicates:
26
+ return None, None, "Please enter at least one predicate."
27
 
28
  try:
29
  prediction = triplextract(text, entity_types, predicates)
30
  if prediction.startswith("Error"):
31
+ return None, None, prediction
32
 
33
  entities, relationships = parse_triples(prediction)
34
 
35
  if not entities and not relationships:
36
+ return None, None, "No entities or relationships found. Try different text or check your input."
37
+
38
+ G = create_graph(entities, relationships)
39
+
40
+ if visualization_type == 'Bokeh':
41
+ fig = create_bokeh_plot(G, layout_type)
42
+ else:
43
+ fig = create_plotly_plot(G, layout_type)
44
+
45
+ output_text = f"Entities: {entities}\nRelationships: {relationships}\n\nRaw output:\n{prediction}"
46
+ return G, fig, output_text
47
  except Exception as e:
48
+ print(f"Error in process_text: {str(e)}")
49
+ return None, None, f"An error occurred: {str(e)}"
50
+
51
+ def update_graph(G, layout_type, visualization_type):
52
+ if G is None:
53
+ return None, "Please process text first."
54
+
55
+ try:
56
+ if visualization_type == 'Bokeh':
57
+ fig = create_bokeh_plot(G, layout_type)
58
+ else:
59
+ fig = create_plotly_plot(G, layout_type)
60
+ return fig, ""
61
+ except Exception as e:
62
+ print(f"Error in update_graph: {e}")
63
+ return None, f"An error occurred while updating the graph: {str(e)}"
64
 
65
  def update_inputs(sample_name):
66
  sample = snippets[sample_name]
 
74
 
75
  with gr.Row():
76
  with gr.Column(scale=1):
77
+ sample_dropdown = gr.Dropdown(choices=list(snippets.keys()), label="Select Sample", value=default_sample_name)
78
+ input_text = gr.Textbox(label="Input Text", lines=5, value=default_sample.text_input)
 
 
 
 
 
 
 
 
79
  entity_types = gr.Textbox(label="Entity Types", value=default_sample.entity_types)
80
  predicates = gr.Textbox(label="Predicates", value=default_sample.predicates)
81
+ layout_type = gr.Dropdown(choices=['spring', 'fruchterman_reingold', 'circular', 'random', 'spectral', 'shell'],
82
+ label="Layout Type", value='spring')
83
+ visualization_type = gr.Radio(choices=['Bokeh', 'Plotly'], label="Visualization Type", value='Bokeh')
84
+ process_btn = gr.Button("Process Text")
85
  with gr.Column(scale=2):
86
  output_graph = gr.Plot(label="Knowledge Graph")
87
  error_message = gr.Textbox(label="Textual Output")
88
 
89
+ graph_state = gr.State(None)
90
+
91
+ def process_and_update(text, entity_types, predicates, layout_type, visualization_type):
92
+ G, fig, output = process_text(text, entity_types, predicates, layout_type, visualization_type)
93
+ return G, fig, output
94
+
95
+ def update_graph_wrapper(G, layout_type, visualization_type):
96
+ if G is not None:
97
+ fig, _ = update_graph(G, layout_type, visualization_type)
98
+ return fig
99
 
100
+ sample_dropdown.change(update_inputs, inputs=[sample_dropdown], outputs=[input_text, entity_types, predicates])
101
+
102
+ process_btn.click(process_and_update,
103
+ inputs=[input_text, entity_types, predicates, layout_type, visualization_type],
104
+ outputs=[graph_state, output_graph, error_message])
105
+
106
+ layout_type.change(update_graph_wrapper,
107
+ inputs=[graph_state, layout_type, visualization_type],
108
+ outputs=[output_graph])
109
+
110
+ visualization_type.change(update_graph_wrapper,
111
+ inputs=[graph_state, layout_type, visualization_type],
112
+ outputs=[output_graph])
113
 
114
  if __name__ == "__main__":
115
+ demo.launch(share=True)
lib/visualize.py CHANGED
@@ -1,30 +1,55 @@
1
  import plotly.graph_objects as go
2
  import networkx as nx
3
-
4
- import networkx as nx
5
  from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
6
  Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
7
  from bokeh.palettes import Spectral4
8
  from bokeh.plotting import from_networkx
9
 
10
- def create_bokeh_plot(entities, relationships):
11
- # Create a NetworkX graph
12
  G = nx.Graph()
13
  for entity_id, entity_data in entities.items():
14
- G.add_node(entity_id, label=f"{entity_data['value']} ({entity_data['type']})")
 
15
  for source, relation, target in relationships:
16
  G.add_edge(source, target, label=relation)
17
-
18
- plot = Plot(width=600, height=600, # Increased size for better visibility
 
 
 
 
 
 
 
 
 
 
 
19
  x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
20
  plot.title.text = "Knowledge Graph Interaction"
21
 
22
- # Use tooltips to show node and edge labels on hover
23
  node_hover = HoverTool(tooltips=[("Entity", "@label")])
24
  edge_hover = HoverTool(tooltips=[("Relation", "@label")])
25
  plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
26
 
27
- graph_renderer = from_networkx(G, nx.spring_layout, scale=1,k=0.5, iterations=50, center=(0, 0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
30
  graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
@@ -48,9 +73,7 @@ def create_bokeh_plot(entities, relationships):
48
  plot.renderers.append(labels)
49
 
50
  # Add edge labels
51
- edge_x = []
52
- edge_y = []
53
- edge_labels = []
54
  for (start_node, end_node, label) in G.edges(data='label'):
55
  start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
56
  end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
@@ -65,69 +88,30 @@ def create_bokeh_plot(entities, relationships):
65
  plot.renderers.append(edge_labels)
66
 
67
  return plot
68
-
69
- # def create_bokeh_plot(entities, relationships):
70
- # # Create a NetworkX graph
71
- # G = nx.Graph()
72
- # for entity_id, entity_data in entities.items():
73
- # G.add_node(entity_id, **entity_data)
74
- # for source, relation, target in relationships:
75
- # G.add_edge(source, target)
76
-
77
- # # Create a Bokeh plot
78
- # plot = figure(title="Knowledge Graph", x_range=(-1.1,1.1), y_range=(-1.1,1.1),
79
- # width=400, height=400, tools="pan,wheel_zoom,box_zoom,reset")
80
 
81
- # # Create graph renderer
82
- # graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0,0))
83
-
84
- # # Add graph renderer to plot
85
- # plot.renderers.append(graph_renderer)
86
-
87
- # return plot
88
-
89
- def create_plotly_plot(entities, relationships):
90
- G = nx.DiGraph() # Use DiGraph for directed edges
91
-
92
- for entity_id, entity_data in entities.items():
93
- G.add_node(entity_id, **entity_data)
94
-
95
- for source, relation, target in relationships:
96
- G.add_edge(source, target, relation=relation)
97
-
98
- pos = nx.spring_layout(G, k=0.5, iterations=50) # Adjust layout parameters
99
-
100
- edge_trace = go.Scatter(
101
- x=[],
102
- y=[],
103
- line=dict(width=1, color="#888"),
104
- hoverinfo="text",
105
- mode="lines",
106
- text=[],
107
- )
108
-
109
- node_trace = go.Scatter(
110
- x=[],
111
- y=[],
112
- mode="markers+text",
113
- hoverinfo="text",
114
- marker=dict(
115
- showscale=True,
116
- colorscale="Viridis",
117
- reversescale=True,
118
- color=[],
119
- size=15,
120
- colorbar=dict(
121
- thickness=15,
122
- title="Node Connections",
123
- xanchor="left",
124
- titleside="right",
125
- ),
126
- line_width=2,
127
- ),
128
- text=[],
129
- textposition="top center",
130
- )
131
 
132
  edge_labels = []
133
 
@@ -137,57 +121,26 @@ def create_plotly_plot(entities, relationships):
137
  edge_trace["x"] += (x0, x1, None)
138
  edge_trace["y"] += (y0, y1, None)
139
 
140
- # Calculate midpoint for edge label
141
  mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
142
- edge_labels.append(
143
- go.Scatter(
144
- x=[mid_x],
145
- y=[mid_y],
146
- mode="text",
147
- text=[G.edges[edge]["relation"]],
148
- textposition="middle center",
149
- hoverinfo="none",
150
- showlegend=False,
151
- textfont=dict(size=8),
152
- )
153
- )
154
 
155
  for node in G.nodes():
156
  x, y = pos[node]
157
  node_trace["x"] += (x,)
158
  node_trace["y"] += (y,)
159
- node_info = f"{entities[node]['value']} ({entities[node]['type']})"
160
- node_trace["text"] += (node_info,)
161
  node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
162
 
163
- fig = go.Figure(
164
- data=[edge_trace, node_trace] + edge_labels,
165
- layout=go.Layout(
166
- title="Knowledge Graph",
167
- titlefont_size=16,
168
- showlegend=False,
169
- hovermode="closest",
170
- margin=dict(b=20, l=5, r=5, t=40),
171
- annotations=[],
172
- xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
173
- yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
174
- width=800,
175
- height=600,
176
- ),
177
- )
178
-
179
- # Enable dragging of nodes
180
- fig.update_layout(
181
- newshape=dict(line_color="#009900"),
182
- # Enable zoom
183
- xaxis=dict(
184
- scaleanchor="y",
185
- scaleratio=1,
186
- ),
187
- yaxis=dict(
188
- scaleanchor="x",
189
- scaleratio=1,
190
- ),
191
- )
192
 
193
  return fig
 
1
  import plotly.graph_objects as go
2
  import networkx as nx
3
+ import numpy as np
 
4
  from bokeh.models import (BoxSelectTool, HoverTool, MultiLine, NodesAndLinkedEdges,
5
  Plot, Range1d, Scatter, TapTool, LabelSet, ColumnDataSource)
6
  from bokeh.palettes import Spectral4
7
  from bokeh.plotting import from_networkx
8
 
9
+ def create_graph(entities, relationships):
 
10
  G = nx.Graph()
11
  for entity_id, entity_data in entities.items():
12
+ G.add_node(entity_id, label=f"{entity_data.get('value', 'Unknown')} ({entity_data.get('type', 'Unknown')})")
13
+
14
  for source, relation, target in relationships:
15
  G.add_edge(source, target, label=relation)
16
+
17
+ return G
18
+
19
+ def improved_spectral_layout(G, scale=1):
20
+ pos = nx.spectral_layout(G)
21
+ # Add some random noise to prevent overlapping
22
+ pos = {node: (x + np.random.normal(0, 0.1), y + np.random.normal(0, 0.1)) for node, (x, y) in pos.items()}
23
+ # Scale the layout
24
+ pos = {node: (x * scale, y * scale) for node, (x, y) in pos.items()}
25
+ return pos
26
+
27
+ def create_bokeh_plot(G, layout_type='spring'):
28
+ plot = Plot(width=600, height=600,
29
  x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2))
30
  plot.title.text = "Knowledge Graph Interaction"
31
 
 
32
  node_hover = HoverTool(tooltips=[("Entity", "@label")])
33
  edge_hover = HoverTool(tooltips=[("Relation", "@label")])
34
  plot.add_tools(node_hover, edge_hover, TapTool(), BoxSelectTool())
35
 
36
+ # Create layout based on layout_type
37
+ if layout_type == 'spring':
38
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
39
+ elif layout_type == 'fruchterman_reingold':
40
+ pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
41
+ elif layout_type == 'circular':
42
+ pos = nx.circular_layout(G)
43
+ elif layout_type == 'random':
44
+ pos = nx.random_layout(G)
45
+ elif layout_type == 'spectral':
46
+ pos = improved_spectral_layout(G)
47
+ elif layout_type == 'shell':
48
+ pos = nx.shell_layout(G)
49
+ else:
50
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
51
+
52
+ graph_renderer = from_networkx(G, pos, scale=1, center=(0, 0))
53
 
54
  graph_renderer.node_renderer.glyph = Scatter(size=15, fill_color=Spectral4[0])
55
  graph_renderer.node_renderer.selection_glyph = Scatter(size=15, fill_color=Spectral4[2])
 
73
  plot.renderers.append(labels)
74
 
75
  # Add edge labels
76
+ edge_x, edge_y, edge_labels = [], [], []
 
 
77
  for (start_node, end_node, label) in G.edges(data='label'):
78
  start_x, start_y = graph_renderer.layout_provider.graph_layout[start_node]
79
  end_x, end_y = graph_renderer.layout_provider.graph_layout[end_node]
 
88
  plot.renderers.append(edge_labels)
89
 
90
  return plot
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ def create_plotly_plot(G, layout_type='spring'):
93
+ # Create layout based on layout_type
94
+ if layout_type == 'spring':
95
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
96
+ elif layout_type == 'fruchterman_reingold':
97
+ pos = nx.fruchterman_reingold_layout(G, k=0.5, iterations=50)
98
+ elif layout_type == 'circular':
99
+ pos = nx.circular_layout(G)
100
+ elif layout_type == 'random':
101
+ pos = nx.random_layout(G)
102
+ elif layout_type == 'spectral':
103
+ pos = improved_spectral_layout(G)
104
+ elif layout_type == 'shell':
105
+ pos = nx.shell_layout(G)
106
+ else:
107
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
108
+
109
+ edge_trace = go.Scatter(x=[], y=[], line=dict(width=1, color="#888"), hoverinfo="text", mode="lines", text=[])
110
+ node_trace = go.Scatter(x=[], y=[], mode="markers+text", hoverinfo="text",
111
+ marker=dict(showscale=True, colorscale="Viridis", reversescale=True, color=[], size=15,
112
+ colorbar=dict(thickness=15, title="Node Connections", xanchor="left", titleside="right"),
113
+ line_width=2),
114
+ text=[], textposition="top center")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  edge_labels = []
117
 
 
121
  edge_trace["x"] += (x0, x1, None)
122
  edge_trace["y"] += (y0, y1, None)
123
 
 
124
  mid_x, mid_y = (x0 + x1) / 2, (y0 + y1) / 2
125
+ edge_labels.append(go.Scatter(x=[mid_x], y=[mid_y], mode="text", text=[G.edges[edge]["label"]],
126
+ textposition="middle center", hoverinfo="none", showlegend=False, textfont=dict(size=8)))
 
 
 
 
 
 
 
 
 
 
127
 
128
  for node in G.nodes():
129
  x, y = pos[node]
130
  node_trace["x"] += (x,)
131
  node_trace["y"] += (y,)
132
+ node_trace["text"] += (G.nodes[node]["label"],)
 
133
  node_trace["marker"]["color"] += (len(list(G.neighbors(node))),)
134
 
135
+ fig = go.Figure(data=[edge_trace, node_trace] + edge_labels,
136
+ layout=go.Layout(title="Knowledge Graph", titlefont_size=16, showlegend=False, hovermode="closest",
137
+ margin=dict(b=20, l=5, r=5, t=40), annotations=[],
138
+ xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
139
+ yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
140
+ width=800, height=600))
141
+
142
+ fig.update_layout(newshape=dict(line_color="#009900"),
143
+ xaxis=dict(scaleanchor="y", scaleratio=1),
144
+ yaxis=dict(scaleanchor="x", scaleratio=1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  return fig