import dataclasses import logging import os from typing import Any, Dict, List import gradio as gr import PIL.Image as Image import PIL.ImageOps as ImageOps import spaces import torch from peft import PeftModel from transformers import AutoProcessor from transformers import Idefics2ForConditionalGeneration, Idefics2Processor from adapter import IdeficsAdapter from config_generator import GameConfig, generate_game_config from utils import device, nested_to_device, sorted_list import copy ### Constants IMG_DIR = "tangram_pngs" ### Bot server GEN_KWS: Dict[str, Any] = { "max_new_tokens": 10, "do_sample": True, "temperature": 1.0, "output_logits": True, "return_dict_in_generate": True, "remove_invalid_values": True, # just to be safe "renormalize_logits": True, "suppress_tokens": IdeficsAdapter.SUPPRESS_TOKEN_IDS } @spaces.GPU(duration=20) def get_model_response( # predict model: PeftModel, adapter_name: str, adapter: IdeficsAdapter, image_paths: List[str], chat : str, chats: List[str], previous_selected: List[List[str]] ) -> List[str]: if model.active_adapter != adapter_name: model.set_adapter(adapter_name) model.to(device()) new_chats = chats + [chat] currently_selected = previous_selected[-1] if len(previous_selected) > 0 else [] model_input: Dict[str, Any] = adapter.compose( image_paths, new_chats, previous_selected, True, True) model_input = nested_to_device(model_input) with torch.inference_mode(), torch.autocast(device_type=device().type, dtype=torch.bfloat16): model_output = model.generate(**model_input, **GEN_KWS) decoded_out: str = adapter.tokenizer.decode( model_output.sequences[0], skip_special_tokens=True) model_clicks = adapter.parse( image_paths, decoded_out, currently_selected) if len(model_clicks) == 0: logging.warning("empty clicks by model") model_clicks = [image_paths[0]] logging.debug(f"{image_paths=}") logging.debug(f"selecting {model_clicks}") prob = -1 else: prob = -3 logging.debug(f"{prob=}") logging.info(f"User input: {chat}") logging.info(f"Model selected: {model_clicks}") logging.debug(f"Model output: {decoded_out}") return model_clicks def get_model() -> PeftModel: model_id = 'lil-lab/respect' checkpoint = "HuggingFaceM4/idefics2-8b" model = Idefics2ForConditionalGeneration.from_pretrained( checkpoint, torch_dtype=torch.bfloat16,) peft_model = PeftModel.from_pretrained( model, model_id, adapter_name="r6_bp", is_trainable=False, revision="r6_bp") # Add other adapter - hack to avoid conflict lora_config = copy.deepcopy(peft_model.active_peft_config) targets = list(set(n[:n.find('lora')-1] for n, _ in model.named_parameters() if 'lora' in n)) lora_config.target_modules = targets peft_model.add_adapter("r0", lora_config) peft_model.load_adapter(model_id, "r0", is_trainable=False, revision="r0", peft_config=lora_config) return peft_model def get_processor() -> Idefics2Processor: checkpoint = "HuggingFaceM4/idefics2-8b" processor = AutoProcessor.from_pretrained( checkpoint, do_image_splitting=False, size={"longest_edge": 224, "shortest_edge": 224}) return processor def get_adapter() -> IdeficsAdapter: processor = get_processor() return IdeficsAdapter(IMG_DIR, processor) ### Game logic @dataclasses.dataclass(frozen=False) class GameState: config: GameConfig adapter_name: str chats: List[str] currently_selected: List[str] selected_accum: List[List[str]] clicks_accum: List[List[str]] turn: int = 0 def has_ended(self): return self.has_successfully_ended() or self.turn >= 10 def has_successfully_ended(self): return set(self.currently_selected) == set(self.config.targets) ### UI helpers def serialize_conversation(self): output = [f"Turn {i+1}: {message}" for i, message in enumerate(self.chats)] return "\n".join(output) def markup_images(self): context = self.config.speaker_context targets = self.config.targets selected = self.currently_selected changes = self.selected_accum[-1] if len(self.selected_accum) > 0 else [] tangram_list = self._display_context(context, targets, changes, selected) return tangram_list @staticmethod def _display_context(context: List[str], targets: List[str], changes: List[str], selected: List[str]) -> List[Image.Image]: tangram_list: List[Image.Image] = [] arrow = Image.open("yellow_circle.png").resize((20, 20)).convert("RGBA") for img in context: image = Image.open(os.path.join(IMG_DIR, img)).resize((60, 60)).convert("RGB") image = ImageOps.expand(image, border=2, fill="white") if img in targets and img in selected: # listener selected a target image image = ImageOps.expand(image, border=10, fill="green") elif img in targets and img not in selected: # unselected target: image = ImageOps.expand(image, border=10, fill="black") elif img in selected and img not in targets: # listener selected a wrong image image = ImageOps.expand(image, border=10, fill="red") else: image = ImageOps.expand(image, border=10, fill="white") image = ImageOps.expand(image, border=2, fill="white") if img in changes: image.paste(arrow, (68, 0), mask=arrow) tangram_list.append(image) return tangram_list class GameFlow: @classmethod def initialize(cls, model_iteration: str) -> GameState: config = generate_game_config() adapter_name = "r0" if model_iteration == "Initial System" else "r6_bp" state = GameState( config=config, adapter_name=adapter_name, chats=[], currently_selected=[], selected_accum=[], clicks_accum=[], turn=0, ) return state @classmethod def progress(cls, state: GameState, chat: str, model: PeftModel, adapter: IdeficsAdapter) -> GameState: turn = state.turn model_context_images = state.config.listener_context model_clicks = get_model_response( model, state.adapter_name, adapter, model_context_images, chat, state.chats, state.selected_accum ) # symmetric difference (apply deselection, then selection) currently_selected2 = sorted_list( (set(state.currently_selected) - set(model_clicks)) \ | (set(model_clicks) - set(state.currently_selected)) ) state2 = GameState( # constants config=state.config, adapter_name=state.adapter_name, # updates chats=state.chats.copy() + [chat], currently_selected=currently_selected2, selected_accum=state.selected_accum.copy() + [currently_selected2], clicks_accum=state.clicks_accum.copy() + [model_clicks], turn=turn+1, ) return state2 ### UI def create_app_inner(): ### layout gr.Markdown("# Tangram Multi-Reference Game") gr.Markdown( '### You will be playing a multi-reference games against a model. \ To start a game, first select whether you wish to play against our \ initial trained model ("Initial System") or \ our model at the end of continual learning ("Final System") \ and press the "Start Game" button.') gr.Markdown( 'You will take on a "speaker" role at each round. \ Your goal is to describe this image (via a message in the textbox) \ so that the model can guess what it is.\ Targets have black borders. \ Correctly selected targets have green borders. \ Incorrectly selected targets have red borders. \ Actions are marked with yellow dot. \ The listener cannot see boxes or colors and the order is different.') gr.Markdown( '### Press "Send" to submit your action to proceed to the next turn. \ You have 10 turns in total.') with gr.Row(): model_iteration = gr.Radio(["Initial System", "Final System"], label="Model Iteration", value="Final System") start_btn = gr.Button("Start Game") status = gr.Textbox(label="Status", interactive=False, show_label=False, text_align="center", value="Please start a game.") with gr.Row(): image_output = gr.Gallery( label="CONTEXT", show_label=False, elem_id="gallery", columns=5, rows=2, object_fit="contain", height="250px", allow_preview=False, container=True, interactive=False ) with gr.Row(): conversation_output = gr.Textbox(label="Interaction History") with gr.Column(): user_input = gr.Textbox(label="Your Message as Speaker", interactive=True) send_btn = gr.Button("Send", interactive=True) ### globals model = get_model() adapter = get_adapter() game_state = gr.State(value=None) ### callbacks def output_from_state(state: GameState): has_ended = state.has_ended() success = "Success" if state.has_successfully_ended() else "Failure" status = f"{success} (Turn {state.turn}/10) - Start another game?" \ if has_ended else f"Turn {state.turn+1}/10" return ( state.markup_images(), # image_output state.serialize_conversation(), # conversation_output status, # status gr.update(interactive=not has_ended, value=""), # user_input gr.update(interactive=not has_ended), # send_btn gr.update(interactive=has_ended), # model_iteration state, # game_history ) def on_start_interaction(model_iteration: str): assert model_iteration in ["Initial System", "Final System"] state = GameFlow.initialize(model_iteration) return output_from_state(state) def on_send_message(message: str, state: GameState): nonlocal model nonlocal adapter if message.strip() == "": logging.info("Empty message") return output_from_state(state) state = GameFlow.progress(state, message, model, adapter) return output_from_state(state) start_btn.click( on_start_interaction, inputs=[model_iteration], outputs=[image_output, conversation_output, status, user_input, send_btn, model_iteration, game_state], queue=False ) send_btn.click( on_send_message, inputs=[user_input, game_state], outputs=[image_output, conversation_output, status, user_input, send_btn, model_iteration, game_state], queue=True ) def create_app(): with gr.Blocks(theme='saq1b/gradio-theme') as app: create_app_inner() return app if __name__ == "__main__": app = create_app() app.queue() app.launch()