momergul commited on
Commit
14eba99
1 Parent(s): 554adbb
Files changed (3) hide show
  1. app.py +153 -89
  2. joint_inference.py +1 -6
  3. models.py +2 -2
app.py CHANGED
@@ -21,9 +21,9 @@ css="""
21
  """
22
 
23
  def initialize_game() -> List[List[str]]:
24
- context_dicts = [generate_complete_game() for _ in range(2)]
25
 
26
- roles = ["speaker"] * 3 + ["listener"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
@@ -71,6 +71,7 @@ def get_model_response(
71
  @spaces.GPU(duration=20)
72
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
73
  model.model.set_adapter(adapter_name)
 
74
  model = model.cuda()
75
  with torch.no_grad():
76
  captions, _, _, _, _ = model.generate(
@@ -85,6 +86,7 @@ def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask
85
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
86
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
87
  model.model.set_adapter(adapter_name)
 
88
  model = model.cuda()
89
  with torch.no_grad():
90
  _, _, joint_log_probs = model.comprehension_side([
@@ -95,71 +97,118 @@ def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_at
95
  response = image_paths[target_idx]
96
  return response
97
 
98
- def interaction(model, processor, index_to_token, model_iteration: str) -> Tuple[List[str], List[str]]:
99
- image_role_pairs = initialize_game()
100
- conversation = []
101
- turn = 0
102
- num_correct = 0
103
- human_role = None
104
- adapter_name = "initial" if model_iteration == "Initial System" else "final"
105
- internal_model = model
106
-
107
- for speaker_image, listener_image, target_image, model_role in image_role_pairs:
108
- acc_message = f"{num_correct}/{turn}"
109
- if model_role == "speaker":
110
- human_role = "Listener"
111
- turn += 1
112
- turn_message = f"{turn}/6"
113
- human_context = listener_image
114
- model_context = speaker_image
115
- target_idx = human_context.index(target_image)
116
-
117
- conversation.extend([
118
- f"TURN: {turn}/6",
119
- f"Guess the target image given the speaker's description. ",
120
- ])
121
- model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, target_image=target_image)
122
- conversation.append(f"Model: {model_message}")
123
- conversation.append("You: The target is Image ")
124
- user_message = yield human_context, conversation, human_role, turn_message, acc_message
125
-
126
- conversation[-1] += f"{user_message}"
127
- if int(user_message) == target_idx + 1:
128
- conversation.append("Correct!\n")
129
- num_correct += 1
130
- else:
131
- conversation.append(f"Incorrect!\n")
132
- else:
133
- # listener
134
- human_role = "Speaker"
135
- turn += 1
136
- turn_message = f"{turn}/6"
137
- human_context = speaker_image
138
- model_context = listener_image
139
- target_idx = human_context.index(target_image)
140
-
141
- conversation.extend([
142
- f"TURN: {turn}/6",
143
- f"Generate a description for the target image. Your target is Image {target_idx + 1}",
144
- ])
145
-
146
- user_message = yield human_context, conversation, human_role, turn_message, acc_message
147
- conversation.append(f"You: {user_message}")
148
- model_message = get_model_response(internal_model, adapter_name, processor, index_to_token, model_role, model_context, user_message=user_message)
149
- model_idx = human_context.index(model_message)
150
-
151
- if int(model_idx) == int(target_idx):
152
- conversation.append("The model guessed correctly!\n")
153
- num_correct += 1
154
- else:
155
- conversation.append(f"The model guessed incorrectly.\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- acc_message = f"{num_correct}/{turn}"
158
- conversation.append("The game is over!")
159
- yield human_context, conversation, human_role, turn_message, acc_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  def create_app():
162
  with gr.Blocks(css=css) as app:
 
 
163
  gr.Markdown("# Tangram Reference Game")
164
  gr.Markdown(
165
  '### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\
@@ -207,51 +256,66 @@ def create_app():
207
  interactive=False,
208
  )
209
 
210
- send_btn = gr.Button("Send")
211
-
212
- interaction_generator = None
213
  model = get_model()
214
  processor = get_processor()
215
  index_to_token = get_index_to_token()
216
 
217
- print("Heyo!")
218
  def start_interaction(model_iteration):
 
219
  if model_iteration is None:
220
  return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \
221
- gr.update(interactive=False), gr.update(interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- nonlocal interaction_generator
224
  nonlocal model
225
  nonlocal processor
226
  nonlocal index_to_token
227
- interaction_generator = interaction(model, processor, index_to_token, model_iteration)
228
- images, conversation, role, turn, acc_message = next(interaction_generator)
229
- human_listener = role == "Listener"
230
- return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
231
- gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True)
232
 
233
- def send_message(message, radio_choice):
234
- nonlocal interaction_generator
235
- if interaction_generator is None:
236
- return [], "Please start the interaction first.", "", gr.update(interactive=False), gr.update(interactive=False, value=None)
237
-
238
- try:
239
- user_output = message if radio_choice is None else radio_choice
240
- images, conversation, role, turn, acc_message = interaction_generator.send(user_output)
241
- human_listener = role == "Listener"
242
- return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, acc_message, \
243
- gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), gr.update(interactive=True)
244
- except StopIteration:
245
- return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=False)
246
 
247
  start_btn.click(
248
  start_interaction,
249
  inputs=[model_iteration],
250
- outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn]
 
 
 
 
 
 
 
 
 
 
251
  )
252
- send_btn.click(send_message, inputs=[user_input, radio_buttons], outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input, radio_buttons, send_btn])
253
 
254
  return app
255
 
256
  app = create_app()
 
257
  app.launch()
 
21
  """
22
 
23
  def initialize_game() -> List[List[str]]:
24
+ context_dicts = [generate_complete_game() for _ in range(4)]
25
 
26
+ roles = ["listener"] * 3 + ["speaker"] * 3 + ["listener"] * 3 + ["speaker"] * 3
27
  speaker_images = []
28
  listener_images = []
29
  targets = []
 
71
  @spaces.GPU(duration=20)
72
  def get_speaker_response(model, images, input_tokens, attn_mask, image_attn_mask, label, image_paths, processor, img_dir, index_to_token, adapter_name):
73
  model.model.set_adapter(adapter_name)
74
+ print(adapter_name)
75
  model = model.cuda()
76
  with torch.no_grad():
77
  captions, _, _, _, _ = model.generate(
 
86
  def get_listener_response(model, images, l_input_tokens, l_attn_mask, l_image_attn_mask, index_to_token,
87
  s_input_tokens, s_attn_mask, s_image_attn_mask, s_target_mask, s_target_label, image_paths, adapter_name):
88
  model.model.set_adapter(adapter_name)
89
+ print(adapter_name)
90
  model = model.cuda()
91
  with torch.no_grad():
92
  _, _, joint_log_probs = model.comprehension_side([
 
97
  response = image_paths[target_idx]
98
  return response
99
 
100
+ def initialize_interaction(model_iteration):
101
+ # initialize the overall history
102
+ new_history = {
103
+ 'adapter_name' : 'initial' if model_iteration == "Initial System" else "final",
104
+ 'image_role_pairs' : initialize_game(),
105
+ 'conversation' : [],
106
+ 'turn' : 0,
107
+ 'num_correct' : 0,
108
+ }
109
+
110
+ # Initialize the first turn (always a listener)
111
+ turn = new_history['turn']
112
+ image_role_pairs = new_history['image_role_pairs']
113
+ speaker_image, listener_image, target_image, _ = image_role_pairs[turn]
114
+ target_idx = speaker_image.index(target_image)
115
+ new_history['conversation'].extend([
116
+ f"TURN: {turn + 1}/12",
117
+ f"Generate a description for the target image. Your target is Image {target_idx + 1}"
118
+ ])
119
+
120
+ return new_history
121
+
122
+ def progress_game(user_message, model, processor, index_to_token, current_state):
123
+ # First get the game state
124
+ turn = current_state['turn']
125
+ image_role_pairs = current_state['image_role_pairs']
126
+ speaker_image, listener_image, target_image, model_role = image_role_pairs[turn]
127
+ human_role = "Speaker" if model_role == "listener" else "Listener"
128
+
129
+ # Next, move on with current turn
130
+ if model_role == "listener":
131
+ human_context = speaker_image
132
+ model_context = listener_image
133
+
134
+ # If model is a listener, the human must have sent a message
135
+ current_state['conversation'].append(f"You: {user_message}")
136
+ model_message = get_model_response(
137
+ model, current_state['adapter_name'], processor, index_to_token, model_role,
138
+ model_context, user_message=user_message
139
+ )
140
+ model_idx = human_context.index(model_message)
141
+ target_idx = human_context.index(target_image)
142
+
143
+ if int(model_idx) == int(target_idx):
144
+ current_state['conversation'].append("The model guessed correctly!\n")
145
+ current_state['num_correct'] += 1
146
+ else:
147
+ current_state['conversation'].append(f"The model guessed incorrectly.\n")
148
+ else:
149
+ human_context = listener_image
150
+ model_context = speaker_image
151
+
152
+ # If model is a speaker, the human must have made a guess
153
+ target_idx = human_context.index(target_image)
154
+ current_state['conversation'][-1] += f"{user_message}"
155
+ if int(user_message) == target_idx + 1:
156
+ current_state['conversation'].append("Correct!\n")
157
+ current_state['num_correct'] += 1
158
+ else:
159
+ current_state['conversation'].append(f"Incorrect!\n")
160
+
161
+ # We move on to the next turn
162
+ current_state['turn'] += 1
163
+ acc_message = f"{current_state['num_correct']}/{current_state['turn']}"
164
+ turn_message = f"{current_state['turn'] + 1}/12"
165
+ if current_state['turn'] == len(image_role_pairs):
166
+ current_state['conversation'].append('The game is over!')
167
+ return human_context, current_state['conversation'], human_role, turn_message, acc_message, {}
168
+
169
+ speaker_image, listener_image, target_image, model_role = image_role_pairs[current_state['turn']]
170
+ human_role = "Listener" if model_role == "speaker" else "Speaker"
171
+ if model_role == "speaker":
172
+ human_context = listener_image
173
+ model_context = speaker_image
174
+
175
+ current_state['conversation'].extend([
176
+ f"TURN: {current_state['turn'] + 1}/12",
177
+ f"Guess the target image given the speaker's description. ",
178
+ ])
179
+ model_message = get_model_response(model, current_state['adapter_name'], processor, index_to_token,
180
+ model_role, model_context, target_image=target_image)
181
+ current_state['conversation'].append(f"Model: {model_message}")
182
+ current_state['conversation'].append("You: The target is Image ")
183
+ else:
184
+ human_context = speaker_image
185
+ model_context = listener_image
186
+ target_idx = human_context.index(target_image)
187
 
188
+ current_state['conversation'].extend([
189
+ f"TURN: {current_state['turn'] + 1}/12",
190
+ f"Generate a description for the target image. Your target is Image {target_idx + 1}",
191
+ ])
192
+
193
+ return human_context, current_state['conversation'], human_role, turn_message, acc_message, current_state
194
+
195
+ def get_current_images(current_history):
196
+ turn = current_history['turn']
197
+ image_role_pairs = current_history['image_role_pairs']
198
+ speaker_image, listener_image, target_image, model_role = image_role_pairs[turn]
199
+ human_context = listener_image if model_role == "speaker" else speaker_image
200
+ return human_context
201
+
202
+ def get_human_role(current_history):
203
+ turn = current_history['turn']
204
+ image_role_pairs = current_history['image_role_pairs']
205
+ speaker_image, listener_image, target_image, model_role = image_role_pairs[turn]
206
+ return "Listener" if model_role == "speaker" else "Speaker"
207
 
208
  def create_app():
209
  with gr.Blocks(css=css) as app:
210
+ game_history = gr.State(value={})
211
+
212
  gr.Markdown("# Tangram Reference Game")
213
  gr.Markdown(
214
  '### You will be playing a sequence of reference games against a model. To start a game, first select whether ' +\
 
256
  interactive=False,
257
  )
258
 
259
+ send_btn = gr.Button("Send", interactive=False)
 
 
260
  model = get_model()
261
  processor = get_processor()
262
  index_to_token = get_index_to_token()
263
 
 
264
  def start_interaction(model_iteration):
265
+ # Initialize the interaction
266
  if model_iteration is None:
267
  return [], "Please select a model iteration.", "", "", "", gr.update(interactive=False), \
268
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), {}
269
+ current_history = initialize_interaction(model_iteration)
270
+
271
+ # Unpack the relevant items
272
+ images = get_current_images(current_history)
273
+ conversation = current_history["conversation"]
274
+ role = get_human_role(current_history)
275
+ human_listener = role == "Listener"
276
+
277
+ current_turn = current_history['turn'] + 1
278
+ turn_msg = f"{current_turn}/12"
279
+ acc_msg = "0/0"
280
+ return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn_msg, acc_msg, \
281
+ gr.update(interactive=not human_listener), gr.update(interactive=human_listener), gr.update(interactive=True), gr.update(interactive=False), current_history
282
 
283
+ def send_message(message, radio_choice, current_state):
284
  nonlocal model
285
  nonlocal processor
286
  nonlocal index_to_token
 
 
 
 
 
287
 
288
+ # Game ended
289
+ if current_state['turn'] == len(current_state['image_role_pairs']):
290
+ return [], conversation_output.value, current_role.value, current_turn.value, accuracy.value, gr.update(interactive=False), \
291
+ gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, value=None), {}
292
+
293
+ # Regular game progress
294
+ user_output = message if radio_choice is None else radio_choice
295
+ images, conversation, role, turn, acc_message, current_state = progress_game(user_output, model, processor, index_to_token, current_state)
296
+ human_listener = role == "Listener"
297
+ return [(f"tangram_pngs/{img}", f"Image {i+1}") for i, img in enumerate(images)], "\n".join(conversation), role, turn, \
298
+ acc_message, gr.update(interactive=not human_listener, value=""), gr.update(interactive=human_listener, value=None), \
299
+ gr.update(interactive=True), gr.update(interactive=False), current_state
 
300
 
301
  start_btn.click(
302
  start_interaction,
303
  inputs=[model_iteration],
304
+ outputs=[
305
+ image_output, conversation_output, current_role, current_turn, accuracy,
306
+ user_input, radio_buttons, send_btn, model_iteration, game_history],
307
+ queue=False
308
+ )
309
+ send_btn.click(
310
+ send_message,
311
+ inputs=[user_input, radio_buttons, game_history],
312
+ outputs=[image_output, conversation_output, current_role, current_turn, accuracy, user_input,
313
+ radio_buttons, send_btn, model_iteration, game_history],
314
+ queue=True
315
  )
 
316
 
317
  return app
318
 
319
  app = create_app()
320
+ app.queue()
321
  app.launch()
joint_inference.py CHANGED
@@ -346,7 +346,6 @@ class IdeficsJointInferenceModel(nn.Module):
346
  speaker = self.get_speaker()
347
  generation_config = GenerationConfig(
348
  max_new_tokens=max_steps,
349
- min_new_tokens=1,
350
  do_sample=True,
351
  temperature=temperature,
352
  top_k=top_k, top_p=top_p,
@@ -429,6 +428,7 @@ class IdeficsJointInferenceModel(nn.Module):
429
  speaker = self.get_speaker()
430
  generation_config = GenerationConfig(
431
  max_new_tokens=max_steps,
 
432
  do_sample=True,
433
  temperature=temperature,
434
  top_k=top_k, top_p=top_p,
@@ -438,11 +438,6 @@ class IdeficsJointInferenceModel(nn.Module):
438
  return_dict_in_generate=True
439
  )
440
 
441
- print(torch.any(torch.isnan(s_input_tokens)))
442
- print(torch.any(torch.isnan(s_attn_mask)))
443
- print(torch.any(torch.isnan(images)))
444
- print(torch.any(torch.isnan(s_image_attn_mask)))
445
-
446
  outputs = speaker.generate(
447
  input_ids=s_input_tokens,
448
  attention_mask=s_attn_mask,
 
346
  speaker = self.get_speaker()
347
  generation_config = GenerationConfig(
348
  max_new_tokens=max_steps,
 
349
  do_sample=True,
350
  temperature=temperature,
351
  top_k=top_k, top_p=top_p,
 
428
  speaker = self.get_speaker()
429
  generation_config = GenerationConfig(
430
  max_new_tokens=max_steps,
431
+ min_new_tokens=1,
432
  do_sample=True,
433
  temperature=temperature,
434
  top_k=top_k, top_p=top_p,
 
438
  return_dict_in_generate=True
439
  )
440
 
 
 
 
 
 
441
  outputs = speaker.generate(
442
  input_ids=s_input_tokens,
443
  attention_mask=s_attn_mask,
models.py CHANGED
@@ -11,7 +11,7 @@ def get_model():
11
  # Initialize the model
12
  repo = 'lil-lab/cogen'
13
  checkpoint = "HuggingFaceM4/idefics2-8b"
14
- model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16).cuda()
15
 
16
  # Add LoRA adapters
17
  target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)'
@@ -39,7 +39,7 @@ def get_model():
39
  )
40
  model.add_adapter('final', lora_config)
41
  model.load_adapter(repo, "final", revision="r3_full")
42
- model = IdeficsJointInferenceModel(0.5, 0, model=model).cuda()
43
  model.eval()
44
 
45
  return model
 
11
  # Initialize the model
12
  repo = 'lil-lab/cogen'
13
  checkpoint = "HuggingFaceM4/idefics2-8b"
14
+ model = Idefics2ForConditionalGeneration.from_pretrained(checkpoint, torch_dtype=torch.bfloat16)
15
 
16
  # Add LoRA adapters
17
  target_modules=r'(.*(vision_model|modality_projection|perceiver_resampler).*(out_proj|fc1|fc2|down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)|(.*(k_proj|q_proj|v_proj).*$)'
 
39
  )
40
  model.add_adapter('final', lora_config)
41
  model.load_adapter(repo, "final", revision="r3_full")
42
+ model = IdeficsJointInferenceModel(0.5, 0, model=model)
43
  model.eval()
44
 
45
  return model