Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import re | |
from functools import cache | |
from pathlib import Path | |
from typing import List, Set, Tuple, TypeVar | |
import torch | |
from PIL import Image | |
from transformers import Idefics2Processor, PreTrainedTokenizer | |
from utils import device, nested_apply, sorted_list | |
RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa | |
# Name type, newtype of str. e.g. "page4-249.png" | |
N = TypeVar('N') | |
ALPHABET = 'ABCDEFGHIJ' # we only have 10 images | |
LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413, | |
420, 475, 5339, 634, 17960, 32002] # A - J and <end_of_utterance> and <\s> and 'select' and 'deselect' | |
MINI_DECODER = { | |
384: 'D', | |
# 2: '</s>', | |
32002: '<end_of_utterance>', | |
420: 'G', 17960: 'elect', | |
330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J', | |
634: 'des', 315: 'I', 413: 'E', 382: 'H'} | |
class AlphabeticNameHash: | |
def __init__(self, context: List[N]) -> None: | |
self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)} | |
self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)} | |
def hash(self, im: N) -> str: | |
return self._forward_map[im] | |
def unhash(self, i: str) -> N: | |
return self._backward_map[i] | |
def valid_hash(self, i: str) -> bool: | |
return i in self._backward_map | |
class IdeficsAdapter: | |
PAD_TOKEN_ID = 0 | |
LABEL_MASK_ID = 32001 # idefics2: image_token_id | |
LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS | |
LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\ | |
.index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool) | |
SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS)) | |
def __init__(self, image_folder: str, processor: Idefics2Processor) -> None: | |
self.t_max_length = 2048 | |
self.image_folder = Path(image_folder) | |
self.image_cache = {} | |
self.processor = processor | |
self.tokenizer: PreTrainedTokenizer = self.processor.tokenizer # type: ignore | |
def get_image(self, im_name: N) -> Image.Image: | |
if im_name not in self.image_cache: | |
self.image_cache[im_name] = Image.open( | |
self.image_folder.joinpath(im_name)) | |
return self.image_cache[im_name] | |
def unhash(self, context: List[N], c: str) -> N: | |
return AlphabeticNameHash(tuple(context)).unhash(c) | |
def valid_hash(self, context: List[N], c: str) -> bool: | |
return AlphabeticNameHash(tuple(context)).valid_hash(c) | |
def parse(self, context: List[N], decoded_out: str, | |
currently_selected: List[N]) -> List[str]: | |
h = AlphabeticNameHash(tuple(context)) | |
logging.debug(f"{context=}") | |
# do inference | |
logging.debug(f"{decoded_out=}") | |
selection, deselection = self.parse_raw(decoded_out) | |
hashed_currently_selected = {h.hash(n) for n in currently_selected} | |
desel_to_remove = deselection - hashed_currently_selected | |
if len(desel_to_remove) > 0: | |
logging.debug(f"warn! {desel_to_remove=}") | |
deselection = deselection - desel_to_remove | |
sel_to_remove = selection & hashed_currently_selected | |
if len(sel_to_remove) > 0: | |
logging.debug(f"warn! {sel_to_remove=}") | |
selection = selection - sel_to_remove | |
logging.debug("post strict cleaning") | |
logging.debug(f"{selection=}") | |
logging.debug(f"{deselection=}") | |
model_clicks = selection | deselection | |
logging.debug(f"{model_clicks=}") | |
model_clicks_png = [h.unhash(n) | |
for n in model_clicks if h.valid_hash(n)] | |
logging.debug(f"{model_clicks_png=}") | |
return model_clicks_png | |
def parse_raw(text: str) -> Tuple[Set[N], Set[N]]: | |
last_answer = text.strip() | |
if ":" in text: | |
last_answer_pattern = r":.*$" | |
xs = re.findall(last_answer_pattern, text) | |
last_answer = xs[0].removeprefix(":").strip() | |
xs = re.search(RE_PATTERN, last_answer) | |
if xs is None: | |
print(f"{last_answer=}") | |
print("did not pass regex") | |
return set(), set() | |
select_pattern = r"(?<!de)select( [A-J])+$" | |
xs = re.search(select_pattern, last_answer) | |
if xs is not None: | |
xs = xs.group() | |
selections: Set[N] = set(xs.split(" ")[1:]) if xs else set() | |
deselect_pattern = r"^deselect( [A-J])+" | |
xs = re.search(deselect_pattern, last_answer) | |
if xs is not None: | |
xs = xs.group() | |
deselections: Set[N] = set(xs.split(" ")[1:]) if xs else set() | |
return selections, deselections | |
def compose(self, context, chats, previous_selected, hash_images, padding): | |
select_accum, deselect_accum, clickss = self.unfold_select_deselect( | |
previous_selected) | |
select_accum = select_accum + [[]] | |
deselect_accum = deselect_accum + [[]] | |
previous_selected = [[]] + previous_selected # old states pre click | |
assert len(chats) == len(select_accum) == len( | |
deselect_accum) == len(previous_selected) | |
messages, images = self.build_processor_input( | |
context, chats, select_accum, deselect_accum, previous_selected, hash_images, omit_last_answer=True, sort_names=True, omit_context=False, chat_feedback=None) | |
prompt = self.processor.apply_chat_template( | |
messages, add_generation_prompt=True) | |
prompt = prompt.strip() | |
logging.debug(prompt) | |
# Keep consistent with train_script | |
inputs = self.processor( | |
text=prompt, images=images, | |
padding=padding, truncation=True, max_length=self.t_max_length, | |
return_tensors="pt") | |
return inputs | |
def build_processor_input(self, image_pngs: List[N], chats: List[str], | |
select_accum: List[List[N]], | |
deselect_accum: List[List[N]], | |
pre_click_selected_accum: List[List[N]], | |
hash_image: bool, omit_last_answer: bool, | |
sort_names: bool, omit_context: bool, | |
chat_feedback: str, ): | |
def _text_content(text): return {"type": "text", "text": text} | |
def _image_content(): return {"type": "image"} | |
def _user_prompt(content): return {"role": "user", "content": content} | |
def _assistant_prompt(content): return { | |
"role": "assistant", "content": content} | |
def _system_prompt(content): return { | |
"role": "system", "content": content} | |
def _current_state(selected: List[N]): | |
if len(selected) == 0: | |
return 'none is selected' | |
return f'{" ".join(selected)} currently selected' | |
def _listener_action(select: List[N], deselect: List[N]): | |
if len(select) == 0 and len(deselect) == 0: | |
return 'nothing' | |
if len(select) == 0: | |
return f'deselect {" ".join(deselect)}' | |
if len(deselect) == 0: | |
return f'select {" ".join(select)}' | |
return f'deselect {" ".join(deselect)} select {" ".join(select)}' | |
func = AlphabeticNameHash(tuple(image_pngs)).hash if hash_image else id | |
context, select_accum, deselect_accum, pre_click_selected_accum = nested_apply( | |
func, (image_pngs, select_accum, deselect_accum, pre_click_selected_accum)) | |
prompt = [] | |
images = [] | |
if not omit_context: | |
images = [self.get_image(im) for im in image_pngs] | |
images_and_names_content = [] | |
for im_name in context: | |
images_and_names_content.append(_image_content()) | |
images_and_names_content.append(_text_content(im_name)) | |
prompt.append(_system_prompt(images_and_names_content)) | |
if not len(chats) == len(select_accum) == len(deselect_accum) == len(pre_click_selected_accum): | |
logging.error(f"{chats=}") | |
logging.error(f"{select_accum=}") | |
logging.error(f"{deselect_accum=}") | |
logging.error(f"{pre_click_selected_accum=}") | |
assert False | |
for i, (chat, select, deselect, pre_click_selected) in enumerate( | |
zip(chats, select_accum, deselect_accum, pre_click_selected_accum)): | |
if sort_names: | |
select = sorted(select) | |
deselect = sorted(deselect) | |
pre_click_selected = sorted(pre_click_selected) | |
prompt.append(_system_prompt( | |
[_text_content(_current_state(pre_click_selected))])) | |
prompt.append(_user_prompt([_text_content(chat)])) | |
prompt.append(_assistant_prompt( | |
[_text_content(_listener_action(select, deselect))])) | |
if omit_last_answer: | |
# idefics2 has processor.apply_chat_template(messages, add_generation_prompt=True) instead | |
prompt.pop(-1) | |
if chat_feedback is not None: | |
prompt.append(_user_prompt([_text_content(chat_feedback)])) | |
return prompt, images | |
def unfold_select_deselect(self, previous_selected: List[List[N]]) -> Tuple[List[N], List[N], List[N]]: | |
# currently selected AFTER i-th turn | |
num_turns = len(previous_selected) | |
selected: List[List[str]] = [] # turn-wise selection | |
deselected: List[List[str]] = [] # turn-wise deselection | |
clicks: List[List[str]] = [] | |
# combining turn-wise newly selected and newly deselected | |
prev_selected = set() | |
for turn in range(num_turns): | |
curr_selected = set(previous_selected[turn]) | |
newly_selected = curr_selected - prev_selected | |
newly_deselected = prev_selected - curr_selected | |
selected.append(sorted_list(newly_selected)) | |
deselected.append(sorted_list(newly_deselected)) | |
clicks.append(sorted_list(newly_selected | newly_deselected)) | |
prev_selected = curr_selected.copy() | |
return selected, deselected, clicks | |