Merge branch 'main' of https://huggingface.co/apoorvumang/kgt5-wikikg90mv2 into main
Browse files
README.md
CHANGED
@@ -10,4 +10,110 @@ widget:
|
|
10 |
- text: "Q12345678| follows"
|
11 |
example_title: "follows"
|
12 |
---
|
13 |
-
This is a t5-small model trained on WikiKG90Mv2 dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
- text: "Q12345678| follows"
|
11 |
example_title: "follows"
|
12 |
---
|
13 |
+
This is a t5-small model trained from scratch on WikiKG90Mv2 dataset. Please see https://github.com/apoorvumang/transformer-kgc/ for more details on the method.
|
14 |
+
|
15 |
+
This model was trained on the tail entity prediction task ie. given subject entity and relation, predict the object entity. Input should be provided in the form of "\<entity text\>| \<relation text\>".
|
16 |
+
|
17 |
+
We used the raw text title and descriptions to get entity and relation textual representations. These raw texts were obtained from ogb dataset itself (dataset/wikikg90m-v2/mapping/entity.csv and relation.csv). Entity representation was set to the title, and description was used to disambiguate if 2 entities had the same title. If still no disambiguation was possible, we used the wikidata ID (eg. Q123456).
|
18 |
+
|
19 |
+
We trained the model on WikiKG90Mv2 for approx 1.5 epochs on 4x1080Ti GPUs. The training time for 1 epoch was approx 5.5 days.
|
20 |
+
|
21 |
+
To evaluate the model, we sample 300 times from the decoder for each input (s,r) pair. We then remove predictions which do not map back to a valid entity, and then rank the predictions by their log probabilities. Filtering was performed subsequently.
|
22 |
+
|
23 |
+
You can try the following code in an ipython notebook to evaluate the pre-trained model. The full procedure of mapping entity to ids, filtering etc. is not included here for sake of simplicity but can be provided on request if needed. Please contact Apoorv ([email protected]) for clarifications/details.
|
24 |
+
---------
|
25 |
+
```
|
26 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
|
28 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
|
29 |
+
```
|
30 |
+
```
|
31 |
+
import torch
|
32 |
+
|
33 |
+
def getScores(ids, scores, pad_token_id):
|
34 |
+
"""get sequence scores from model.generate output"""
|
35 |
+
scores = torch.stack(scores, dim=1)
|
36 |
+
log_probs = torch.log_softmax(scores, dim=2)
|
37 |
+
# remove start token
|
38 |
+
ids = ids[:,1:]
|
39 |
+
# gather needed probs
|
40 |
+
x = ids.unsqueeze(-1).expand(log_probs.shape)
|
41 |
+
needed_logits = torch.gather(log_probs, 2, x)
|
42 |
+
final_logits = needed_logits[:, :, 0]
|
43 |
+
padded_mask = (ids == pad_token_id)
|
44 |
+
final_logits[padded_mask] = 0
|
45 |
+
final_scores = final_logits.sum(dim=-1)
|
46 |
+
return final_scores.cpu().detach().numpy()
|
47 |
+
|
48 |
+
def topkSample(input, model, tokenizer,
|
49 |
+
num_samples=5,
|
50 |
+
num_beams=1,
|
51 |
+
max_output_length=30):
|
52 |
+
tokenized = tokenizer(input, return_tensors="pt")
|
53 |
+
out = model.generate(**tokenized,
|
54 |
+
do_sample=True,
|
55 |
+
num_return_sequences = num_samples,
|
56 |
+
num_beams = num_beams,
|
57 |
+
eos_token_id = tokenizer.eos_token_id,
|
58 |
+
pad_token_id = tokenizer.pad_token_id,
|
59 |
+
output_scores = True,
|
60 |
+
return_dict_in_generate=True,
|
61 |
+
max_length=max_output_length,)
|
62 |
+
out_tokens = out.sequences
|
63 |
+
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
|
64 |
+
out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
|
65 |
+
|
66 |
+
pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
|
67 |
+
sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
|
68 |
+
return sorted_pair_list
|
69 |
+
|
70 |
+
def greedyPredict(input, model, tokenizer):
|
71 |
+
input_ids = tokenizer([input], return_tensors="pt").input_ids
|
72 |
+
out_tokens = model.generate(input_ids)
|
73 |
+
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
|
74 |
+
return out_str[0]
|
75 |
+
```
|
76 |
+
|
77 |
+
```
|
78 |
+
# an example from validation set that the model predicts correctly
|
79 |
+
# you can try your own examples here. what's your noble title?
|
80 |
+
input = "Sophie Valdemarsdottir| noble title"
|
81 |
+
out = topkSample(input, model, tokenizer, num_samples=5)
|
82 |
+
out
|
83 |
+
```
|
84 |
+
|
85 |
+
You can further load the list of entity aliases, then filter only those predictions which are valid entities then create a reverse mapping from alias -> integer id to get final predictions in required format.
|
86 |
+
|
87 |
+
However, loading these aliases in memory as a dictionary requires a lot of RAM + you need to download the aliases file (made available here)
|
88 |
+
|
89 |
+
The submitted validation/test results for were obtained by sampling 300 times for each input, then applying above procedure, followed by filtering known entities. The final MRR can vary slightly due to this sampling nature (we found that although beam search gives deterministic output, the results are inferior to sampling large number of times).
|
90 |
+
|
91 |
+
```
|
92 |
+
# download valid.txt. you can also try same url with test.txt. however test does not contain the correct tails
|
93 |
+
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
|
94 |
+
```
|
95 |
+
```
|
96 |
+
fname = 'valid.txt'
|
97 |
+
valid_lines = []
|
98 |
+
f = open(fname)
|
99 |
+
for line in f:
|
100 |
+
valid_lines.append(line.rstrip())
|
101 |
+
f.close()
|
102 |
+
print(valid_lines[0])
|
103 |
+
```
|
104 |
+
```
|
105 |
+
from tqdm.auto import tqdm
|
106 |
+
# try unfiltered hits@k. this is approximation since model can sample same seq multiple times
|
107 |
+
# you should run this on gpu if you want to evaluate on all points with 300 samples each
|
108 |
+
k = 1
|
109 |
+
count_at_k = 0
|
110 |
+
max_predictions = k
|
111 |
+
max_points = 1000
|
112 |
+
for line in tqdm(valid_lines[:max_points]):
|
113 |
+
input, target = line.split('\t')
|
114 |
+
model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
|
115 |
+
prediction_strings = [x[0] for x in model_output]
|
116 |
+
if target in prediction_strings:
|
117 |
+
count_at_k += 1
|
118 |
+
print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))
|
119 |
+
```
|