johnsamuel commited on
Commit
7ba2844
1 Parent(s): 2819fb8
Files changed (1) hide show
  1. app.py +14 -0
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Source: https://huggingface.co/facebook/rag-token-nq#usage
2
+
3
+ from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
4
+
5
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
6
+ retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
7
+ model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
8
+
9
+ input_dict = tokenizer.prepare_seq2seq_batch("who holds the record in 100m freestyle", return_tensors="pt")
10
+
11
+ generated = model.generate(input_ids=input_dict["input_ids"])
12
+ print(tokenizer.batch_decode(generated, skip_special_tokens=True)[0])
13
+
14
+ # should give michael phelps => sounds reasonable