Spaces:
Running
Running
aldan.creo
commited on
Commit
•
cdd672b
0
Parent(s):
First commit
Browse files- .env.template +3 -0
- .gitattributes +21 -0
- .gitignore +3 -0
- README.md +10 -0
- app.py +370 -0
- explore.ipynb +536 -0
- institutions.csv +0 -0
- model/.data-00000-of-00001 +3 -0
- model/.index +3 -0
- model/checkpoint +3 -0
- model/model_metadata.ampkl +3 -0
- requirements.txt +131 -0
- test.csv +3 -0
- test.py +12 -0
- train.csv +3 -0
- train.py +215 -0
- universities.ttl +3 -0
- universities_large.ttl +3 -0
- universities_large_1200.ttl +3 -0
- universities_large_4200.ttl +3 -0
- universities_large_4300.ttl +3 -0
- valid.csv +3 -0
.env.template
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:876dccf69e76c840d490e31392e63a465919635c812c2113cbc9c445f8af616b
|
3 |
+
size 9
|
.gitattributes
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
universities.ttl filter=lfs diff=lfs merge=lfs -text
|
2 |
+
model/checkpoint filter=lfs diff=lfs merge=lfs -text
|
3 |
+
model/model_metadata.ampkl filter=lfs diff=lfs merge=lfs -text
|
4 |
+
model/.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
model/.index filter=lfs diff=lfs merge=lfs -text
|
6 |
+
model/ filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.template filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.gitignore filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.ttl filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.index filter=lfs diff=lfs merge=lfs -text
|
12 |
+
/model/checkpoint filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.ampkl filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
15 |
+
universities_large_1200.ttl filter=lfs diff=lfs merge=lfs -text
|
16 |
+
universities_large_4300.ttl filter=lfs diff=lfs merge=lfs -text
|
17 |
+
universities_large_4200.ttl filter=lfs diff=lfs merge=lfs -text
|
18 |
+
universities_large.ttl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
test.csv filter=lfs diff=lfs merge=lfs -text
|
20 |
+
train.csv filter=lfs diff=lfs merge=lfs -text
|
21 |
+
valid.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51864a12e4a74f37f6cd4399f8de2d35f8c3b60179ae94d03de408d9366411b6
|
3 |
+
size 35
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Universities Explorer
|
3 |
+
emoji: 🏢
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: blue
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.8.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import rdflib
|
7 |
+
import seaborn as sns
|
8 |
+
import tensorflow as tf
|
9 |
+
from adjustText import adjust_text
|
10 |
+
from ampligraph.latent_features import ScoringBasedEmbeddingModel
|
11 |
+
from ampligraph.utils import restore_model
|
12 |
+
from sklearn.cluster import KMeans
|
13 |
+
from sklearn.decomposition import PCA
|
14 |
+
import logging
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
# Start timer, count time to load graph
|
20 |
+
start_time = tf.timestamp()
|
21 |
+
|
22 |
+
g = rdflib.Graph()
|
23 |
+
uri = "urn:acmcmc:unis:"
|
24 |
+
unis = rdflib.Namespace(uri)
|
25 |
+
g.bind("unis", unis)
|
26 |
+
g.parse("universities.ttl", format="turtle")
|
27 |
+
|
28 |
+
# End timer
|
29 |
+
end_time = tf.timestamp()
|
30 |
+
logger.info("Graph loaded in {} seconds".format(end_time - start_time))
|
31 |
+
|
32 |
+
# model = restore_model("model.pkl")
|
33 |
+
|
34 |
+
# Start timer, count time to load model
|
35 |
+
start_time = tf.timestamp()
|
36 |
+
model = ScoringBasedEmbeddingModel(k=150, eta=10, scoring_type="ComplEx")
|
37 |
+
model.load_metadata("model/model")
|
38 |
+
model.build_full_model()
|
39 |
+
super(ScoringBasedEmbeddingModel, model).load_weights("model/")
|
40 |
+
# End timer
|
41 |
+
end_time = tf.timestamp()
|
42 |
+
logger.info("Model loaded in {} seconds".format(end_time - start_time))
|
43 |
+
|
44 |
+
|
45 |
+
def separate_concepts(concepts):
|
46 |
+
concept_list = concepts.split(",")
|
47 |
+
# Trim the strings
|
48 |
+
concept_list = [x.strip() for x in concept_list]
|
49 |
+
return concept_list
|
50 |
+
|
51 |
+
|
52 |
+
def pca(embeddings):
|
53 |
+
pca = PCA(n_components=2)
|
54 |
+
pca.fit(embeddings)
|
55 |
+
entity_embeddings_pca = pca.transform(embeddings)
|
56 |
+
return entity_embeddings_pca
|
57 |
+
|
58 |
+
|
59 |
+
def cluster(embeddings):
|
60 |
+
clustering_algorithm = KMeans(n_clusters=6, n_init=50, max_iter=500, random_state=0)
|
61 |
+
clusters = clustering_algorithm.fit_predict(embeddings)
|
62 |
+
return clusters
|
63 |
+
|
64 |
+
|
65 |
+
def get_concept_name(concept_uri):
|
66 |
+
"""
|
67 |
+
Get the name of the concept from the URI
|
68 |
+
"""
|
69 |
+
results = g.query(
|
70 |
+
f"""SELECT ?name
|
71 |
+
WHERE {{
|
72 |
+
<{concept_uri}> <urn:acmcmc:unis:name> ?name .
|
73 |
+
}}"""
|
74 |
+
)
|
75 |
+
return pd.DataFrame(results)[0][0]
|
76 |
+
|
77 |
+
|
78 |
+
def get_similarities_to_node(array_of_triples, model):
|
79 |
+
"""
|
80 |
+
Calculate the similarity between the embeddings of a node and a list of other nodes
|
81 |
+
"""
|
82 |
+
# Cosine similarity using tensorflow
|
83 |
+
indexes = model.get_indexes(array_of_triples)
|
84 |
+
scores = model(indexes)
|
85 |
+
return scores
|
86 |
+
|
87 |
+
|
88 |
+
def process_user_input_concept(concept_chooser):
|
89 |
+
"""
|
90 |
+
The user input is the URI of the concept. Get the similarites between the concept and the institutions
|
91 |
+
"""
|
92 |
+
all_ids_institutions = np.loadtxt(
|
93 |
+
"institutions.csv", delimiter=",", skiprows=1, dtype=str, quotechar='"'
|
94 |
+
)
|
95 |
+
# Remove duplicates based on the first column
|
96 |
+
all_ids_institutions = all_ids_institutions[
|
97 |
+
~pd.DataFrame(all_ids_institutions).duplicated(0)
|
98 |
+
]
|
99 |
+
|
100 |
+
chosen_concepts = separate_concepts(concept_chooser)
|
101 |
+
all_similarities = []
|
102 |
+
for concept in chosen_concepts:
|
103 |
+
s = all_ids_institutions[:, 0]
|
104 |
+
p = np.array(["urn:acmcmc:unis:institution_related_to_concept"] * len(s))
|
105 |
+
o = np.array([concept] * len(s))
|
106 |
+
|
107 |
+
array_of_triples = np.array([s, p, o]).T
|
108 |
+
|
109 |
+
scores = get_similarities_to_node(array_of_triples, model)
|
110 |
+
all_similarities.append(scores)
|
111 |
+
|
112 |
+
# Now, average the similarities
|
113 |
+
scores = np.stack(all_similarities, axis=0)
|
114 |
+
scores = np.mean(all_similarities, axis=0)
|
115 |
+
|
116 |
+
table_df = pd.DataFrame(
|
117 |
+
{
|
118 |
+
"institution": s,
|
119 |
+
"similarity": scores.flatten(),
|
120 |
+
"institution_name": all_ids_institutions[:, 1],
|
121 |
+
# "num_articles": all_ids_institutions[:, 2].astype(int),
|
122 |
+
}
|
123 |
+
)
|
124 |
+
# Sort by number of articles
|
125 |
+
table_df = table_df.sort_values(by=["similarity"], ascending=False)
|
126 |
+
concept_names = [get_concept_name(concept_uri) for concept_uri in chosen_concepts]
|
127 |
+
return (
|
128 |
+
table_df,
|
129 |
+
gr.update(visible=True),
|
130 |
+
gr.update(visible=True),
|
131 |
+
gr.update(visible=True),
|
132 |
+
f'Concept names: {", ".join(concept_names)}',
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
def calculate_emdeddings_and_pca(table):
|
137 |
+
gr.Info("Performing PCA and clustering...")
|
138 |
+
# Perform PCA
|
139 |
+
embeddings_of_institutions = model.get_embeddings(
|
140 |
+
entities=np.array(table["institution"])
|
141 |
+
)
|
142 |
+
|
143 |
+
entity_embeddings_pca = pca(embeddings_of_institutions)
|
144 |
+
|
145 |
+
# Perform clustering
|
146 |
+
clusters = cluster(embeddings_of_institutions)
|
147 |
+
|
148 |
+
plot_df = pd.DataFrame(
|
149 |
+
{
|
150 |
+
"embedding1": entity_embeddings_pca[:, 0],
|
151 |
+
"embedding2": entity_embeddings_pca[:, 1],
|
152 |
+
"cluster": "cluster" + pd.Series(clusters).astype(str),
|
153 |
+
}
|
154 |
+
)
|
155 |
+
|
156 |
+
# Toast message
|
157 |
+
gr.Info("PCA and clustering done!")
|
158 |
+
return plot_df
|
159 |
+
|
160 |
+
|
161 |
+
def click_on_institution(table, embeddings_var, evt: gr.SelectData):
|
162 |
+
institution_id = table["institution"][evt.index[0]]
|
163 |
+
try:
|
164 |
+
embeddings_df = embeddings_var["embeddings_df"]
|
165 |
+
plot_df = pd.DataFrame(
|
166 |
+
{
|
167 |
+
"institution": table["institution"].values,
|
168 |
+
"institution_name": table["institution_name"].values,
|
169 |
+
"embedding1": embeddings_df["embedding1"].values,
|
170 |
+
"embedding2": embeddings_df["embedding2"].values,
|
171 |
+
"cluster": embeddings_df["cluster"].values,
|
172 |
+
# "num_articles": table["num_articles"].values,
|
173 |
+
}
|
174 |
+
)
|
175 |
+
return plot_embeddings(plot_df, institution_id)
|
176 |
+
except:
|
177 |
+
pass
|
178 |
+
|
179 |
+
|
180 |
+
def click_on_show_plot(table):
|
181 |
+
embeddings_df = calculate_emdeddings_and_pca(table)
|
182 |
+
|
183 |
+
plot_df = pd.DataFrame(
|
184 |
+
{
|
185 |
+
"institution": table["institution"].values,
|
186 |
+
"institution_name": table["institution_name"].values,
|
187 |
+
"embedding1": embeddings_df["embedding1"].values,
|
188 |
+
"embedding2": embeddings_df["embedding2"].values,
|
189 |
+
"cluster": embeddings_df["cluster"].values,
|
190 |
+
# "num_articles": table["num_articles"].values,
|
191 |
+
}
|
192 |
+
)
|
193 |
+
fig = plot_embeddings(plot_df, None)
|
194 |
+
|
195 |
+
return fig, {"embeddings_df": plot_df}
|
196 |
+
|
197 |
+
|
198 |
+
def plot_embeddings(plot_df, institution_id):
|
199 |
+
fig = plt.figure(figsize=(12, 12))
|
200 |
+
np.random.seed(0)
|
201 |
+
# fig.title("{} embeddings".format(parameter).capitalize())
|
202 |
+
ax = sns.scatterplot(
|
203 |
+
data=plot_df,
|
204 |
+
x="embedding1",
|
205 |
+
y="embedding2",
|
206 |
+
hue="cluster",
|
207 |
+
)
|
208 |
+
|
209 |
+
row_of_institution = plot_df[plot_df["institution"] == institution_id]
|
210 |
+
if not row_of_institution.empty:
|
211 |
+
ax.text(
|
212 |
+
row_of_institution["embedding1"],
|
213 |
+
row_of_institution["embedding2"],
|
214 |
+
row_of_institution["institution_name"].values[0],
|
215 |
+
horizontalalignment="left",
|
216 |
+
size="medium",
|
217 |
+
color="black",
|
218 |
+
weight="normal",
|
219 |
+
)
|
220 |
+
# Also draw a point for the institution
|
221 |
+
ax.scatter(
|
222 |
+
row_of_institution["embedding1"],
|
223 |
+
row_of_institution["embedding2"],
|
224 |
+
color="black",
|
225 |
+
s=100,
|
226 |
+
marker="x",
|
227 |
+
)
|
228 |
+
# texts = []
|
229 |
+
# for i, point in plot_df.iterrows():
|
230 |
+
# if point["institution"] == institution_id:
|
231 |
+
# texts.append(
|
232 |
+
# fig.text(
|
233 |
+
# point["embedding1"] + 0.02,
|
234 |
+
# point["embedding2"] + 0.01,
|
235 |
+
# str(point["institution_name"]),
|
236 |
+
# )
|
237 |
+
# )
|
238 |
+
# adjust_text(texts)
|
239 |
+
return fig
|
240 |
+
|
241 |
+
|
242 |
+
def get_authors_of_institution(institutions_table, concept_chooser, evt: gr.SelectData):
|
243 |
+
"""
|
244 |
+
Get the authors of an institution
|
245 |
+
"""
|
246 |
+
institution = institutions_table["institution"][0]
|
247 |
+
number_of_row = evt.index[0]
|
248 |
+
institution = institutions_table["institution"][number_of_row]
|
249 |
+
concepts = separate_concepts(concept_chooser)
|
250 |
+
results_dfs = []
|
251 |
+
for concept in concepts:
|
252 |
+
# Create a dataframe of the authors and the number of articles they have written for each concept
|
253 |
+
result = g.query(
|
254 |
+
f"""SELECT ?author ?name (COUNT (?article) AS ?num_articles)
|
255 |
+
WHERE {{
|
256 |
+
?author a <urn:acmcmc:unis:Author> .
|
257 |
+
?author <urn:acmcmc:unis:name> ?name .
|
258 |
+
?article <urn:acmcmc:unis:written_in_institution> <{institution}> .
|
259 |
+
?article <urn:acmcmc:unis:has_author> ?author .
|
260 |
+
?article <urn:acmcmc:unis:related_to_concept> <{concept}> .
|
261 |
+
}}
|
262 |
+
GROUP BY ?author ?name
|
263 |
+
ORDER BY DESC(COUNT (?article))
|
264 |
+
"""
|
265 |
+
)
|
266 |
+
result_df = pd.DataFrame(result)
|
267 |
+
result_df.columns = ["author", "name", "num_articles"]
|
268 |
+
results_dfs.append(result_df)
|
269 |
+
# Now, aggregate the results into a single dataframe by summing the number of articles
|
270 |
+
results_df = pd.concat(results_dfs)
|
271 |
+
results_df = results_df.groupby(["author", "name"]).sum().reset_index()
|
272 |
+
# Sort by number of articles
|
273 |
+
results_df = results_df.sort_values(by=["num_articles"], ascending=False)
|
274 |
+
return results_df, gr.update(visible=True)
|
275 |
+
|
276 |
+
|
277 |
+
# %%
|
278 |
+
theme = gr.themes.Default(primary_hue="cyan", secondary_hue="fuchsia")
|
279 |
+
|
280 |
+
with gr.Blocks(theme=theme) as demo:
|
281 |
+
embeddings_df = gr.State({})
|
282 |
+
# App title and description
|
283 |
+
title = gr.Markdown(
|
284 |
+
"""
|
285 |
+
# Universities Explorer
|
286 |
+
This app allows you to explore the institutions more closely related to a concept.
|
287 |
+
|
288 |
+
It uses embeddings of institutions and concepts to calculate the similarity between them. The embedding model, [ComplEx](https://doi.org/10.48550/arXiv.1606.06357), was trained using the [AmpliGraph](https://github.com/Accenture/AmpliGraph) library. The data comes from the [OpenAlex](https://openalex.org/) dataset, which contains information about scientific articles, authors, institutions, and concepts.
|
289 |
+
"""
|
290 |
+
)
|
291 |
+
with gr.Group() as institution_search:
|
292 |
+
concept_chooser = gr.Textbox(
|
293 |
+
label="Concept URI",
|
294 |
+
info="Using OpenAlex, find the URI of the concept you want to search for. For example, the URI of the concept 'Knowledge Graph' is https://openalex.org/C2987255567, while the URI of the concept 'Natural Language Processing' is https://openalex.org/C204321447. You can find the URI of a concept by searching for it on OpenAlex and copying the URL from the address bar. You can also search for multiple concepts by separating them with a comma.",
|
295 |
+
placeholder="https://openalex.org/C2987255567, https://openalex.org/C204321447",
|
296 |
+
value="https://openalex.org/C2987255567, https://openalex.org/C204321447",
|
297 |
+
)
|
298 |
+
concept_name_label = gr.Markdown("Concept name: ", visible=False)
|
299 |
+
# Table for name of institution and similarity to concept
|
300 |
+
btn_search_institutions = gr.Button("Search institutions", variant="primary")
|
301 |
+
table = gr.Dataframe(
|
302 |
+
interactive=False, visible=False, elem_classes="institutions", wrap=True
|
303 |
+
)
|
304 |
+
btn_search_institutions.click(
|
305 |
+
lambda: gr.update(visible=True), outputs=[table], queue=True
|
306 |
+
)
|
307 |
+
|
308 |
+
btn_plot_embeddings = gr.Button(
|
309 |
+
"Plot embeddings", variant="primary", visible=False, elem_classes="embeddings"
|
310 |
+
)
|
311 |
+
# Description of what plot embeddings does
|
312 |
+
plot_embeddings_info = gr.Markdown(
|
313 |
+
"""
|
314 |
+
This button will plot the embeddings of the institutions related to the concept. The embeddings are calculated using the trained model and then reduced to 2 dimensions using PCA. The institutions are then clustered using KMeans.
|
315 |
+
|
316 |
+
Running this may take a while, as we need to calculate the embeddings for all institutions and then perform PCA and clustering.
|
317 |
+
""",
|
318 |
+
visible=False,
|
319 |
+
)
|
320 |
+
btn_search_institutions.click(
|
321 |
+
process_user_input_concept,
|
322 |
+
inputs=[concept_chooser],
|
323 |
+
outputs=[
|
324 |
+
table,
|
325 |
+
btn_plot_embeddings,
|
326 |
+
plot_embeddings_info,
|
327 |
+
concept_name_label,
|
328 |
+
concept_name_label,
|
329 |
+
],
|
330 |
+
queue=True,
|
331 |
+
)
|
332 |
+
plot = gr.Plot(visible=False, elem_classes="embeddings")
|
333 |
+
btn_plot_embeddings.click(
|
334 |
+
lambda: gr.update(visible=True), outputs=[plot], queue=True
|
335 |
+
)
|
336 |
+
btn_plot_embeddings.click(
|
337 |
+
click_on_show_plot,
|
338 |
+
inputs=[table],
|
339 |
+
outputs=[plot, embeddings_df],
|
340 |
+
queue=True,
|
341 |
+
)
|
342 |
+
|
343 |
+
# When the user selects a row in the table, get the authors of that institution and display them in a dataframe
|
344 |
+
with gr.Group(visible=False, elem_classes="authors") as authors:
|
345 |
+
table_authors = gr.Dataframe(
|
346 |
+
interactive=False, label="Authors in institution writing about concept"
|
347 |
+
)
|
348 |
+
table.select(
|
349 |
+
get_authors_of_institution,
|
350 |
+
inputs=[table, concept_chooser],
|
351 |
+
outputs=[table_authors],
|
352 |
+
)
|
353 |
+
table.select(
|
354 |
+
click_on_institution,
|
355 |
+
inputs=[table, embeddings_df],
|
356 |
+
outputs=[plot],
|
357 |
+
)
|
358 |
+
|
359 |
+
btn_clear = gr.ClearButton(components=[table, plot, table_authors])
|
360 |
+
|
361 |
+
# Author information
|
362 |
+
author_info = gr.Markdown(
|
363 |
+
"""
|
364 |
+
This demo has been built by [Aldan Creo](
|
365 |
+
https://acmc-website.web.app/).
|
366 |
+
"""
|
367 |
+
)
|
368 |
+
|
369 |
+
demo.queue()
|
370 |
+
demo.launch()
|
explore.ipynb
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import pyalex\n",
|
10 |
+
"import dotenv\n",
|
11 |
+
"import os\n",
|
12 |
+
"from pyalex import Works, Authors, Sources, Institutions, Concepts, Publishers, Funders\n",
|
13 |
+
"import pandas as pd\n",
|
14 |
+
"from sklearn.model_selection import train_test_split\n",
|
15 |
+
"import numpy as np\n",
|
16 |
+
"from ampligraph.evaluation import train_test_split_no_unseen\n",
|
17 |
+
"\n",
|
18 |
+
"dotenv.load_dotenv()\n",
|
19 |
+
"\n",
|
20 |
+
"pyalex.config.email = os.getenv(\"MY_EMAIL\")"
|
21 |
+
]
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "code",
|
25 |
+
"execution_count": null,
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"knowledge_graphs = Concepts().search(\"knowledge graph\").count()"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"metadata": {},
|
36 |
+
"outputs": [],
|
37 |
+
"source": [
|
38 |
+
"import rdflib\n",
|
39 |
+
"\n",
|
40 |
+
"g = rdflib.Graph()\n",
|
41 |
+
"uri = \"urn:acmcmc:unis:\"\n",
|
42 |
+
"unis = rdflib.Namespace(uri)\n",
|
43 |
+
"g.bind(\"unis\", unis)\n",
|
44 |
+
"# g.parse(\"universities_large_1200.ttl\", format=\"turtle\")"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"def store_graph():\n",
|
54 |
+
" g.serialize(destination='universities_large.ttl', format='turtle')"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"articles = (\n",
|
64 |
+
" Works()\n",
|
65 |
+
" .search_filter(abstract=\"Large Language Model Knowledge Graph\")\n",
|
66 |
+
" .filter(authorships={\"institutions\": {\"country_code\": \"US\", \"type\": \"education\"}})\n",
|
67 |
+
")\n",
|
68 |
+
"articles = Works().filter(\n",
|
69 |
+
" concepts={\"id\": \"C2987255567|C204321447|C41008148\"},\n",
|
70 |
+
" # C2987255567: Knowledge Graph\n",
|
71 |
+
" # C204321447: Natural Language Processing\n",
|
72 |
+
" # C41008148 : Computer Science\n",
|
73 |
+
" authorships={\"institutions\": {\"country_code\": \"US\", \"type\": \"education\"}},\n",
|
74 |
+
").sort(publication_date=\"desc\")\n",
|
75 |
+
"print(f\"Found {articles.count()} articles. Fetching...\")\n",
|
76 |
+
"\n",
|
77 |
+
"if articles.count() > 1000:\n",
|
78 |
+
" print(\"Too many articles. Loading from file.\")\n",
|
79 |
+
" g.parse(\"universities_large_1200.ttl\", format=\"turtle\")\n",
|
80 |
+
"else:\n",
|
81 |
+
" all_articles = []\n",
|
82 |
+
" num_articles_concepts = {\n",
|
83 |
+
" \"https://openalex.org/C2987255567\": 0,\n",
|
84 |
+
" \"https://openalex.org/C204321447\": 0,\n",
|
85 |
+
" \"https://openalex.org/C41008148\": 0,\n",
|
86 |
+
" }\n",
|
87 |
+
" # Go through all pages\n",
|
88 |
+
" paginator = articles.paginate(per_page=200, n_max=1000000)\n",
|
89 |
+
" for i, page in enumerate(paginator):\n",
|
90 |
+
" print(f\"Processing page {i}\")\n",
|
91 |
+
" if i > 0 and i % 100 == 0:\n",
|
92 |
+
" store_graph()\n",
|
93 |
+
" for article in page:\n",
|
94 |
+
" all_articles.append(article)\n",
|
95 |
+
" article_uri = rdflib.URIRef(article[\"id\"])\n",
|
96 |
+
" g.add((article_uri, rdflib.RDF.type, unis.Article))\n",
|
97 |
+
" g.add((article_uri, unis.title, rdflib.Literal(article[\"title\"])))\n",
|
98 |
+
" # Related to is a list of ids\n",
|
99 |
+
" for related_to in article[\"related_works\"]:\n",
|
100 |
+
" g.add((article_uri, unis.related_to, rdflib.URIRef(related_to)))\n",
|
101 |
+
" for reference in article[\"referenced_works\"]:\n",
|
102 |
+
" g.add((article_uri, unis.references, rdflib.URIRef(reference)))\n",
|
103 |
+
" # Authors is a list of dicts\n",
|
104 |
+
" for author in article[\"authorships\"]:\n",
|
105 |
+
" author_uri = rdflib.URIRef(author[\"author\"][\"id\"])\n",
|
106 |
+
" g.add((author_uri, rdflib.RDF.type, unis.Author))\n",
|
107 |
+
" g.add(\n",
|
108 |
+
" (\n",
|
109 |
+
" author_uri,\n",
|
110 |
+
" unis.name,\n",
|
111 |
+
" rdflib.Literal(author[\"author\"][\"display_name\"]),\n",
|
112 |
+
" )\n",
|
113 |
+
" )\n",
|
114 |
+
" g.add((article_uri, unis.has_author, author_uri))\n",
|
115 |
+
" for institution in author[\"institutions\"]:\n",
|
116 |
+
" institution_uri = rdflib.URIRef(institution[\"id\"])\n",
|
117 |
+
" g.add((institution_uri, rdflib.RDF.type, unis.Institution))\n",
|
118 |
+
" # g.add((author_uri, unis.affiliated_to, institution_uri)) # Do not add this, because the author might be affiliated to multiple institutions at different times\n",
|
119 |
+
" g.add(\n",
|
120 |
+
" (\n",
|
121 |
+
" article_uri,\n",
|
122 |
+
" unis.written_in_institution,\n",
|
123 |
+
" institution_uri,\n",
|
124 |
+
" )\n",
|
125 |
+
" )\n",
|
126 |
+
" g.add(\n",
|
127 |
+
" (\n",
|
128 |
+
" institution_uri,\n",
|
129 |
+
" unis.country,\n",
|
130 |
+
" rdflib.Literal(institution[\"country_code\"]),\n",
|
131 |
+
" )\n",
|
132 |
+
" )\n",
|
133 |
+
" g.add(\n",
|
134 |
+
" (\n",
|
135 |
+
" institution_uri,\n",
|
136 |
+
" unis.name,\n",
|
137 |
+
" rdflib.Literal(institution[\"display_name\"]),\n",
|
138 |
+
" )\n",
|
139 |
+
" )\n",
|
140 |
+
" for parent_institution_id in institution[\"lineage\"]:\n",
|
141 |
+
" parent_institution_uri = rdflib.URIRef(parent_institution_id)\n",
|
142 |
+
" g.add(\n",
|
143 |
+
" (parent_institution_uri, rdflib.RDF.type, unis.Institution)\n",
|
144 |
+
" )\n",
|
145 |
+
" g.add(\n",
|
146 |
+
" (institution_uri, unis.is_part_of, parent_institution_uri)\n",
|
147 |
+
" )\n",
|
148 |
+
" # Concepts is a list of dicts\n",
|
149 |
+
" for concept in [c for c in article[\"concepts\"] if c[\"score\"] > 0.4]:\n",
|
150 |
+
" concept_uri = rdflib.URIRef(concept[\"id\"])\n",
|
151 |
+
" g.add((concept_uri, rdflib.RDF.type, unis.Concept))\n",
|
152 |
+
" g.add(\n",
|
153 |
+
" (\n",
|
154 |
+
" institution_uri,\n",
|
155 |
+
" unis.institution_related_to_concept,\n",
|
156 |
+
" concept_uri,\n",
|
157 |
+
" )\n",
|
158 |
+
" )\n",
|
159 |
+
" # Count the concepts\n",
|
160 |
+
" if concept[\"id\"] in num_articles_concepts:\n",
|
161 |
+
" num_articles_concepts[concept[\"id\"]] += 1\n",
|
162 |
+
" # Concepts is a list of dicts\n",
|
163 |
+
" for concept in [c for c in article[\"concepts\"] if c[\"score\"] > 0.4]:\n",
|
164 |
+
" concept_uri = rdflib.URIRef(concept[\"id\"])\n",
|
165 |
+
" g.add((concept_uri, rdflib.RDF.type, unis.Concept))\n",
|
166 |
+
" g.add((article_uri, unis.related_to_concept, concept_uri))\n",
|
167 |
+
" g.add((concept_uri, unis.name, rdflib.Literal(concept[\"display_name\"])))\n",
|
168 |
+
" # print the numbers of articles per concept\n",
|
169 |
+
" print(num_articles_concepts)"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"# Knogledge Extraction rule: if we have institution I, a paper P, and P is related to concept C, then C is related to I\n",
|
179 |
+
"# Add triples to the graph for this rule\n",
|
180 |
+
"query_results = g.query(\n",
|
181 |
+
" \"\"\"\n",
|
182 |
+
" SELECT DISTINCT ?institution ?concept\n",
|
183 |
+
" WHERE {\n",
|
184 |
+
" ?institution a unis:Institution .\n",
|
185 |
+
" ?article a unis:Article .\n",
|
186 |
+
" ?concept a unis:Concept .\n",
|
187 |
+
" ?article unis:written_in_institution ?institution .\n",
|
188 |
+
" ?article unis:related_to_concept ?concept .\n",
|
189 |
+
" }\n",
|
190 |
+
" \"\"\"\n",
|
191 |
+
")\n",
|
192 |
+
"# Print the number of results\n",
|
193 |
+
"print(f\"Found {len(query_results)} results for the rule.\")\n",
|
194 |
+
"for i, row in enumerate(query_results):\n",
|
195 |
+
" if i % 1000 == 0:\n",
|
196 |
+
" print(f\"Processing rule {i}\")\n",
|
197 |
+
" g.add((row[0], unis.institution_related_to_concept, row[1]))"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": null,
|
203 |
+
"metadata": {},
|
204 |
+
"outputs": [],
|
205 |
+
"source": [
|
206 |
+
"results = Works().search_filter(abstract=\"Large Language Model Knowledge Graph\").group_by(\n",
|
207 |
+
" \"authorships.institutions.id\"\n",
|
208 |
+
")\n",
|
209 |
+
"\n",
|
210 |
+
"print(f\"Found {results.count()} articles. Fetching...\")\n",
|
211 |
+
"\n",
|
212 |
+
"df = pd.DataFrame(results.get())\n",
|
213 |
+
"\n",
|
214 |
+
"display(df)"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
{
|
218 |
+
"cell_type": "code",
|
219 |
+
"execution_count": null,
|
220 |
+
"metadata": {},
|
221 |
+
"outputs": [],
|
222 |
+
"source": [
|
223 |
+
"store_graph()"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"cell_type": "code",
|
228 |
+
"execution_count": null,
|
229 |
+
"metadata": {},
|
230 |
+
"outputs": [],
|
231 |
+
"source": [
|
232 |
+
"#from rdflib.extras.external_graph_libs import rdflib_to_networkx_multidigraph\n",
|
233 |
+
"#import networkx as nx\n",
|
234 |
+
"#import matplotlib.pyplot as plt\n",
|
235 |
+
"#\n",
|
236 |
+
"#G = rdflib_to_networkx_multidigraph(g)\n",
|
237 |
+
"#\n",
|
238 |
+
"## Plot Networkx instance of RDF Graph\n",
|
239 |
+
"#pos = nx.spring_layout(G, scale=0.1)\n",
|
240 |
+
"#edge_labels = nx.get_edge_attributes(G, \"r\")\n",
|
241 |
+
"#nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)\n",
|
242 |
+
"#nx.draw(G, with_labels=True)\n",
|
243 |
+
"#\n",
|
244 |
+
"## if not in interactive mode for\n",
|
245 |
+
"#plt.show()"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": null,
|
251 |
+
"metadata": {},
|
252 |
+
"outputs": [],
|
253 |
+
"source": [
|
254 |
+
"\n",
|
255 |
+
"# Get the triples from the graph to a numpy array\n",
|
256 |
+
"# Array of size (n_triples, 3)\n",
|
257 |
+
"# We just want the triples where the predicate is either:\n",
|
258 |
+
"# - related_to\n",
|
259 |
+
"# - has_author\n",
|
260 |
+
"# - written_in_institution\n",
|
261 |
+
"# - related_to_concept\n",
|
262 |
+
"# - references\n",
|
263 |
+
"# - is_part_of\n",
|
264 |
+
"triples_generator = list(g.triples((None, unis.related_to, None)))\n",
|
265 |
+
"triples_generator += list(g.triples((None, unis.has_author, None)))\n",
|
266 |
+
"triples_generator += list(g.triples((None, unis.written_in_institution, None)))\n",
|
267 |
+
"triples_generator += list(g.triples((None, unis.related_to_concept, None)))\n",
|
268 |
+
"triples_generator += list(g.triples((None, unis.institution_related_to_concept, None)))\n",
|
269 |
+
"triples_generator += list(g.triples((None, unis.references, None)))\n",
|
270 |
+
"triples_generator += list(g.triples((None, unis.is_part_of, None)))\n",
|
271 |
+
"triples = np.array(\n",
|
272 |
+
" [(str(s), str(p), str(o)) for s, p, o in triples_generator]\n",
|
273 |
+
") # (subject, predicate, object) triples\n",
|
274 |
+
"\n",
|
275 |
+
"# Convert the objects to their string representation\n",
|
276 |
+
"# Split the triples into train, valid, and test sets (80%, 10%, 10%)\n",
|
277 |
+
"X_train, X_valid = train_test_split_no_unseen(np.array(triples), test_size=0.2)\n",
|
278 |
+
"X_valid, X_test = train_test_split_no_unseen(X_valid, test_size=0.5, allow_duplication=True)"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": null,
|
284 |
+
"metadata": {},
|
285 |
+
"outputs": [],
|
286 |
+
"source": []
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": null,
|
291 |
+
"metadata": {},
|
292 |
+
"outputs": [],
|
293 |
+
"source": [
|
294 |
+
"# Store the triples in a file\n",
|
295 |
+
"np.save(\"train.npy\", X_train)\n",
|
296 |
+
"np.save(\"valid.npy\", X_valid)\n",
|
297 |
+
"np.save(\"test.npy\", X_test)"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"execution_count": null,
|
303 |
+
"metadata": {},
|
304 |
+
"outputs": [],
|
305 |
+
"source": [
|
306 |
+
"# Load the triples from the file\n",
|
307 |
+
"X_train = np.load(\"train.npy\")\n",
|
308 |
+
"X_valid = np.load(\"valid.npy\")\n",
|
309 |
+
"X_test = np.load(\"test.npy\")\n",
|
310 |
+
"\n",
|
311 |
+
"print(f\"Train size: {X_train.shape[0]}\")\n",
|
312 |
+
"print(f\"Valid size: {X_valid.shape[0]}\")\n",
|
313 |
+
"print(f\"Test size: {X_test.shape[0]}\")"
|
314 |
+
]
|
315 |
+
},
|
316 |
+
{
|
317 |
+
"cell_type": "code",
|
318 |
+
"execution_count": null,
|
319 |
+
"metadata": {},
|
320 |
+
"outputs": [],
|
321 |
+
"source": [
|
322 |
+
"# Run the evaluation procedure on the test set (with filtering)\n",
|
323 |
+
"# To disable filtering: use_filter=None\n",
|
324 |
+
"# Usually, we corrupt subject and object sides separately and compute ranks\n",
|
325 |
+
"ranks = model.evaluate(X_test, use_filter=filter, corrupt_side=\"s,o\")\n",
|
326 |
+
"\n",
|
327 |
+
"# compute and print metrics:\n",
|
328 |
+
"mrr = mrr_score(ranks)\n",
|
329 |
+
"hits_10 = hits_at_n_score(ranks, n=10)\n",
|
330 |
+
"print(\"MRR: %f, Hits@10: %f\" % (mrr, hits_10))"
|
331 |
+
]
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "code",
|
335 |
+
"execution_count": null,
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": [
|
339 |
+
"# Store the model\n",
|
340 |
+
"super(ScoringBasedEmbeddingModel, model).save_weights(\"model/\")\n",
|
341 |
+
"model.save_metadata(filedir='model')\n",
|
342 |
+
"#from ampligraph.utils import save_model\n",
|
343 |
+
"#save_model(model, model_name_path='model.pkl')"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
{
|
347 |
+
"cell_type": "code",
|
348 |
+
"execution_count": null,
|
349 |
+
"metadata": {},
|
350 |
+
"outputs": [],
|
351 |
+
"source": [
|
352 |
+
"# Generate the embeddings for entities and relations in the graph\n",
|
353 |
+
"# and store them in numpy arrays\n",
|
354 |
+
"all_ids_institutions = np.array(\n",
|
355 |
+
" [\n",
|
356 |
+
" (str(x), str(name), int(num_articles))\n",
|
357 |
+
" for (x, name, num_articles) in g.query(\n",
|
358 |
+
" \"\"\"SELECT DISTINCT ?s ?name (COUNT (?article) AS ?num_articles)\n",
|
359 |
+
" WHERE {\n",
|
360 |
+
" ?s a <urn:acmcmc:unis:Institution> .\n",
|
361 |
+
" ?s <urn:acmcmc:unis:name> ?name .\n",
|
362 |
+
" ?article <urn:acmcmc:unis:written_in_institution> ?s .\n",
|
363 |
+
" ?article ?related_to <https://openalex.org/C204321447>\n",
|
364 |
+
" }\n",
|
365 |
+
" GROUP BY ?s ?name\n",
|
366 |
+
" \"\"\"\n",
|
367 |
+
" )\n",
|
368 |
+
" ]\n",
|
369 |
+
")\n",
|
370 |
+
"print(all_ids_institutions.shape)\n",
|
371 |
+
"print(all_ids_institutions[0])\n",
|
372 |
+
"entity_embeddings = model.get_embeddings(entities=all_ids_institutions[:, 0])\n",
|
373 |
+
"display(entity_embeddings.shape)"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"cell_type": "code",
|
378 |
+
"execution_count": null,
|
379 |
+
"metadata": {},
|
380 |
+
"outputs": [],
|
381 |
+
"source": [
|
382 |
+
"# PCA\n",
|
383 |
+
"from sklearn.decomposition import PCA\n",
|
384 |
+
"pca = PCA(n_components=2)\n",
|
385 |
+
"pca.fit(entity_embeddings)\n",
|
386 |
+
"entity_embeddings_pca = pca.transform(entity_embeddings)"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"execution_count": null,
|
392 |
+
"metadata": {},
|
393 |
+
"outputs": [],
|
394 |
+
"source": [
|
395 |
+
"from ampligraph.discovery import find_clusters\n",
|
396 |
+
"from sklearn.cluster import KMeans\n",
|
397 |
+
"\n",
|
398 |
+
"clustering_algorithm = KMeans(n_clusters=6, n_init=50, max_iter=500, random_state=0)\n",
|
399 |
+
"clusters = find_clusters(all_ids_institutions[:,0], model, clustering_algorithm, mode=\"e\")"
|
400 |
+
]
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"cell_type": "code",
|
404 |
+
"execution_count": null,
|
405 |
+
"metadata": {},
|
406 |
+
"outputs": [],
|
407 |
+
"source": [
|
408 |
+
"plot_df = pd.DataFrame(\n",
|
409 |
+
" {\n",
|
410 |
+
" \"institution\": all_ids_institutions[:, 0],\n",
|
411 |
+
" \"institution_name\": all_ids_institutions[:, 1],\n",
|
412 |
+
" \"embedding1\": entity_embeddings_pca[:, 0],\n",
|
413 |
+
" \"embedding2\": entity_embeddings_pca[:, 1],\n",
|
414 |
+
" \"cluster\": \"cluster\" + pd.Series(clusters).astype(str),\n",
|
415 |
+
" \"num_articles\": all_ids_institutions[:, 2].astype(int),\n",
|
416 |
+
" }\n",
|
417 |
+
")"
|
418 |
+
]
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"cell_type": "code",
|
422 |
+
"execution_count": null,
|
423 |
+
"metadata": {},
|
424 |
+
"outputs": [],
|
425 |
+
"source": [
|
426 |
+
"import matplotlib.pyplot as plt\n",
|
427 |
+
"import seaborn as sns\n",
|
428 |
+
"from adjustText import adjust_text\n",
|
429 |
+
"\n",
|
430 |
+
"\n",
|
431 |
+
"def plot_clusters(parameter):\n",
|
432 |
+
" np.random.seed(0)\n",
|
433 |
+
" plt.figure(figsize=(12, 12))\n",
|
434 |
+
" plt.title(\"{} embeddings\".format(parameter).capitalize())\n",
|
435 |
+
" ax = sns.scatterplot(\n",
|
436 |
+
" data=plot_df,\n",
|
437 |
+
" x=\"embedding1\",\n",
|
438 |
+
" y=\"embedding2\",\n",
|
439 |
+
" hue=parameter,\n",
|
440 |
+
" )\n",
|
441 |
+
" texts = []\n",
|
442 |
+
" for i, point in plot_df.iterrows():\n",
|
443 |
+
" if point[\"institution\"] in [\"https://openalex.org/I161318765\", 'https://openalex.org/I1174212', 'https://openalex.org/I95457486']:\n",
|
444 |
+
" print(point)\n",
|
445 |
+
" texts.append(\n",
|
446 |
+
" plt.text(\n",
|
447 |
+
" point[\"embedding1\"] + 0.02,\n",
|
448 |
+
" point[\"embedding2\"] + 0.01,\n",
|
449 |
+
" str(point[\"institution_name\"]),\n",
|
450 |
+
" )\n",
|
451 |
+
" )\n",
|
452 |
+
" # texts.append(\n",
|
453 |
+
" # plt.text(\n",
|
454 |
+
" # point[\"embedding1\"] + 0.02,\n",
|
455 |
+
" # point[\"embedding2\"] + 0.01,\n",
|
456 |
+
" # str(point[\"institutions\"]),\n",
|
457 |
+
" # )\n",
|
458 |
+
" # )\n",
|
459 |
+
" adjust_text(texts)"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"execution_count": null,
|
465 |
+
"metadata": {},
|
466 |
+
"outputs": [],
|
467 |
+
"source": [
|
468 |
+
"plot_clusters(\"num_articles\")"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "code",
|
473 |
+
"execution_count": null,
|
474 |
+
"metadata": {},
|
475 |
+
"outputs": [],
|
476 |
+
"source": [
|
477 |
+
"from ampligraph.discovery import discover_facts\n",
|
478 |
+
"\n",
|
479 |
+
"discover_facts(\n",
|
480 |
+
" filter['test'],\n",
|
481 |
+
" model,\n",
|
482 |
+
" top_n=100,\n",
|
483 |
+
" strategy=\"random_uniform\",\n",
|
484 |
+
" max_candidates=100,\n",
|
485 |
+
" target_rel=\"urn:acmcmc:unis:related_to_concept\",\n",
|
486 |
+
" seed=0,\n",
|
487 |
+
")"
|
488 |
+
]
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"cell_type": "code",
|
492 |
+
"execution_count": null,
|
493 |
+
"metadata": {},
|
494 |
+
"outputs": [],
|
495 |
+
"source": [
|
496 |
+
"# Create a dataframe of the institutions and their names\n",
|
497 |
+
"import pandas as pd\n",
|
498 |
+
"query_results = g.query(\n",
|
499 |
+
" \"\"\"\n",
|
500 |
+
" SELECT DISTINCT ?institution ?name\n",
|
501 |
+
" WHERE {\n",
|
502 |
+
" ?institution a unis:Institution .\n",
|
503 |
+
" ?institution unis:name ?name .\n",
|
504 |
+
" }\n",
|
505 |
+
" \"\"\"\n",
|
506 |
+
")\n",
|
507 |
+
"institutions = pd.DataFrame(query_results, columns=[\"institution\", \"name\"])\n",
|
508 |
+
"institutions[\"institution\"] = institutions[\"institution\"].apply(lambda x: str(x))\n",
|
509 |
+
"institutions[\"name\"] = institutions[\"name\"].apply(lambda x: str(x))\n",
|
510 |
+
"# Store the dataframe\n",
|
511 |
+
"institutions.to_csv(\"institutions.csv\", index=False)"
|
512 |
+
]
|
513 |
+
}
|
514 |
+
],
|
515 |
+
"metadata": {
|
516 |
+
"kernelspec": {
|
517 |
+
"display_name": "universities-kge",
|
518 |
+
"language": "python",
|
519 |
+
"name": "python3"
|
520 |
+
},
|
521 |
+
"language_info": {
|
522 |
+
"codemirror_mode": {
|
523 |
+
"name": "ipython",
|
524 |
+
"version": 3
|
525 |
+
},
|
526 |
+
"file_extension": ".py",
|
527 |
+
"mimetype": "text/x-python",
|
528 |
+
"name": "python",
|
529 |
+
"nbconvert_exporter": "python",
|
530 |
+
"pygments_lexer": "ipython3",
|
531 |
+
"version": "3.11.5"
|
532 |
+
}
|
533 |
+
},
|
534 |
+
"nbformat": 4,
|
535 |
+
"nbformat_minor": 2
|
536 |
+
}
|
institutions.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/.data-00000-of-00001
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5ded6f0bf7985926646dd021e03e008d0f8779f606e4010f0ab89cf8687e943
|
3 |
+
size 87725277
|
model/.index
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d9027e082ae75293bde304a2044fbd0549aa0bd1b43d3483c7c28b0ab7bc72b
|
3 |
+
size 291
|
model/checkpoint
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba7ac3757b9a57bdd3e603acb528728d61a9479fe392a7f343330aad23f22c50
|
3 |
+
size 59
|
model/model_metadata.ampkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7a052e205b870dba54d5a4b23c54f638d93e880c81b66e14ec1c6ae90d2cd33
|
3 |
+
size 24656298
|
requirements.txt
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.0.0
|
2 |
+
adjustText==0.8
|
3 |
+
aiofiles==23.2.1
|
4 |
+
alabaster==0.7.13
|
5 |
+
altair==5.2.0
|
6 |
+
ampligraph==2.0.1
|
7 |
+
annotated-types==0.6.0
|
8 |
+
anyio==3.7.1
|
9 |
+
astunparse==1.6.3
|
10 |
+
attrs==23.1.0
|
11 |
+
Babel==2.13.1
|
12 |
+
beautifultable==1.1.0
|
13 |
+
cachetools==5.3.2
|
14 |
+
certifi==2023.11.17
|
15 |
+
charset-normalizer==3.3.2
|
16 |
+
click==8.1.7
|
17 |
+
colorama==0.4.6
|
18 |
+
contextlib2==21.6.0
|
19 |
+
contourpy==1.2.0
|
20 |
+
cycler==0.12.1
|
21 |
+
docopt==0.6.2
|
22 |
+
docutils==0.17.1
|
23 |
+
fastapi==0.105.0
|
24 |
+
ffmpy==0.3.1
|
25 |
+
filelock==3.13.1
|
26 |
+
flake8==6.1.0
|
27 |
+
flatbuffers==23.5.26
|
28 |
+
fonttools==4.46.0
|
29 |
+
fsspec==2023.12.2
|
30 |
+
gast==0.5.4
|
31 |
+
google-auth==2.25.2
|
32 |
+
google-auth-oauthlib==1.1.0
|
33 |
+
google-pasta==0.2.0
|
34 |
+
gradio==4.8.0
|
35 |
+
gradio_client==0.7.1
|
36 |
+
grpcio==1.60.0
|
37 |
+
h11==0.14.0
|
38 |
+
h5py==3.10.0
|
39 |
+
httpcore==1.0.2
|
40 |
+
httpx==0.25.2
|
41 |
+
huggingface-hub==0.19.4
|
42 |
+
idna==3.6
|
43 |
+
imagesize==1.4.1
|
44 |
+
importlib-resources==6.1.1
|
45 |
+
iniconfig==2.0.0
|
46 |
+
isodate==0.6.1
|
47 |
+
Jinja2==3.1.2
|
48 |
+
joblib==1.3.2
|
49 |
+
jsonschema==4.20.0
|
50 |
+
jsonschema-specifications==2023.11.2
|
51 |
+
keras==2.15.0
|
52 |
+
kiwisolver==1.4.5
|
53 |
+
latexcodec==2.0.1
|
54 |
+
libclang==16.0.6
|
55 |
+
Markdown==3.5.1
|
56 |
+
markdown-it-py==2.2.0
|
57 |
+
MarkupSafe==2.1.3
|
58 |
+
matplotlib==3.8.2
|
59 |
+
mccabe==0.7.0
|
60 |
+
mdit-py-plugins==0.3.5
|
61 |
+
mdurl==0.1.2
|
62 |
+
ml-dtypes==0.2.0
|
63 |
+
myst-parser==0.18.0
|
64 |
+
networkx==3.2.1
|
65 |
+
numpy==1.26.2
|
66 |
+
oauthlib==3.2.2
|
67 |
+
opt-einsum==3.3.0
|
68 |
+
orjson==3.9.10
|
69 |
+
pandas==2.1.4
|
70 |
+
Pillow==10.1.0
|
71 |
+
pluggy==1.3.0
|
72 |
+
protobuf==4.23.4
|
73 |
+
pyalex==0.13
|
74 |
+
pyasn1==0.5.1
|
75 |
+
pyasn1-modules==0.3.0
|
76 |
+
pybtex==0.24.0
|
77 |
+
pybtex-docutils==1.0.3
|
78 |
+
pycodestyle==2.11.1
|
79 |
+
pydantic==2.5.2
|
80 |
+
pydantic_core==2.14.5
|
81 |
+
pydub==0.25.1
|
82 |
+
pyflakes==3.1.0
|
83 |
+
pyparsing==3.1.1
|
84 |
+
pytest==7.4.3
|
85 |
+
python-dotenv==1.0.0
|
86 |
+
python-multipart==0.0.6
|
87 |
+
pytz==2023.3.post1
|
88 |
+
PyYAML==6.0.1
|
89 |
+
rdflib==7.0.0
|
90 |
+
referencing==0.32.0
|
91 |
+
requests==2.31.0
|
92 |
+
requests-oauthlib==1.3.1
|
93 |
+
rich==13.7.0
|
94 |
+
rpds-py==0.13.2
|
95 |
+
rsa==4.9
|
96 |
+
schema==0.7.5
|
97 |
+
scikit-learn==1.3.2
|
98 |
+
scipy==1.10.0
|
99 |
+
seaborn==0.13.0
|
100 |
+
semantic-version==2.10.0
|
101 |
+
shellingham==1.5.4
|
102 |
+
sniffio==1.3.0
|
103 |
+
snowballstemmer==2.2.0
|
104 |
+
SPARQLWrapper==2.0.0
|
105 |
+
Sphinx==5.0.2
|
106 |
+
sphinx-rtd-theme==1.0.0
|
107 |
+
sphinxcontrib-applehelp==1.0.7
|
108 |
+
sphinxcontrib-bibtex==2.4.2
|
109 |
+
sphinxcontrib-devhelp==1.0.5
|
110 |
+
sphinxcontrib-htmlhelp==2.0.4
|
111 |
+
sphinxcontrib-jsmath==1.0.1
|
112 |
+
sphinxcontrib-qthelp==1.0.6
|
113 |
+
sphinxcontrib-serializinghtml==1.1.9
|
114 |
+
starlette==0.27.0
|
115 |
+
tensorboard==2.15.1
|
116 |
+
tensorboard-data-server==0.7.2
|
117 |
+
tensorflow==2.15.0
|
118 |
+
tensorflow-estimator==2.15.0
|
119 |
+
tensorflow-io-gcs-filesystem==0.34.0
|
120 |
+
termcolor==2.4.0
|
121 |
+
threadpoolctl==3.2.0
|
122 |
+
tomlkit==0.12.0
|
123 |
+
toolz==0.12.0
|
124 |
+
tqdm==4.66.1
|
125 |
+
typer==0.9.0
|
126 |
+
tzdata==2023.3
|
127 |
+
urllib3==2.1.0
|
128 |
+
uvicorn==0.24.0.post1
|
129 |
+
websockets==11.0.3
|
130 |
+
Werkzeug==3.0.1
|
131 |
+
wrapt==1.14.1
|
test.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4853fc51e34ffde1e7d2bfc0b463b41d57b163442e8bd2ad748e038d635bb140
|
3 |
+
size 73705888
|
test.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
with gr.Blocks() as demo:
|
4 |
+
t1 = gr.Label("Hello world!")
|
5 |
+
btn = gr.Button("Click me")
|
6 |
+
t2 = gr.Label("Hello world!", visible=False)
|
7 |
+
def update():
|
8 |
+
return "abc", gr.update(visible=True)
|
9 |
+
btn.click(update, inputs=[], outputs=[t2, t2])
|
10 |
+
|
11 |
+
# %%
|
12 |
+
demo.launch()
|
train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e3e08f7214b7ab53bc55eff7a07eddaff45202a1975bb3a526c4f7bc9e82f83d
|
3 |
+
size 546683543
|
train.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
# Set logging level to DEBUG
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
|
6 |
+
import dotenv
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import pyalex
|
10 |
+
import rdflib
|
11 |
+
from ampligraph.datasets import (
|
12 |
+
GraphDataLoader,
|
13 |
+
SQLiteAdapter,
|
14 |
+
DataSourceIdentifier,
|
15 |
+
)
|
16 |
+
from ampligraph.datasets.graph_partitioner import NaiveGraphPartitioner, BucketGraphPartitioner
|
17 |
+
from ampligraph.evaluation import train_test_split_no_unseen
|
18 |
+
from ampligraph.latent_features import ScoringBasedEmbeddingModel
|
19 |
+
from pyalex import Authors, Concepts, Funders, Institutions, Publishers, Sources, Works
|
20 |
+
from sklearn.model_selection import train_test_split
|
21 |
+
import tensorflow as tf
|
22 |
+
from ampligraph.evaluation import hits_at_n_score, mrr_score
|
23 |
+
from ampligraph.latent_features import ScoringBasedEmbeddingModel
|
24 |
+
from ampligraph.latent_features.loss_functions import get as get_loss
|
25 |
+
from ampligraph.latent_features.regularizers import get as get_regularizer
|
26 |
+
|
27 |
+
logging.basicConfig(level=logging.DEBUG)
|
28 |
+
loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
|
29 |
+
for logger in loggers:
|
30 |
+
logger.setLevel(logging.INFO)
|
31 |
+
|
32 |
+
# Load the triples from the file
|
33 |
+
X_train = np.load("train.npy")
|
34 |
+
X_valid = np.load("valid.npy")
|
35 |
+
X_test = np.load("test.npy")
|
36 |
+
|
37 |
+
## Store as CSVs. There are commas in the names of some institutions, so we need to use a tab as the delimiter
|
38 |
+
#np.savetxt("train.csv", X_train, delimiter="\t", fmt="%s")
|
39 |
+
#np.savetxt("valid.csv", X_valid, delimiter="\t", fmt="%s")
|
40 |
+
#np.savetxt("test.csv", X_test, delimiter="\t", fmt="%s")
|
41 |
+
#
|
42 |
+
#print(f"Train size: {X_train.shape[0]}")
|
43 |
+
#print(f"Valid size: {X_valid.shape[0]}")
|
44 |
+
#print(f"Test size: {X_test.shape[0]}")
|
45 |
+
|
46 |
+
|
47 |
+
# Initialize a ComplEx neural embedding model: the embedding size is k,
|
48 |
+
# eta specifies the number of corruptions to generate per each positive,
|
49 |
+
# scoring_type determines the scoring function of the embedding model.
|
50 |
+
partitioned_model = ScoringBasedEmbeddingModel(k=150, eta=10, scoring_type="ComplEx")
|
51 |
+
|
52 |
+
# Optimizer, loss and regularizer definition
|
53 |
+
optim = tf.keras.optimizers.Adam(learning_rate=1e-3)
|
54 |
+
loss = get_loss("pairwise", {"margin": 0.5})
|
55 |
+
regularizer = get_regularizer("LP", {"p": 2, "lambda": 1e-5})
|
56 |
+
|
57 |
+
# Compilation of the model
|
58 |
+
partitioned_model.compile(
|
59 |
+
optimizer=optim, loss=loss, entity_relation_regularizer=regularizer
|
60 |
+
)
|
61 |
+
|
62 |
+
# For evaluation, we can use a filter which would be used to filter out
|
63 |
+
# positives statements created by the corruption procedure.
|
64 |
+
# Here we define the filter set by concatenating all the positives
|
65 |
+
|
66 |
+
filter = {"test": np.concatenate((X_train, X_valid, X_test))}
|
67 |
+
|
68 |
+
# Early Stopping callback
|
69 |
+
checkpoint = tf.keras.callbacks.EarlyStopping(
|
70 |
+
monitor="val_{}".format("hits10"),
|
71 |
+
min_delta=0,
|
72 |
+
patience=5,
|
73 |
+
verbose=1,
|
74 |
+
mode="max",
|
75 |
+
restore_best_weights=True,
|
76 |
+
)
|
77 |
+
|
78 |
+
|
79 |
+
######
|
80 |
+
use_db = False
|
81 |
+
if use_db:
|
82 |
+
AMPLIGRAPH_DATA_HOME = os.path.join(os.getcwd(), "data") # + os.sep
|
83 |
+
|
84 |
+
from ampligraph.datasets.data_indexer import SQLite as SQLiteIndexer, DataIndexer
|
85 |
+
|
86 |
+
# Initialize GraphDataLoader from .csv file
|
87 |
+
sqlite_indexer = SQLiteIndexer(
|
88 |
+
data=None,
|
89 |
+
db_file="main_partition.db",
|
90 |
+
root_directory=AMPLIGRAPH_DATA_HOME,
|
91 |
+
name="main_partition",
|
92 |
+
)
|
93 |
+
indexer = DataIndexer(
|
94 |
+
X=None,
|
95 |
+
backend_type='sqlite',
|
96 |
+
backend=sqlite_indexer,
|
97 |
+
)
|
98 |
+
dataset_loader = GraphDataLoader(
|
99 |
+
"train.csv",
|
100 |
+
backend=SQLiteAdapter,
|
101 |
+
in_memory=False,
|
102 |
+
verbose=True,
|
103 |
+
root_directory=AMPLIGRAPH_DATA_HOME,
|
104 |
+
db_name="mydb.db",
|
105 |
+
use_indexer=indexer,
|
106 |
+
)
|
107 |
+
# adapter = SQLiteAdapter(
|
108 |
+
# "database_25-12-2023_07-28-41_485047_PM_2a11fc49-2337-415e-8672-2bfa48a83745.db",
|
109 |
+
# identifier=DataSourceIdentifier,
|
110 |
+
# root_directory=AMPLIGRAPH_DATA_HOME,
|
111 |
+
# )
|
112 |
+
print("Graph data loader initialized")
|
113 |
+
# for elem in next(dataset_loader._get_batch_generator()):
|
114 |
+
# print(elem)
|
115 |
+
# break
|
116 |
+
######
|
117 |
+
else:
|
118 |
+
X_train = np.load("train.npy")
|
119 |
+
dataset_loader = GraphDataLoader(
|
120 |
+
X_train,
|
121 |
+
verbose=True,
|
122 |
+
use_indexer=True,
|
123 |
+
in_memory=True,
|
124 |
+
)
|
125 |
+
print(f'next: {next(dataset_loader)}')
|
126 |
+
print(f'next: {next(dataset_loader)}')
|
127 |
+
print(f'next: {next(dataset_loader)}')
|
128 |
+
#x = np.loadtxt(
|
129 |
+
# "train.csv",
|
130 |
+
# delimiter="\t",
|
131 |
+
# dtype=str,
|
132 |
+
#)
|
133 |
+
#print(x[0])
|
134 |
+
|
135 |
+
# Choose the partitioner - in this case we choose RandomEdges partitioner
|
136 |
+
partition = False
|
137 |
+
if partition:
|
138 |
+
print("Will start partitioning now")
|
139 |
+
graph_partitioner_train = NaiveGraphPartitioner(dataset_loader, k=6)
|
140 |
+
print("Graph partitioner initialized")
|
141 |
+
#indexer = (
|
142 |
+
# partitioned_model.data_handler.get_mapper()
|
143 |
+
#) # get the mapper from the trained model
|
144 |
+
# dataset_loader_test = GraphDataLoader(
|
145 |
+
# data_source=X_test,
|
146 |
+
# backend=SQLiteAdapter, # type of backend to use
|
147 |
+
# batch_size=400, # batch size to use while iterating over this dataset
|
148 |
+
# dataset_type="test", # dataset type
|
149 |
+
# use_indexer=indexer, # mapper to map test concepts to the same indices used during training
|
150 |
+
# verbose=True,
|
151 |
+
# )
|
152 |
+
# graph_partitioner_test = BucketGraphPartitioner(data=partitioner, k=3)
|
153 |
+
|
154 |
+
print("Will start training now")
|
155 |
+
# Fit the model on training and validation set
|
156 |
+
partitioned_model.fit(
|
157 |
+
#graph_partitioner_train if partition else dataset_loader,
|
158 |
+
X_train,
|
159 |
+
batch_size=500,
|
160 |
+
epochs=45, # Number of training epochs
|
161 |
+
validation_freq=20, # Epochs between successive validation
|
162 |
+
validation_burn_in=100, # Epoch to start validation
|
163 |
+
validation_data=X_test, # Validation data
|
164 |
+
validation_filter=filter, # Filter positives from validation corruptions
|
165 |
+
callbacks=[
|
166 |
+
checkpoint
|
167 |
+
], # Early stopping callback (more from tf.keras.callbacks are supported)
|
168 |
+
verbose=True, # Enable stdout messages
|
169 |
+
#partitioning_k=7, # Number of partitions to create
|
170 |
+
)
|
171 |
+
|
172 |
+
# %%
|
173 |
+
# Store the model
|
174 |
+
super(ScoringBasedEmbeddingModel, partitioned_model).save_weights("model/")
|
175 |
+
partitioned_model.save_metadata(filedir="model")
|
176 |
+
# from ampligraph.utils import save_model
|
177 |
+
# save_model(partitioned_model, model_name_path='model.pkl')
|
178 |
+
|
179 |
+
# %%
|
180 |
+
# Create a dataframe of the institutions and their names
|
181 |
+
import pandas as pd
|
182 |
+
|
183 |
+
import rdflib
|
184 |
+
|
185 |
+
g = rdflib.Graph()
|
186 |
+
uri = "urn:acmcmc:unis:"
|
187 |
+
unis = rdflib.Namespace(uri)
|
188 |
+
g.bind("unis", unis)
|
189 |
+
g.parse("universities_large_1200.ttl", format="turtle")
|
190 |
+
|
191 |
+
query_results = g.query(
|
192 |
+
"""
|
193 |
+
SELECT DISTINCT ?institution ?name
|
194 |
+
WHERE {
|
195 |
+
?institution a unis:Institution .
|
196 |
+
?institution unis:name ?name .
|
197 |
+
}
|
198 |
+
"""
|
199 |
+
)
|
200 |
+
institutions = pd.DataFrame(query_results, columns=["institution", "name"])
|
201 |
+
institutions["institution"] = institutions["institution"].apply(lambda x: str(x))
|
202 |
+
institutions["name"] = institutions["name"].apply(lambda x: str(x))
|
203 |
+
# Store the dataframe
|
204 |
+
institutions.to_csv("institutions.csv", index=False)
|
205 |
+
|
206 |
+
# %%
|
207 |
+
# Run the evaluation procedure on the test set (with filtering)
|
208 |
+
# To disable filtering: use_filter=None
|
209 |
+
# Usually, we corrupt subject and object sides separately and compute ranks
|
210 |
+
ranks = partitioned_model.evaluate(X_test, use_filter=filter, corrupt_side="s,o")
|
211 |
+
|
212 |
+
# compute and print metrics:
|
213 |
+
mrr = mrr_score(ranks)
|
214 |
+
hits_10 = hits_at_n_score(ranks, n=10)
|
215 |
+
print("MRR: %f, Hits@10: %f" % (mrr, hits_10))
|
universities.ttl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb26d10e53b756c0a17940946cf06603826eb779847b5943cd35155e4257f636
|
3 |
+
size 209243108
|
universities_large.ttl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e86756a309bf3e29aceb783f9d02cca057466c92862423501d13d9a08fd2ffa
|
3 |
+
size 807256238
|
universities_large_1200.ttl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:efd28c5a3d62c211b536439bad446658100ba11c26102478d97b3df5483b0dcb
|
3 |
+
size 807269262
|
universities_large_4200.ttl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1c1fc0caa955c41cc0ffd1cd8a903b8845de5f0b463e3c1d65ec94cdc3d71e9c
|
3 |
+
size 1757505860
|
universities_large_4300.ttl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfadae62d79379f75bf86b7c30bd58753b65cdc97c69aafad796c60faaa84de4
|
3 |
+
size 1818624628
|
valid.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e0a30e621a7859e186970e2b4bc81e2bc9ffc6ece265d373e95a26733e397314
|
3 |
+
size 71276234
|