File size: 555 Bytes
0b69648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

dummy_input = next(iter(dataset))["text"]

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
z = model(input_ids)