fl399 ybelkada commited on
Commit
26b9e85
1 Parent(s): 6d7cc85

Update app.py (#1)

Browse files

- Update app.py (7ef96b73fc06e2713e59a7ef0ba34838d1dd624b)


Co-authored-by: Younes Belkada <[email protected]>

Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -148,12 +148,12 @@ def evaluate(
148
 
149
 
150
  ## deplot models
151
- model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot")
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
  def process_document(image, question):
155
  # image = Image.open(image)
156
- inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt")
157
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
158
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
159
 
 
148
 
149
 
150
  ## deplot models
151
+ model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
152
  processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
153
 
154
  def process_document(image, question):
155
  # image = Image.open(image)
156
+ inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(torch.bfloat16, 0)
157
  predictions = model_deplot.generate(**inputs, max_new_tokens=512)
158
  table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
159