Apoorv Umang commited on
Commit
16de72d
2 Parent(s): 7b5d79b 72d8967

Merge branch 'main' of https://huggingface.co/apoorvumang/kgt5-wikikg90mv2 into main

Browse files
Files changed (1) hide show
  1. README.md +107 -1
README.md CHANGED
@@ -10,4 +10,110 @@ widget:
10
  - text: "Q12345678| follows"
11
  example_title: "follows"
12
  ---
13
- This is a t5-small model trained on WikiKG90Mv2 dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  - text: "Q12345678| follows"
11
  example_title: "follows"
12
  ---
13
+ This is a t5-small model trained from scratch on WikiKG90Mv2 dataset. Please see https://github.com/apoorvumang/transformer-kgc/ for more details on the method.
14
+
15
+ This model was trained on the tail entity prediction task ie. given subject entity and relation, predict the object entity. Input should be provided in the form of "\<entity text\>| \<relation text\>".
16
+
17
+ We used the raw text title and descriptions to get entity and relation textual representations. These raw texts were obtained from ogb dataset itself (dataset/wikikg90m-v2/mapping/entity.csv and relation.csv). Entity representation was set to the title, and description was used to disambiguate if 2 entities had the same title. If still no disambiguation was possible, we used the wikidata ID (eg. Q123456).
18
+
19
+ We trained the model on WikiKG90Mv2 for approx 1.5 epochs on 4x1080Ti GPUs. The training time for 1 epoch was approx 5.5 days.
20
+
21
+ To evaluate the model, we sample 300 times from the decoder for each input (s,r) pair. We then remove predictions which do not map back to a valid entity, and then rank the predictions by their log probabilities. Filtering was performed subsequently.
22
+
23
+ You can try the following code in an ipython notebook to evaluate the pre-trained model. The full procedure of mapping entity to ids, filtering etc. is not included here for sake of simplicity but can be provided on request if needed. Please contact Apoorv ([email protected]) for clarifications/details.
24
+ ---------
25
+ ```
26
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
27
+ tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
28
+ model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-wikikg90mv2")
29
+ ```
30
+ ```
31
+ import torch
32
+
33
+ def getScores(ids, scores, pad_token_id):
34
+ """get sequence scores from model.generate output"""
35
+ scores = torch.stack(scores, dim=1)
36
+ log_probs = torch.log_softmax(scores, dim=2)
37
+ # remove start token
38
+ ids = ids[:,1:]
39
+ # gather needed probs
40
+ x = ids.unsqueeze(-1).expand(log_probs.shape)
41
+ needed_logits = torch.gather(log_probs, 2, x)
42
+ final_logits = needed_logits[:, :, 0]
43
+ padded_mask = (ids == pad_token_id)
44
+ final_logits[padded_mask] = 0
45
+ final_scores = final_logits.sum(dim=-1)
46
+ return final_scores.cpu().detach().numpy()
47
+
48
+ def topkSample(input, model, tokenizer,
49
+ num_samples=5,
50
+ num_beams=1,
51
+ max_output_length=30):
52
+ tokenized = tokenizer(input, return_tensors="pt")
53
+ out = model.generate(**tokenized,
54
+ do_sample=True,
55
+ num_return_sequences = num_samples,
56
+ num_beams = num_beams,
57
+ eos_token_id = tokenizer.eos_token_id,
58
+ pad_token_id = tokenizer.pad_token_id,
59
+ output_scores = True,
60
+ return_dict_in_generate=True,
61
+ max_length=max_output_length,)
62
+ out_tokens = out.sequences
63
+ out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
64
+ out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
65
+
66
+ pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
67
+ sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
68
+ return sorted_pair_list
69
+
70
+ def greedyPredict(input, model, tokenizer):
71
+ input_ids = tokenizer([input], return_tensors="pt").input_ids
72
+ out_tokens = model.generate(input_ids)
73
+ out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
74
+ return out_str[0]
75
+ ```
76
+
77
+ ```
78
+ # an example from validation set that the model predicts correctly
79
+ # you can try your own examples here. what's your noble title?
80
+ input = "Sophie Valdemarsdottir| noble title"
81
+ out = topkSample(input, model, tokenizer, num_samples=5)
82
+ out
83
+ ```
84
+
85
+ You can further load the list of entity aliases, then filter only those predictions which are valid entities then create a reverse mapping from alias -> integer id to get final predictions in required format.
86
+
87
+ However, loading these aliases in memory as a dictionary requires a lot of RAM + you need to download the aliases file (made available here)
88
+
89
+ The submitted validation/test results for were obtained by sampling 300 times for each input, then applying above procedure, followed by filtering known entities. The final MRR can vary slightly due to this sampling nature (we found that although beam search gives deterministic output, the results are inferior to sampling large number of times).
90
+
91
+ ```
92
+ # download valid.txt. you can also try same url with test.txt. however test does not contain the correct tails
93
+ !wget https://storage.googleapis.com/kgt5-wikikg90mv2/valid.txt
94
+ ```
95
+ ```
96
+ fname = 'valid.txt'
97
+ valid_lines = []
98
+ f = open(fname)
99
+ for line in f:
100
+ valid_lines.append(line.rstrip())
101
+ f.close()
102
+ print(valid_lines[0])
103
+ ```
104
+ ```
105
+ from tqdm.auto import tqdm
106
+ # try unfiltered hits@k. this is approximation since model can sample same seq multiple times
107
+ # you should run this on gpu if you want to evaluate on all points with 300 samples each
108
+ k = 1
109
+ count_at_k = 0
110
+ max_predictions = k
111
+ max_points = 1000
112
+ for line in tqdm(valid_lines[:max_points]):
113
+ input, target = line.split('\t')
114
+ model_output = topkSample(input, model, tokenizer, num_samples=max_predictions)
115
+ prediction_strings = [x[0] for x in model_output]
116
+ if target in prediction_strings:
117
+ count_at_k += 1
118
+ print('Hits at {0} unfiltered: {1}'.format(k, count_at_k/max_points))
119
+ ```