File size: 741 Bytes
7ba2844
 
71ebf3b
 
 
7ba2844
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#Source: https://huggingface.co/facebook/rag-token-nq#usage

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, logging

logging.set_verbosity_error()

tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt") 

generated = model.generate(input_ids=input_dict["input_ids"]) 
print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0]) 

# should give michael phelps => sounds reasonable