Spaces:
Running
on
Zero
Running
on
Zero
chenzizhao
commited on
Commit
•
87b7a45
1
Parent(s):
2f56479
cosmetics
Browse files- adapter.py +7 -6
- app.py +37 -48
- config_generator.py +2 -2
- utils.py +2 -2
adapter.py
CHANGED
@@ -6,6 +6,7 @@ from typing import List, Set, Tuple, TypeVar
|
|
6 |
|
7 |
import torch
|
8 |
from PIL import Image
|
|
|
9 |
|
10 |
from utils import device, nested_apply, sorted_list
|
11 |
|
@@ -55,12 +56,12 @@ class IdeficsAdapter:
|
|
55 |
.index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
|
56 |
SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
|
57 |
|
58 |
-
def __init__(self, image_folder: str, processor) -> None:
|
59 |
self.t_max_length = 2048
|
60 |
self.image_folder = Path(image_folder)
|
61 |
self.image_cache = {}
|
62 |
self.processor = processor
|
63 |
-
self.tokenizer = self.processor.tokenizer
|
64 |
|
65 |
def get_image(self, im_name: N) -> Image.Image:
|
66 |
if im_name not in self.image_cache:
|
@@ -68,10 +69,10 @@ class IdeficsAdapter:
|
|
68 |
self.image_folder.joinpath(im_name))
|
69 |
return self.image_cache[im_name]
|
70 |
|
71 |
-
def unhash(self, context: List[N], c: str):
|
72 |
return AlphabeticNameHash(tuple(context)).unhash(c)
|
73 |
|
74 |
-
def valid_hash(self, context: List[N], c: str):
|
75 |
return AlphabeticNameHash(tuple(context)).valid_hash(c)
|
76 |
|
77 |
def parse(self, context: List[N], decoded_out: str,
|
@@ -121,13 +122,13 @@ class IdeficsAdapter:
|
|
121 |
xs = re.search(select_pattern, last_answer)
|
122 |
if xs is not None:
|
123 |
xs = xs.group()
|
124 |
-
selections = set(xs.split(" ")[1:]) if xs else set()
|
125 |
|
126 |
deselect_pattern = r"^deselect( [A-J])+"
|
127 |
xs = re.search(deselect_pattern, last_answer)
|
128 |
if xs is not None:
|
129 |
xs = xs.group()
|
130 |
-
deselections = set(xs.split(" ")[1:]) if xs else set()
|
131 |
|
132 |
return selections, deselections
|
133 |
|
|
|
6 |
|
7 |
import torch
|
8 |
from PIL import Image
|
9 |
+
from transformers import Idefics2Processor, PreTrainedTokenizer
|
10 |
|
11 |
from utils import device, nested_apply, sorted_list
|
12 |
|
|
|
56 |
.index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool)
|
57 |
SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS))
|
58 |
|
59 |
+
def __init__(self, image_folder: str, processor: Idefics2Processor) -> None:
|
60 |
self.t_max_length = 2048
|
61 |
self.image_folder = Path(image_folder)
|
62 |
self.image_cache = {}
|
63 |
self.processor = processor
|
64 |
+
self.tokenizer: PreTrainedTokenizer = self.processor.tokenizer # type: ignore
|
65 |
|
66 |
def get_image(self, im_name: N) -> Image.Image:
|
67 |
if im_name not in self.image_cache:
|
|
|
69 |
self.image_folder.joinpath(im_name))
|
70 |
return self.image_cache[im_name]
|
71 |
|
72 |
+
def unhash(self, context: List[N], c: str) -> N:
|
73 |
return AlphabeticNameHash(tuple(context)).unhash(c)
|
74 |
|
75 |
+
def valid_hash(self, context: List[N], c: str) -> bool:
|
76 |
return AlphabeticNameHash(tuple(context)).valid_hash(c)
|
77 |
|
78 |
def parse(self, context: List[N], decoded_out: str,
|
|
|
122 |
xs = re.search(select_pattern, last_answer)
|
123 |
if xs is not None:
|
124 |
xs = xs.group()
|
125 |
+
selections: Set[N] = set(xs.split(" ")[1:]) if xs else set()
|
126 |
|
127 |
deselect_pattern = r"^deselect( [A-J])+"
|
128 |
xs = re.search(deselect_pattern, last_answer)
|
129 |
if xs is not None:
|
130 |
xs = xs.group()
|
131 |
+
deselections: Set[N] = set(xs.split(" ")[1:]) if xs else set()
|
132 |
|
133 |
return selections, deselections
|
134 |
|
app.py
CHANGED
@@ -3,13 +3,13 @@ import logging
|
|
3 |
import os
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
-
import gradio as gr
|
7 |
import PIL.Image as Image
|
8 |
import PIL.ImageOps as ImageOps
|
9 |
-
import spaces
|
10 |
import torch
|
11 |
-
from peft import PeftModel
|
12 |
-
from transformers import AutoProcessor
|
13 |
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
|
14 |
|
15 |
from adapter import IdeficsAdapter
|
@@ -18,15 +18,6 @@ from utils import device, nested_to_device, sorted_list
|
|
18 |
import copy
|
19 |
|
20 |
### Constants
|
21 |
-
css="""
|
22 |
-
.radio-group .wrap {
|
23 |
-
display: grid;
|
24 |
-
grid-template-columns: repeat(5, 1fr);
|
25 |
-
grid-template-rows: repeat(5, 1fr);
|
26 |
-
width: 100%;
|
27 |
-
height: 100%
|
28 |
-
}
|
29 |
-
"""
|
30 |
IMG_DIR = "tangram_pngs"
|
31 |
|
32 |
|
@@ -56,18 +47,18 @@ def get_model_response( # predict
|
|
56 |
|
57 |
new_chats = chats + [chat]
|
58 |
currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
|
59 |
-
model_input: Dict[str, Any] = adapter.compose(
|
60 |
image_paths, new_chats, previous_selected, True, False)
|
61 |
-
model_input = nested_to_device(model_input)
|
62 |
|
63 |
with torch.inference_mode(), torch.autocast(device_type=device().type,
|
64 |
dtype=torch.bfloat16):
|
65 |
-
model_output = model.generate(**model_input, **GEN_KWS)
|
66 |
|
67 |
-
decoded_out: str = adapter.tokenizer.decode(
|
68 |
model_output.sequences[0], skip_special_tokens=True)
|
69 |
model_clicks = adapter.parse(
|
70 |
-
image_paths, decoded_out, currently_selected)
|
71 |
|
72 |
if len(model_clicks) == 0:
|
73 |
logging.warning("empty clicks by model")
|
@@ -87,10 +78,9 @@ def get_model_response( # predict
|
|
87 |
def get_model() -> PeftModel:
|
88 |
model_id = 'lil-lab/respect'
|
89 |
checkpoint = "HuggingFaceM4/idefics2-8b"
|
90 |
-
model = Idefics2ForConditionalGeneration.from_pretrained(
|
91 |
-
checkpoint, torch_dtype=torch.bfloat16,
|
92 |
-
|
93 |
-
peft_model = PeftModel.from_pretrained( # type: ignore
|
94 |
model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
|
95 |
|
96 |
# Add other adapter - hack to avoid conflict
|
@@ -105,10 +95,10 @@ def get_model() -> PeftModel:
|
|
105 |
|
106 |
def get_processor() -> Idefics2Processor:
|
107 |
checkpoint = "HuggingFaceM4/idefics2-8b"
|
108 |
-
processor = AutoProcessor.from_pretrained(
|
109 |
checkpoint, do_image_splitting=False,
|
110 |
size={"longest_edge": 224, "shortest_edge": 224})
|
111 |
-
return processor
|
112 |
|
113 |
def get_adapter() -> IdeficsAdapter:
|
114 |
processor = get_processor()
|
@@ -147,7 +137,6 @@ class GameState:
|
|
147 |
changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
|
148 |
|
149 |
tangram_list = self._display_context(context, targets, changes, selected)
|
150 |
-
# return [(img, f"Image {i+1}") for i, img in enumerate(tangram_list)]
|
151 |
return tangram_list
|
152 |
|
153 |
@staticmethod
|
@@ -234,30 +223,29 @@ def create_app_inner():
|
|
234 |
To start a game, first select whether you wish to play against our \
|
235 |
initial trained model ("Initial System") or \
|
236 |
our model at the end of continual learning ("Final System") \
|
237 |
-
and press the "Start Game" button.
|
238 |
-
You will take on a "speaker" role at each round. \
|
239 |
-
Your goal is to describe this image (via a message in the textbox) \
|
240 |
-
so that the model can guess what it is.'
|
241 |
-
)
|
242 |
|
243 |
-
gr.Markdown(
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
gr.Markdown(
|
248 |
'### Press "Send" to submit your action to proceed to the next turn. \
|
249 |
-
You have 10 turns in total.'
|
250 |
-
)
|
251 |
|
252 |
with gr.Row():
|
253 |
model_iteration = gr.Radio(["Initial System", "Final System"],
|
254 |
label="Model Iteration",
|
255 |
value="Final System")
|
256 |
start_btn = gr.Button("Start Game")
|
257 |
-
|
258 |
-
|
259 |
-
current_turn = gr.Textbox(label="TURN")
|
260 |
-
success = gr.Textbox(label="Success")
|
261 |
|
262 |
with gr.Row():
|
263 |
image_output = gr.Gallery(
|
@@ -268,9 +256,9 @@ def create_app_inner():
|
|
268 |
|
269 |
with gr.Row():
|
270 |
conversation_output = gr.Textbox(label="Interaction History")
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
|
275 |
### globals
|
276 |
model = get_model()
|
@@ -280,12 +268,13 @@ def create_app_inner():
|
|
280 |
### callbacks
|
281 |
def output_from_state(state: GameState):
|
282 |
has_ended = state.has_ended()
|
283 |
-
success = "
|
|
|
|
|
284 |
return (
|
285 |
state.markup_images(), # image_output
|
286 |
state.serialize_conversation(), # conversation_output
|
287 |
-
|
288 |
-
success if has_ended else "n/a", # success
|
289 |
gr.update(interactive=not has_ended, value=""), # user_input
|
290 |
gr.update(interactive=not has_ended), # send_btn
|
291 |
gr.update(interactive=has_ended), # model_iteration
|
@@ -309,7 +298,7 @@ def create_app_inner():
|
|
309 |
start_btn.click(
|
310 |
on_start_interaction,
|
311 |
inputs=[model_iteration],
|
312 |
-
outputs=[image_output, conversation_output,
|
313 |
user_input, send_btn, model_iteration, game_state],
|
314 |
queue=False
|
315 |
)
|
@@ -317,14 +306,14 @@ def create_app_inner():
|
|
317 |
send_btn.click(
|
318 |
on_send_message,
|
319 |
inputs=[user_input, game_state],
|
320 |
-
outputs=[image_output, conversation_output,
|
321 |
user_input, send_btn, model_iteration, game_state],
|
322 |
queue=True
|
323 |
)
|
324 |
|
325 |
|
326 |
def create_app():
|
327 |
-
with gr.Blocks(
|
328 |
create_app_inner()
|
329 |
return app
|
330 |
|
|
|
3 |
import os
|
4 |
from typing import Any, Dict, List
|
5 |
|
6 |
+
import gradio as gr
|
7 |
import PIL.Image as Image
|
8 |
import PIL.ImageOps as ImageOps
|
9 |
+
import spaces
|
10 |
import torch
|
11 |
+
from peft import PeftModel
|
12 |
+
from transformers import AutoProcessor
|
13 |
from transformers import Idefics2ForConditionalGeneration, Idefics2Processor
|
14 |
|
15 |
from adapter import IdeficsAdapter
|
|
|
18 |
import copy
|
19 |
|
20 |
### Constants
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
IMG_DIR = "tangram_pngs"
|
22 |
|
23 |
|
|
|
47 |
|
48 |
new_chats = chats + [chat]
|
49 |
currently_selected = previous_selected[-1] if len(previous_selected) > 0 else []
|
50 |
+
model_input: Dict[str, Any] = adapter.compose(
|
51 |
image_paths, new_chats, previous_selected, True, False)
|
52 |
+
model_input = nested_to_device(model_input)
|
53 |
|
54 |
with torch.inference_mode(), torch.autocast(device_type=device().type,
|
55 |
dtype=torch.bfloat16):
|
56 |
+
model_output = model.generate(**model_input, **GEN_KWS)
|
57 |
|
58 |
+
decoded_out: str = adapter.tokenizer.decode(
|
59 |
model_output.sequences[0], skip_special_tokens=True)
|
60 |
model_clicks = adapter.parse(
|
61 |
+
image_paths, decoded_out, currently_selected)
|
62 |
|
63 |
if len(model_clicks) == 0:
|
64 |
logging.warning("empty clicks by model")
|
|
|
78 |
def get_model() -> PeftModel:
|
79 |
model_id = 'lil-lab/respect'
|
80 |
checkpoint = "HuggingFaceM4/idefics2-8b"
|
81 |
+
model = Idefics2ForConditionalGeneration.from_pretrained(
|
82 |
+
checkpoint, torch_dtype=torch.bfloat16,)
|
83 |
+
peft_model = PeftModel.from_pretrained(
|
|
|
84 |
model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp")
|
85 |
|
86 |
# Add other adapter - hack to avoid conflict
|
|
|
95 |
|
96 |
def get_processor() -> Idefics2Processor:
|
97 |
checkpoint = "HuggingFaceM4/idefics2-8b"
|
98 |
+
processor = AutoProcessor.from_pretrained(
|
99 |
checkpoint, do_image_splitting=False,
|
100 |
size={"longest_edge": 224, "shortest_edge": 224})
|
101 |
+
return processor
|
102 |
|
103 |
def get_adapter() -> IdeficsAdapter:
|
104 |
processor = get_processor()
|
|
|
137 |
changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else []
|
138 |
|
139 |
tangram_list = self._display_context(context, targets, changes, selected)
|
|
|
140 |
return tangram_list
|
141 |
|
142 |
@staticmethod
|
|
|
223 |
To start a game, first select whether you wish to play against our \
|
224 |
initial trained model ("Initial System") or \
|
225 |
our model at the end of continual learning ("Final System") \
|
226 |
+
and press the "Start Game" button.')
|
|
|
|
|
|
|
|
|
227 |
|
228 |
+
gr.Markdown(
|
229 |
+
'You will take on a "speaker" role at each round. \
|
230 |
+
Your goal is to describe this image (via a message in the textbox) \
|
231 |
+
so that the model can guess what it is.\
|
232 |
+
Targets have black borders. \
|
233 |
+
Correctly selected targets have green borders. \
|
234 |
+
Incorrectly selected targets have red borders. \
|
235 |
+
Actions are marked with yellow dot. \
|
236 |
+
The listener cannot see boxes or colors and the order is different.')
|
237 |
|
238 |
gr.Markdown(
|
239 |
'### Press "Send" to submit your action to proceed to the next turn. \
|
240 |
+
You have 10 turns in total.')
|
|
|
241 |
|
242 |
with gr.Row():
|
243 |
model_iteration = gr.Radio(["Initial System", "Final System"],
|
244 |
label="Model Iteration",
|
245 |
value="Final System")
|
246 |
start_btn = gr.Button("Start Game")
|
247 |
+
status = gr.Textbox(label="Status", interactive=False, show_label=False,
|
248 |
+
text_align="center", value="Please start a game.")
|
|
|
|
|
249 |
|
250 |
with gr.Row():
|
251 |
image_output = gr.Gallery(
|
|
|
256 |
|
257 |
with gr.Row():
|
258 |
conversation_output = gr.Textbox(label="Interaction History")
|
259 |
+
with gr.Column():
|
260 |
+
user_input = gr.Textbox(label="Your Message as Speaker", interactive=True)
|
261 |
+
send_btn = gr.Button("Send", interactive=True)
|
262 |
|
263 |
### globals
|
264 |
model = get_model()
|
|
|
268 |
### callbacks
|
269 |
def output_from_state(state: GameState):
|
270 |
has_ended = state.has_ended()
|
271 |
+
success = "Success" if state.has_successfully_ended() else "Failure"
|
272 |
+
status = f"{success} (Turn {state.turn}/10) - Start another game?" \
|
273 |
+
if has_ended else f"Turn {state.turn+1}/10"
|
274 |
return (
|
275 |
state.markup_images(), # image_output
|
276 |
state.serialize_conversation(), # conversation_output
|
277 |
+
status, # status
|
|
|
278 |
gr.update(interactive=not has_ended, value=""), # user_input
|
279 |
gr.update(interactive=not has_ended), # send_btn
|
280 |
gr.update(interactive=has_ended), # model_iteration
|
|
|
298 |
start_btn.click(
|
299 |
on_start_interaction,
|
300 |
inputs=[model_iteration],
|
301 |
+
outputs=[image_output, conversation_output, status,
|
302 |
user_input, send_btn, model_iteration, game_state],
|
303 |
queue=False
|
304 |
)
|
|
|
306 |
send_btn.click(
|
307 |
on_send_message,
|
308 |
inputs=[user_input, game_state],
|
309 |
+
outputs=[image_output, conversation_output, status,
|
310 |
user_input, send_btn, model_iteration, game_state],
|
311 |
queue=True
|
312 |
)
|
313 |
|
314 |
|
315 |
def create_app():
|
316 |
+
with gr.Blocks(theme='saq1b/gradio-theme') as app:
|
317 |
create_app_inner()
|
318 |
return app
|
319 |
|
config_generator.py
CHANGED
@@ -35,8 +35,8 @@ def generate_game_config() -> GameConfig:
|
|
35 |
return config
|
36 |
|
37 |
@functools.cache
|
38 |
-
def _get_data(
|
39 |
-
if not
|
40 |
# 1013 images
|
41 |
paths = os.listdir(EMPTY_DATA_PATH)
|
42 |
else:
|
|
|
35 |
return config
|
36 |
|
37 |
@functools.cache
|
38 |
+
def _get_data(hb_split: bool=True):
|
39 |
+
if not hb_split:
|
40 |
# 1013 images
|
41 |
paths = os.listdir(EMPTY_DATA_PATH)
|
42 |
else:
|
utils.py
CHANGED
@@ -10,11 +10,11 @@ def sorted_list(s: Set[str]) -> List[str]:
|
|
10 |
def device():
|
11 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
-
def nested_to_device(s):
|
14 |
# s is either a tensor or a dictionary
|
15 |
if isinstance(s, torch.Tensor):
|
16 |
return s.to(device())
|
17 |
-
return {k: v.to(device()) for k, v in s.items()}
|
18 |
|
19 |
def nested_apply(h, s):
|
20 |
# h is an unary function, s is one of N, tuple of N, list of N, or set of N
|
|
|
10 |
def device():
|
11 |
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
|
13 |
+
def nested_to_device(s):
|
14 |
# s is either a tensor or a dictionary
|
15 |
if isinstance(s, torch.Tensor):
|
16 |
return s.to(device())
|
17 |
+
return {k: v.to(device()) for k, v in s.items()}
|
18 |
|
19 |
def nested_apply(h, s):
|
20 |
# h is an unary function, s is one of N, tuple of N, list of N, or set of N
|