chenzizhao commited on
Commit
87b7a45
1 Parent(s): 2f56479
Files changed (4) hide show
  1. adapter.py +7 -6
  2. app.py +37 -48
  3. config_generator.py +2 -2
  4. utils.py +2 -2
adapter.py CHANGED
@@ -6,6 +6,7 @@ from typing import List, Set, Tuple, TypeVar
6
 
7
  import torch
8
  from PIL import Image
 
9
 
10
  from utils import device, nested_apply, sorted_list
11
 
@@ -55,12 +56,12 @@ class IdeficsAdapter:
55
  .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
56
  SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
57
 
58
- def __init__(self, image_folder: str, processor) -> None:
59
  self.t_max_length = 2048
60
  self.image_folder = Path(image_folder)
61
  self.image_cache = {}
62
  self.processor = processor
63
- self.tokenizer = self.processor.tokenizer
64
 
65
  def get_image(self, im_name: N) -> Image.Image:
66
  if im_name not in self.image_cache:
@@ -68,10 +69,10 @@ class IdeficsAdapter:
68
  self.image_folder.joinpath(im_name))
69
  return self.image_cache[im_name]
70
 
71
- def unhash(self, context: List[N], c: str):
72
  return AlphabeticNameHash(tuple(context)).unhash(c)
73
 
74
- def valid_hash(self, context: List[N], c: str):
75
  return AlphabeticNameHash(tuple(context)).valid_hash(c)
76
 
77
  def parse(self, context: List[N], decoded_out: str,
@@ -121,13 +122,13 @@ class IdeficsAdapter:
121
  xs = re.search(select_pattern, last_answer)
122
  if xs is not None:
123
  xs = xs.group()
124
- selections = set(xs.split(" ")[1:]) if xs else set()
125
 
126
  deselect_pattern = r"^deselect( [A-J])+"
127
  xs = re.search(deselect_pattern, last_answer)
128
  if xs is not None:
129
  xs = xs.group()
130
- deselections = set(xs.split(" ")[1:]) if xs else set()
131
 
132
  return selections, deselections
133
 
 
6
 
7
  import torch
8
  from PIL import Image
9
+ from transformers import Idefics2Processor, PreTrainedTokenizer
10
 
11
  from utils import device, nested_apply, sorted_list
12
 
 
56
  .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
57
  SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
58
 
59
+ def __init__(self, image_folder: str, processor: Idefics2Processor) -> None:
60
  self.t_max_length = 2048
61
  self.image_folder = Path(image_folder)
62
  self.image_cache = {}
63
  self.processor = processor
64
+ self.tokenizer: PreTrainedTokenizer = self.processor.tokenizer # type: ignore
65
 
66
  def get_image(self, im_name: N) -> Image.Image:
67
  if im_name not in self.image_cache:
 
69
  self.image_folder.joinpath(im_name))
70
  return self.image_cache[im_name]
71
 
72
+ def unhash(self, context: List[N], c: str) -> N:
73
  return AlphabeticNameHash(tuple(context)).unhash(c)
74
 
75
+ def valid_hash(self, context: List[N], c: str) -> bool:
76
  return AlphabeticNameHash(tuple(context)).valid_hash(c)
77
 
78
  def parse(self, context: List[N], decoded_out: str,
 
122
  xs = re.search(select_pattern, last_answer)
123
  if xs is not None:
124
  xs = xs.group()
125
+ selections: Set[N] = set(xs.split(" ")[1:]) if xs else set()
126
 
127
  deselect_pattern = r"^deselect( [A-J])+"
128
  xs = re.search(deselect_pattern, last_answer)
129
  if xs is not None:
130
  xs = xs.group()
131
+ deselections: Set[N] = set(xs.split(" ")[1:]) if xs else set()
132
 
133
  return selections, deselections
134
 
app.py CHANGED
@@ -3,13 +3,13 @@ import logging
3
  import os
4
  from typing import Any, Dict, List
5
 
6
- import gradio as gr # type: ignore
7
  import PIL.Image as Image
8
  import PIL.ImageOps as ImageOps
9
- import spaces # type: ignore
10
  import torch
11
- from peft import PeftModel # type: ignore
12
- from transformers import AutoProcessor # type: ignore
13
  from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
14
 
15
  from adapter import IdeficsAdapter
@@ -18,15 +18,6 @@ from utils import device, nested_to_device, sorted_list
18
  import copy
19
 
20
  ### Constants
21
- css="""
22
- .radio-group .wrap {
23
- display: grid;
24
- grid-template-columns: repeat(5, 1fr);
25
- grid-template-rows: repeat(5, 1fr);
26
- width: 100%;
27
- height: 100%
28
- }
29
- """
30
  IMG_DIR = "tangram_pngs"
31
 
32
 
@@ -56,18 +47,18 @@ def get_model_response( # predict
56
 
57
  new_chats = chats + [chat]
58
  currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
59
- model_input: Dict[str, Any] = adapter.compose( # type: ignore
60
  image_paths, new_chats, previous_selected, True, False)
61
- model_input = nested_to_device(model_input) # type: ignore
62
 
63
  with torch.inference_mode(), torch.autocast(device_type=device().type,
64
  dtype=torch.bfloat16):
65
- model_output = model.generate(**model_input, **GEN_KWS) # type: ignore
66
 
67
- decoded_out: str = adapter.tokenizer.decode( # type: ignore
68
  model_output.sequences[0], skip_special_tokens=True)
69
  model_clicks = adapter.parse(
70
- image_paths, decoded_out, currently_selected) # type: ignore
71
 
72
  if len(model_clicks) == 0:
73
  logging.warning("empty clicks by model")
@@ -87,10 +78,9 @@ def get_model_response( # predict
87
  def get_model() -> PeftModel:
88
  model_id = 'lil-lab/respect'
89
  checkpoint = "HuggingFaceM4/idefics2-8b"
90
- model = Idefics2ForConditionalGeneration.from_pretrained( # type: ignore
91
- checkpoint, torch_dtype=torch.bfloat16,
92
- )
93
- peft_model = PeftModel.from_pretrained( # type: ignore
94
  model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
95
 
96
  # Add other adapter - hack to avoid conflict
@@ -105,10 +95,10 @@ def get_model() -> PeftModel:
105
 
106
  def get_processor() -> Idefics2Processor:
107
  checkpoint = "HuggingFaceM4/idefics2-8b"
108
- processor = AutoProcessor.from_pretrained( # type: ignore
109
  checkpoint, do_image_splitting=False,
110
  size={"longest_edge": 224, "shortest_edge": 224})
111
- return processor # type: ignore
112
 
113
  def get_adapter() -> IdeficsAdapter:
114
  processor = get_processor()
@@ -147,7 +137,6 @@ class GameState:
147
  changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
148
 
149
  tangram_list = self._display_context(context, targets, changes, selected)
150
- # return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)]
151
  return tangram_list
152
 
153
  @staticmethod
@@ -234,30 +223,29 @@ def create_app_inner():
234
  To start a game, first select whether you wish to play against our \
235
  initial trained model ("Initial System") or \
236
  our model at the end of continual learning ("Final System") \
237
- and press the "Start Game" button. \
238
- You will take on a "speaker" role at each round. \
239
- Your goal is to describe this image (via a message in the textbox) \
240
- so that the model can guess what it is.'
241
- )
242
 
243
- gr.Markdown("Targets have black borders. Correctly selected targets have green borders. Incorrectly selected targets have red borders. Actions are marked with yellow dot.")
244
-
245
- gr.Markdown("The listener cannot see boxes or colors and the order is different.")
 
 
 
 
 
 
246
 
247
  gr.Markdown(
248
  '### Press "Send" to submit your action to proceed to the next turn. \
249
- You have 10 turns in total.'
250
- )
251
 
252
  with gr.Row():
253
  model_iteration = gr.Radio(["Initial System", "Final System"],
254
  label="Model Iteration",
255
  value="Final System")
256
  start_btn = gr.Button("Start Game")
257
-
258
- with gr.Row():
259
- current_turn = gr.Textbox(label="TURN")
260
- success = gr.Textbox(label="Success")
261
 
262
  with gr.Row():
263
  image_output = gr.Gallery(
@@ -268,9 +256,9 @@ def create_app_inner():
268
 
269
  with gr.Row():
270
  conversation_output = gr.Textbox(label="Interaction History")
271
- user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
272
-
273
- send_btn = gr.Button("Send", interactive=True)
274
 
275
  ### globals
276
  model = get_model()
@@ -280,12 +268,13 @@ def create_app_inner():
280
  ### callbacks
281
  def output_from_state(state: GameState):
282
  has_ended = state.has_ended()
283
- success = "success" if state.has_successfully_ended() else "failure"
 
 
284
  return (
285
  state.markup_images(), # image_output
286
  state.serialize_conversation(), # conversation_output
287
- f"{state.turn+1}/10", # current_turn
288
- success if has_ended else "n/a", # success
289
  gr.update(interactive=not has_ended, value=""), # user_input
290
  gr.update(interactive=not has_ended), # send_btn
291
  gr.update(interactive=has_ended), # model_iteration
@@ -309,7 +298,7 @@ def create_app_inner():
309
  start_btn.click(
310
  on_start_interaction,
311
  inputs=[model_iteration],
312
- outputs=[image_output, conversation_output, current_turn, success,
313
  user_input, send_btn, model_iteration, game_state],
314
  queue=False
315
  )
@@ -317,14 +306,14 @@ def create_app_inner():
317
  send_btn.click(
318
  on_send_message,
319
  inputs=[user_input, game_state],
320
- outputs=[image_output, conversation_output, current_turn, success,
321
  user_input, send_btn, model_iteration, game_state],
322
  queue=True
323
  )
324
 
325
 
326
  def create_app():
327
- with gr.Blocks(css=css) as app:
328
  create_app_inner()
329
  return app
330
 
 
3
  import os
4
  from typing import Any, Dict, List
5
 
6
+ import gradio as gr
7
  import PIL.Image as Image
8
  import PIL.ImageOps as ImageOps
9
+ import spaces
10
  import torch
11
+ from peft import PeftModel
12
+ from transformers import AutoProcessor
13
  from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
14
 
15
  from adapter import IdeficsAdapter
 
18
  import copy
19
 
20
  ### Constants
 
 
 
 
 
 
 
 
 
21
  IMG_DIR = "tangram_pngs"
22
 
23
 
 
47
 
48
  new_chats = chats + [chat]
49
  currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
50
+ model_input: Dict[str, Any] = adapter.compose(
51
  image_paths, new_chats, previous_selected, True, False)
52
+ model_input = nested_to_device(model_input)
53
 
54
  with torch.inference_mode(), torch.autocast(device_type=device().type,
55
  dtype=torch.bfloat16):
56
+ model_output = model.generate(**model_input, **GEN_KWS)
57
 
58
+ decoded_out: str = adapter.tokenizer.decode(
59
  model_output.sequences[0], skip_special_tokens=True)
60
  model_clicks = adapter.parse(
61
+ image_paths, decoded_out, currently_selected)
62
 
63
  if len(model_clicks) == 0:
64
  logging.warning("empty clicks by model")
 
78
  def get_model() -> PeftModel:
79
  model_id = 'lil-lab/respect'
80
  checkpoint = "HuggingFaceM4/idefics2-8b"
81
+ model = Idefics2ForConditionalGeneration.from_pretrained(
82
+ checkpoint, torch_dtype=torch.bfloat16,)
83
+ peft_model = PeftModel.from_pretrained(
 
84
  model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
85
 
86
  # Add other adapter - hack to avoid conflict
 
95
 
96
  def get_processor() -> Idefics2Processor:
97
  checkpoint = "HuggingFaceM4/idefics2-8b"
98
+ processor = AutoProcessor.from_pretrained(
99
  checkpoint, do_image_splitting=False,
100
  size={"longest_edge": 224, "shortest_edge": 224})
101
+ return processor
102
 
103
  def get_adapter() -> IdeficsAdapter:
104
  processor = get_processor()
 
137
  changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
138
 
139
  tangram_list = self._display_context(context, targets, changes, selected)
 
140
  return tangram_list
141
 
142
  @staticmethod
 
223
  To start a game, first select whether you wish to play against our \
224
  initial trained model ("Initial System") or \
225
  our model at the end of continual learning ("Final System") \
226
+ and press the "Start Game" button.')
 
 
 
 
227
 
228
+ gr.Markdown(
229
+ 'You will take on a "speaker" role at each round. \
230
+ Your goal is to describe this image (via a message in the textbox) \
231
+ so that the model can guess what it is.\
232
+ Targets have black borders. \
233
+ Correctly selected targets have green borders. \
234
+ Incorrectly selected targets have red borders. \
235
+ Actions are marked with yellow dot. \
236
+ The listener cannot see boxes or colors and the order is different.')
237
 
238
  gr.Markdown(
239
  '### Press "Send" to submit your action to proceed to the next turn. \
240
+ You have 10 turns in total.')
 
241
 
242
  with gr.Row():
243
  model_iteration = gr.Radio(["Initial System", "Final System"],
244
  label="Model Iteration",
245
  value="Final System")
246
  start_btn = gr.Button("Start Game")
247
+ status = gr.Textbox(label="Status", interactive=False, show_label=False,
248
+ text_align="center", value="Please start a game.")
 
 
249
 
250
  with gr.Row():
251
  image_output = gr.Gallery(
 
256
 
257
  with gr.Row():
258
  conversation_output = gr.Textbox(label="Interaction History")
259
+ with gr.Column():
260
+ user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
261
+ send_btn = gr.Button("Send", interactive=True)
262
 
263
  ### globals
264
  model = get_model()
 
268
  ### callbacks
269
  def output_from_state(state: GameState):
270
  has_ended = state.has_ended()
271
+ success = "Success" if state.has_successfully_ended() else "Failure"
272
+ status = f"{success} (Turn {state.turn}/10) - Start another game?" \
273
+ if has_ended else f"Turn {state.turn+1}/10"
274
  return (
275
  state.markup_images(), # image_output
276
  state.serialize_conversation(), # conversation_output
277
+ status, # status
 
278
  gr.update(interactive=not has_ended, value=""), # user_input
279
  gr.update(interactive=not has_ended), # send_btn
280
  gr.update(interactive=has_ended), # model_iteration
 
298
  start_btn.click(
299
  on_start_interaction,
300
  inputs=[model_iteration],
301
+ outputs=[image_output, conversation_output, status,
302
  user_input, send_btn, model_iteration, game_state],
303
  queue=False
304
  )
 
306
  send_btn.click(
307
  on_send_message,
308
  inputs=[user_input, game_state],
309
+ outputs=[image_output, conversation_output, status,
310
  user_input, send_btn, model_iteration, game_state],
311
  queue=True
312
  )
313
 
314
 
315
  def create_app():
316
+ with gr.Blocks(theme='saq1b/gradio-theme') as app:
317
  create_app_inner()
318
  return app
319
 
config_generator.py CHANGED
@@ -35,8 +35,8 @@ def generate_game_config() -> GameConfig:
35
  return config
36
 
37
  @functools.cache
38
- def _get_data(restricted_dataset: bool=False):
39
- if not restricted_dataset:
40
  # 1013 images
41
  paths = os.listdir(EMPTY_DATA_PATH)
42
  else:
 
35
  return config
36
 
37
  @functools.cache
38
+ def _get_data(hb_split: bool=True):
39
+ if not hb_split:
40
  # 1013 images
41
  paths = os.listdir(EMPTY_DATA_PATH)
42
  else:
utils.py CHANGED
@@ -10,11 +10,11 @@ def sorted_list(s: Set[str]) -> List[str]:
10
  def device():
11
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
- def nested_to_device(s): # type: ignore
14
  # s is either a tensor or a dictionary
15
  if isinstance(s, torch.Tensor):
16
  return s.to(device())
17
- return {k: v.to(device()) for k, v in s.items()} # type: ignore
18
 
19
  def nested_apply(h, s):
20
  # h is an unary function, s is one of N, tuple of N, list of N, or set of N
 
10
  def device():
11
  return torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
 
13
+ def nested_to_device(s):
14
  # s is either a tensor or a dictionary
15
  if isinstance(s, torch.Tensor):
16
  return s.to(device())
17
+ return {k: v.to(device()) for k, v in s.items()}
18
 
19
  def nested_apply(h, s):
20
  # h is an unary function, s is one of N, tuple of N, list of N, or set of N