import logging import re from functools import cache from pathlib import Path from typing import List, Set, Tuple, TypeVar import torch from PIL import Image from transformers import Idefics2Processor, PreTrainedTokenizer from utils import device, nested_apply, sorted_list RE_PATTERN = r'^(deselect\s[A-Z](?:\s[A-Z])*(?:\sselect\s[A-Z](?:\s[A-Z])*)?|select\s[A-Z](?:\s[A-Z])*)$' # noqa # Name type, newtype of str. e.g. "page4-249.png" N = TypeVar('N') ALPHABET = 'ABCDEFGHIJ' # we only have 10 images LEGAL_TOKEN_IDS = [2, 315, 330, 334, 365, 382, 384, 401, 413, 420, 475, 5339, 634, 17960, 32002] # A - J and and <\s> and 'select' and 'deselect' MINI_DECODER = { 384: 'D', # 2: '', 32002: '', 420: 'G', 17960: 'elect', 330: 'A', 365: 'B', 334: 'C', 5339: 'select', 401: 'F', 475: 'J', 634: 'des', 315: 'I', 413: 'E', 382: 'H'} class AlphabeticNameHash: @cache def __init__(self, context: List[N]) -> None: self._forward_map = {im: ALPHABET[i] for i, im in enumerate(context)} self._backward_map = {ALPHABET[i]: im for i, im in enumerate(context)} def hash(self, im: N) -> str: return self._forward_map[im] def unhash(self, i: str) -> N: return self._backward_map[i] def valid_hash(self, i: str) -> bool: return i in self._backward_map class IdeficsAdapter: PAD_TOKEN_ID = 0 LABEL_MASK_ID = 32001 # idefics2: image_token_id LEGAL_TOKEN_IDS = LEGAL_TOKEN_IDS LEGAL_TOKEN_MASK = torch.zeros(32003, requires_grad=False)\ .index_fill_(0, torch.tensor(LEGAL_TOKEN_IDS), 1).to(device=device(), dtype=torch.bool) SUPPRESS_TOKEN_IDS = list(set(range(32003)) - set(LEGAL_TOKEN_IDS)) def __init__(self, image_folder: str, processor: Idefics2Processor) -> None: self.t_max_length = 2048 self.image_folder = Path(image_folder) self.image_cache = {} self.processor = processor self.tokenizer: PreTrainedTokenizer = self.processor.tokenizer # type: ignore def get_image(self, im_name: N) -> Image.Image: if im_name not in self.image_cache: self.image_cache[im_name] = Image.open( self.image_folder.joinpath(im_name)) return self.image_cache[im_name] def unhash(self, context: List[N], c: str) -> N: return AlphabeticNameHash(tuple(context)).unhash(c) def valid_hash(self, context: List[N], c: str) -> bool: return AlphabeticNameHash(tuple(context)).valid_hash(c) def parse(self, context: List[N], decoded_out: str, currently_selected: List[N]) -> List[str]: h = AlphabeticNameHash(tuple(context)) logging.debug(f"{context=}") # do inference logging.debug(f"{decoded_out=}") selection, deselection = self.parse_raw(decoded_out) hashed_currently_selected = {h.hash(n) for n in currently_selected} desel_to_remove = deselection - hashed_currently_selected if len(desel_to_remove) > 0: logging.debug(f"warn! {desel_to_remove=}") deselection = deselection - desel_to_remove sel_to_remove = selection & hashed_currently_selected if len(sel_to_remove) > 0: logging.debug(f"warn! {sel_to_remove=}") selection = selection - sel_to_remove logging.debug("post strict cleaning") logging.debug(f"{selection=}") logging.debug(f"{deselection=}") model_clicks = selection | deselection logging.debug(f"{model_clicks=}") model_clicks_png = [h.unhash(n) for n in model_clicks if h.valid_hash(n)] logging.debug(f"{model_clicks_png=}") return model_clicks_png @staticmethod def parse_raw(text: str) -> Tuple[Set[N], Set[N]]: last_answer = text.strip() if ":" in text: last_answer_pattern = r":.*$" xs = re.findall(last_answer_pattern, text) last_answer = xs[0].removeprefix(":").strip() xs = re.search(RE_PATTERN, last_answer) if xs is None: print(f"{last_answer=}") print("did not pass regex") return set(), set() select_pattern = r"(? Tuple[List[N], List[N], List[N]]: # currently selected AFTER i-th turn num_turns = len(previous_selected) selected: List[List[str]] = [] # turn-wise selection deselected: List[List[str]] = [] # turn-wise deselection clicks: List[List[str]] = [] # combining turn-wise newly selected and newly deselected prev_selected = set() for turn in range(num_turns): curr_selected = set(previous_selected[turn]) newly_selected = curr_selected - prev_selected newly_deselected = prev_selected - curr_selected selected.append(sorted_list(newly_selected)) deselected.append(sorted_list(newly_deselected)) clicks.append(sorted_list(newly_selected | newly_deselected)) prev_selected = curr_selected.copy() return selected, deselected, clicks