Zaherrr commited on
Commit
5ce695c
1 Parent(s): 6528e50

Upload 2 files

Browse files
Files changed (2) hide show
  1. graph_visualization.py +61 -0
  2. main.py +84 -0
graph_visualization.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pyvis.network import Network
3
+
4
+
5
+ def create_graph(nodes, edges, physics_enabled=True):
6
+ net = Network(
7
+ notebook=True,
8
+ height="100vh",
9
+ width="100vw",
10
+ bgcolor="#222222",
11
+ font_color="white",
12
+ cdn_resources="remote",
13
+ )
14
+
15
+ for node in nodes:
16
+ net.add_node(
17
+ node["id"],
18
+ label=node["label"],
19
+ title=node["label"],
20
+ color="blue" if node["label"] == "OOP" else "green",
21
+ )
22
+
23
+ for edge in edges:
24
+ net.add_edge(edge["source"], edge["target"], title=edge["type"])
25
+
26
+ net.force_atlas_2based(
27
+ gravity=-50,
28
+ central_gravity=0.01,
29
+ spring_length=100,
30
+ spring_strength=0.08,
31
+ damping=0.4,
32
+ )
33
+
34
+ options = {
35
+ "nodes": {"physics": physics_enabled},
36
+ "edges": {"smooth": True},
37
+ "interaction": {"hover": True, "zoomView": True},
38
+ "physics": {
39
+ "enabled": physics_enabled,
40
+ "stabilization": {"enabled": True, "iterations": 200},
41
+ },
42
+ }
43
+
44
+ net.set_options(json.dumps(options))
45
+ return net
46
+
47
+
48
+ def visualize_graph(json_data, physics_enabled=True):
49
+ if isinstance(json_data, str):
50
+ data = json.loads(json_data)
51
+ else:
52
+ data = json_data
53
+ nodes = data["nodes"]
54
+ edges = data["edges"]
55
+ net = create_graph(nodes, edges, physics_enabled)
56
+ html = net.generate_html()
57
+ html = html.replace("'", '"')
58
+ html = html.replace(
59
+ '<div id="mynetwork"', '<div id="mynetwork" style="height: 100vh; width: 100%;"'
60
+ )
61
+ return f"""<iframe style="width: 100%; height: 100vh; border: none; margin: 0; padding: 0;" srcdoc='{html}'></iframe>"""
main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_from_disk
3
+ from PIL import Image
4
+ import io
5
+ import base64
6
+ import json
7
+
8
+ from graph_visualization import visualize_graph
9
+
10
+ # Load the dataset
11
+ dataset = load_from_disk(
12
+ "/home/zaher/Projects/ZAKA_CAPSTONE_PROJECT/notebooks/OOP_KG_Dataset"
13
+ )
14
+
15
+
16
+ def reshape_json_data_to_fit_visualize_graph(graph_data):
17
+
18
+ nodes = graph_data["nodes"]
19
+ edges = graph_data["edges"]
20
+
21
+ transformed_nodes = [
22
+ {"id": nodes["id"][idx], "label": nodes["label"][idx]}
23
+ for idx in range(len(nodes["id"]))
24
+ ]
25
+
26
+ transformed_edges = [
27
+ {"source": edges["source"][idx], "target": edges["target"][idx], "type": "->"}
28
+ for idx in range(len(edges["source"]))
29
+ ]
30
+
31
+ # print(f"transformed nodes = {transformed_nodes}")
32
+
33
+ graph_data = {"nodes": transformed_nodes, "edges": transformed_edges}
34
+
35
+ return graph_data
36
+
37
+
38
+ def display_example(index):
39
+ example = dataset[index]
40
+ # print("This is the example: ")
41
+ # print(example)
42
+ # Get the image
43
+ img = example["image"]
44
+
45
+ # Prepare the graph data
46
+ graph_data = {"nodes": example["nodes"], "edges": example["edges"]}
47
+
48
+ # # Convert graph_data to JSON string
49
+ # json_data = json.dumps(graph_data)
50
+ transformed_graph_data = reshape_json_data_to_fit_visualize_graph(graph_data)
51
+
52
+ # print(json_data)
53
+ # Generate the graph visualization
54
+ graph_html = visualize_graph(transformed_graph_data)
55
+
56
+ return img, graph_html
57
+
58
+
59
+ def create_interface():
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("# Knowledge Graph Visualizer")
62
+
63
+ with gr.Row():
64
+ index_slider = gr.Slider(
65
+ minimum=0, maximum=len(dataset) - 1, step=1, label="Example Index"
66
+ )
67
+
68
+ with gr.Row():
69
+ image_output = gr.Image(type="pil", label="Image")
70
+ graph_output = gr.HTML(label="Knowledge Graph")
71
+
72
+ index_slider.change(
73
+ fn=display_example,
74
+ inputs=[index_slider],
75
+ outputs=[image_output, graph_output],
76
+ )
77
+
78
+ return demo
79
+
80
+
81
+ # Create and launch the interface
82
+ if __name__ == "__main__":
83
+ demo = create_interface()
84
+ demo.launch(debug=True)