nielsr HF staff commited on
Commit
3381383
1 Parent(s): e9848d6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import ViltProcessor, ViltForNaturalLanguageVisualReasoning
3
  import torch
4
 
5
  # NLRV2 example images
@@ -8,16 +8,17 @@ torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/ex0_1.jpg',
8
  torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_1.jpg', 'image3.jpg')
9
  torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_6.jpg', 'image4.jpg')
10
 
11
- processor = ViltProcessor.from_pretrained("nielsr/vilt-b32-finetuned-nlvr2")
12
- model = ViltForNaturalLanguageVisualReasoning.from_pretrained("nielsr/vilt-b32-finetuned-nlvr2")
13
 
14
  def predict(image1, image2, text):
15
- encoding_1 = processor(image1, text, return_tensors="pt")
16
- encoding_2 = processor(image2, text, return_tensors="pt")
 
17
 
18
  # forward pass
19
  with torch.no_grad():
20
- outputs = model(input_ids=encoding_1.input_ids, pixel_values=encoding_1.pixel_values, pixel_values_2=encoding_2.pixel_values)
21
 
22
  logits = outputs.logits
23
  probs = torch.nn.functional.softmax(logits, dim=1)
 
1
  import gradio as gr
2
+ from transformers import ViltProcessor, ViltForImagesAndTextClassification
3
  import torch
4
 
5
  # NLRV2 example images
 
8
  torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_1.jpg', 'image3.jpg')
9
  torch.hub.download_url_to_file('https://lil.nlp.cornell.edu/nlvr/exs/acorns_6.jpg', 'image4.jpg')
10
 
11
+ processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
12
+ model = ViltForImagesAndTextClassification.from_pretrained("dandelin/vilt-b32-finetuned-nlvr2")
13
 
14
  def predict(image1, image2, text):
15
+ # prepare inputs
16
+ encoding = processor([image1, image2], text, return_tensors="pt")
17
+ pixel_values = torch.stack([encoding_1.pixel_values, encoding_2.pixel_values], dim=1)
18
 
19
  # forward pass
20
  with torch.no_grad():
21
+ outputs = model(input_ids=encoding.input_ids, pixel_values=encoding.pixel_values.unsqueeze(0))
22
 
23
  logits = outputs.logits
24
  probs = torch.nn.functional.softmax(logits, dim=1)