File size: 6,088 Bytes
f2b4a89 07f6ea4 79c60f5 07f6ea4 b8f3837 f2b4a89 01c5197 f958b89 13c5f87 f958b89 d4223b2 f958b89 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
---
license: mit
widget:
- text: "Apoorv Umang Saxena| family name"
example_title: "Family name prediction"
- text: "Apoorv Saxena| country"
example_title: "Country prediction"
- text: "World War 2| followed by"
example_title: "followed by"
---
This is a t5-small model trained from scratch on WikiKG90Mv2 dataset. Please see https://github.com/apoorvumang/kgt5/ for more details on the method.
This model was trained on the tail entity prediction task ie. given subject entity and relation, predict the object entity. Input should be provided in the form of "\<entity text\>| \<relation text\>".
We used the raw text title and descriptions to get entity and relation textual representations. These raw texts were obtained from ogb dataset itself (dataset/wikikg90m-v2/mapping/entity.csv and relation.csv). Entity representation was set to the title, and description was used to disambiguate if 2 entities had the same title. If still no disambiguation was possible, we used the wikidata ID (eg. Q123456).
We trained the model on WikiKG90Mv2 for approx 1.5 epochs on 4x1080Ti GPUs. The training time for 1 epoch was approx 5.5 days.
To evaluate the model, we sample 300 times from the decoder for each input (s,r) pair. We then remove predictions which do not map back to a valid entity, and then rank the predictions by their log probabilities. Filtering was performed subsequently. We achieve 0.22 validation MRR (the full leaderboard is here https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2)
You can try the following code in an ipython notebook to evaluate the pre-trained model. The full procedure of mapping entity to ids, filtering etc. is not included here for sake of simplicity but can be provided on request if needed. Please contact Apoorv ([email protected]) for clarifications/details.
---------
```
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
```
```
import torch
def getScores(ids, scores, pad_token_id):
"""get sequence scores from model.generate output"""
scores = torch.stack(scores, dim=1)
log_probs = torch.log_softmax(scores, dim=2)
# remove start token
ids = ids[:,1:]
# gather needed probs
x = ids.unsqueeze(-1).expand(log_probs.shape)
needed_logits = torch.gather(log_probs, 2, x)
final_logits = needed_logits[:, :, 0]
padded_mask = (ids == pad_token_id)
final_logits[padded_mask] = 0
final_scores = final_logits.sum(dim=-1)
return final_scores.cpu().detach().numpy()
def topkSample(input, model, tokenizer,
num_samples=5,
num_beams=1,
max_output_length=30):
tokenized = tokenizer(input, return_tensors="pt")
out = model.generate(**tokenized,
do_sample=True,
num_return_sequences = num_samples,
num_beams = num_beams,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
output_scores = True,
return_dict_in_generate=True,
max_length=max_output_length,)
out_tokens = out.sequences
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
return sorted_pair_list
def greedyPredict(input, model, tokenizer):
input_ids = tokenizer([input], return_tensors="pt").input_ids
out_tokens = model.generate(input_ids)
out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
return out_str[0]
```
```
# an example from validation set that the model predicts correctly
# you can try your own examples here. what's your noble title?
input = "Sophie Valdemarsdottir| noble title"
out = topkSample(input, model, tokenizer, num_samples=5)
out
```
You can further load the list of entity aliases, then filter only those predictions which are valid entities then create a reverse mapping from alias -> integer id to get final predictions in required format.
However, loading these aliases in memory as a dictionary requires a lot of RAM + you need to download the aliases file (made available here https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle) (relation file: https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle)
The submitted validation/test results for were obtained by sampling 300 times for each input, then applying above procedure, followed by filtering known entities. The final MRR can vary slightly due to this sampling nature (we found that although beam search gives deterministic output, the results are inferior to sampling large number of times).
```
# download valid.txt. you can also try same url with test.txt. however test does not contain the correct tails
!wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
```
```
fname = 'valid.txt'
valid_lines = []
f = open(fname)
for line in f:
valid_lines.append(line.rstrip())
f.close()
print(valid_lines[0])
```
```
from tqdm.auto import tqdm
# try unfiltered hits@k. this is approximation since model can sample same seq multiple times
# you should run this on gpu if you want to evaluate on all points with 300 samples each
k = 1
count_at_k = 0
max_predictions = k
max_points = 1000
for line in tqdm(valid_lines[:max_points]):
input, target = line.split('\t')
model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
prediction_strings = [x[0] for x in model_output]
if target in prediction_strings:
count_at_k += 1
print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))
``` |