gizemsarsinlar commited on
Commit
07f8442
1 Parent(s): cd7b378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -42,13 +42,36 @@ def initialize_model():
42
  # Initialize the model and processor
43
  model, processor, device = initialize_model()
44
 
45
- def run_example(task_prompt, image, text_input=None):
46
- prompt = task_prompt if text_input is None else task_prompt + text_input
47
- inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
48
- with torch.inference_mode():
49
- generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
51
- return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
 
 
 
 
 
 
52
 
53
  def fig_to_pil(fig):
54
  buf = io.BytesIO()
 
42
  # Initialize the model and processor
43
  model, processor, device = initialize_model()
44
 
45
+ # def run_example(task_prompt, image, text_input=None):
46
+ # prompt = task_prompt if text_input is None else task_prompt + text_input
47
+ # inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
48
+ # with torch.inference_mode():
49
+ # generated_ids = model.generate(**inputs, max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3)
50
+ # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
51
+ # return processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.size[0], image.size[1]))
52
+
53
+ def run_example(task_prompt, text_input=None):
54
+ if text_input is None:
55
+ prompt = task_prompt
56
+ else:
57
+ prompt = task_prompt + text_input
58
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
59
+ generated_ids = model.generate(
60
+ input_ids=inputs["input_ids"],
61
+ pixel_values=inputs["pixel_values"],
62
+ max_new_tokens=1024,
63
+ early_stopping=False,
64
+ do_sample=False,
65
+ num_beams=3,
66
+ )
67
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
68
+ parsed_answer = processor.post_process_generation(
69
+ generated_text,
70
+ task=task_prompt,
71
+ image_size=(image.width, image.height)
72
+ )
73
+
74
+ return parsed_answer
75
 
76
  def fig_to_pil(fig):
77
  buf = io.BytesIO()