Li commited on
Commit
ab6ff71
1 Parent(s): a230c75

update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -12
app.py CHANGED
@@ -6,7 +6,7 @@ from PIL import Image
6
 
7
  from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
8
  import string
9
-
10
 
11
 
12
  import gradio as gr
@@ -44,14 +44,13 @@ def generate(
44
  idx,
45
  image,
46
  text,
47
- tsvfile,
48
  vis_embed_size=256,
49
  rank=0,
50
  world_size=1,
51
  ):
52
  if image is None:
53
  raise gr.Error("Please upload an image.")
54
- flamingo.eval().cuda()
55
  loc_token_ids = []
56
  for i in range(1000):
57
  loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
@@ -70,7 +69,12 @@ def generate(
70
  height = image.height
71
  image = image.resize((224, 224))
72
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
73
- prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}"]
 
 
 
 
 
74
  encodings = tokenizer(
75
  prompt,
76
  padding="longest",
@@ -85,13 +89,13 @@ def generate(
85
  image_nums = [1] * len(input_ids)
86
  outputs = get_outputs(
87
  model=flamingo,
88
- batch_images=batch_images.cuda(),
89
- attention_mask=attention_mask.cuda(),
90
  max_generation_length=5,
91
  min_generation_length=4,
92
  num_beams=1,
93
  length_penalty=1.0,
94
- input_ids=input_ids.cuda(),
95
  bad_words_ids=bad_words_ids,
96
  image_start_index_list=image_start_index_list,
97
  image_nums=image_nums,
@@ -106,12 +110,23 @@ def generate(
106
  # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
107
  # tqdm.write(f"prompt: {prompt}")
108
 
 
 
 
 
 
 
 
 
 
 
109
  gen_text = tokenizer.batch_decode(outputs)
110
- return (
111
- f"Output:{gen_text}"
112
- if idx != 2
113
- else f"Question: {text.strip()} Answer: {gen_text}"
114
- )
 
115
 
116
 
117
  with gr.Blocks() as demo:
 
6
 
7
  from open_flamingo.train.distributed import init_distributed_device, world_info_from_env
8
  import string
9
+ import cv2
10
 
11
 
12
  import gradio as gr
 
44
  idx,
45
  image,
46
  text,
 
47
  vis_embed_size=256,
48
  rank=0,
49
  world_size=1,
50
  ):
51
  if image is None:
52
  raise gr.Error("Please upload an image.")
53
+ flamingo.eval()
54
  loc_token_ids = []
55
  for i in range(1000):
56
  loc_token_ids.append(int(tokenizer(f"<loc_{i}>", add_special_tokens=False)["input_ids"][-1]))
 
69
  height = image.height
70
  image = image.resize((224, 224))
71
  batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0)
72
+ if idx ==1:
73
+ prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}"]
74
+ bad_words_ids = bad_words_ids
75
+ else:
76
+ prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
77
+ bad_words_ids = None
78
  encodings = tokenizer(
79
  prompt,
80
  padding="longest",
 
89
  image_nums = [1] * len(input_ids)
90
  outputs = get_outputs(
91
  model=flamingo,
92
+ batch_images=batch_images,
93
+ attention_mask=attention_mask,
94
  max_generation_length=5,
95
  min_generation_length=4,
96
  num_beams=1,
97
  length_penalty=1.0,
98
+ input_ids=input_ids,
99
  bad_words_ids=bad_words_ids,
100
  image_start_index_list=image_start_index_list,
101
  image_nums=image_nums,
 
110
  # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
111
  # tqdm.write(f"prompt: {prompt}")
112
 
113
+ if len(box) == 4:
114
+ img = cv2.cvtColor(np.array(image_ori), cv2.COLOR_RGB2BGR)
115
+ out = cv2.rectangle(img, (int(box[0] * width / 1000), int(box[1] * height / 1000)),
116
+ (int(box[2] * width / 1000), int(box[3] * height / 1000)), color=(255, 0, 255), thickness=2)
117
+ out = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
118
+ out_image = Image.fromarray(out)
119
+ # else:
120
+ # tqdm.write(f"output: {tokenizer.batch_decode(outputs)}")
121
+ # tqdm.write(f"prompt: {prompt}")
122
+
123
  gen_text = tokenizer.batch_decode(outputs)
124
+ if idx == 1:
125
+ return f"Output:{gen_text}", out_image
126
+ elif idx == 2:
127
+ return (f"Question: {text.strip()} Answer: {gen_text}")
128
+ else:
129
+ return (f"Output:{gen_text}")
130
 
131
 
132
  with gr.Blocks() as demo: