Spaces:
Runtime error
Runtime error
"""Visualize some sense vectors""" | |
import torch | |
import argparse | |
import transformers | |
def visualize_word(word, tokenizer, vecs, lm_head, count=20, contents=None): | |
""" | |
Prints out the top-scoring words (and lowest-scoring words) for each sense. | |
""" | |
if contents is None: | |
print(word) | |
token_id = tokenizer(word)['input_ids'][0] | |
contents = vecs[token_id] # torch.Size([16, 768]) | |
for i in range(contents.shape[0]): | |
print('~~~~~~~~~~~~~~~~~~~~~~~{}~~~~~~~~~~~~~~~~~~~~~~~~'.format(i)) | |
logits = contents[i,:] @ lm_head.t() # (vocab,) [768] @ [768, 50257] -> [50257] | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
print('~~~Positive~~~') | |
for j in range(count): | |
print(tokenizer.decode(sorted_indices[j]), '\t','{:.2f}'.format(sorted_logits[j].item())) | |
print('~~~Negative~~~') | |
for j in range(count): | |
print(tokenizer.decode(sorted_indices[-j-1]), '\t','{:.2f}'.format(sorted_logits[-j-1].item())) | |
return contents | |
print() | |
print() | |
print() | |
argp = argparse.ArgumentParser() | |
argp.add_argument('vecs_path') | |
argp.add_argument('lm_head_path') | |
args = argp.parse_args() | |
# Load tokenizer and parameters | |
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2') | |
vecs = torch.load(args.vecs_path) | |
lm_head = torch.load(args.lm_head_path) | |
visualize_word(input('Enter a word:'), tokenizer, vecs, lm_head, count=5) | |