import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) import uuid import gradio as gr import spaces import torch from qwen_vl_utils import process_vision_info from transformers import AutoProcessor, Qwen2VLForConditionalGeneration from voyager_index import Voyager device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the model and processor model = ( Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.bfloat16 ) .to(device) .eval() ) processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True ) def create_index(session_id): return Voyager(embedding_size=1536, override=True, index_name=f"{session_id}") def add_to_index(files, index): index.add_documents([file.name for file in files], batch_size=1) return f"Added {len(files)} files to the index." @spaces.GPU def generate_answer(query, retrieved_image): messages = [ { "role": "user", "content": [ { "type": "image", "image": retrieved_image, }, {"type": "text", "text": query}, ], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to(device) generated_ids = model.generate(**inputs, max_new_tokens=200) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, ) return output_text def query_index(query, index): res = index(query, k=1) retrieved_image = res["documents"][0][0]["image"] output_text = generate_answer(query, retrieved_image) return output_text[0], retrieved_image # Define the Gradio interface with gr.Blocks() as demo: session_id = gr.State(lambda: str(uuid.uuid4())) index = gr.State(lambda: create_index(session_id.value)) gr.Markdown("# Full vision pipeline demo") with gr.Tab("Add to Index"): file_input = gr.File(file_count="multiple", label="Upload Files") add_button = gr.Button("Add to Index") add_output = gr.Textbox(label="Result") add_button.click(add_to_index, inputs=[file_input, index], outputs=add_output) with gr.Tab("Query Index"): query_input = gr.Textbox(label="Enter your query") query_button = gr.Button("Submit Query") with gr.Row(): query_output = gr.Textbox(label="Answer") image_output = gr.Image(label="Retrieved Image") query_button.click( query_index, inputs=[query_input, index], outputs=[query_output, image_output], ) # Launch the interface demo.launch()