import os import subprocess import datasets from langchain.docstore.document import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings import json from transformers.agents import Tool from langchain_core.vectorstores import VectorStore from transformers.agents import HfEngine, ReactJsonAgent # Install flash attention subprocess.run( "pip install flash-attn --no-build-isolation", "pip install git+https://github.com/huggingface/transformers.git#egg=transformers[agents]", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) # Install RAG dependencies subprocess.run( "pip install langchain sentence-transformers faiss-cpu", shell=True, ) import copy import spaces import time import torch from threading import Thread from typing import List, Dict, Union import urllib from PIL import Image import io import datasets import gradio as gr from transformers import AutoProcessor, TextIteratorStreamer from transformers import Idefics2ForConditionalGeneration DEVICE = torch.device("cuda") MODELS = { "idefics2-8b-chatty": Idefics2ForConditionalGeneration.from_pretrained( "HuggingFaceM4/idefics2-8b-chatty", # "Ali-C137/idefics2-8b-chatty-yalla", torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2", ).to(DEVICE), } PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/idefics2-8b", # "Ali-C137/idefics2-8b-chatty-yalla", ) # Load the custom dataset knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train") # Process the documents source_docs = [ Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]}) for doc in knowledge_base ] docs_processed = RecursiveCharacterTextSplitter(chunk_size=500).split_documents(source_docs)[:1000] # Create embeddings and vector store embedding_model = HuggingFaceEmbeddings("thenlper/gte-small") vectordb = FAISS.from_documents(documents=docs_processed, embedding=embedding_model) class RetrieverTool(Tool): name = "retriever" description = "Retrieves documents from the knowledge base that have the closest embeddings to the input query." inputs = { "query": { "type": "text", "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.", }, "source": { "type": "text", "description": "", }, } output_type = "text" def __init__(self, vectordb: VectorStore, all_sources: str, **kwargs): super().__init__(**kwargs) self.vectordb = vectordb self.inputs["source"]["description"] = ( f"The source of the documents to search, as a str representation of a list. Possible values in the list are: {all_sources}. If this argument is not provided, all sources will be searched." ) def forward(self, query: str, source: str = None) -> str: assert isinstance(query, str), "Your search query must be a string" if source: if isinstance(source, str) and "[" not in str(source): # if the source is not representing a list source = [source] source = json.loads(str(source).replace("'", '"')) docs = self.vectordb.similarity_search(query, filter=({"source": source} if source else None), k=3) if len(docs) == 0: return "No documents found with this filtering. Try removing the source filter." return "Retrieved documents:\n\n" + "\n===Document===\n".join( [doc.page_content for doc in docs] ) from transformers.agents import HfEngine, ReactJsonAgent # Initialize the LLM engine and the agent with the retriever tool llm_engine = HfEngine("meta-llama/Meta-Llama-3-8B-Instruct") all_sources = list(set([doc.metadata["source"] for doc in docs_processed])) retriever_tool = RetrieverTool(vectordb, all_sources) agent = ReactJsonAgent(tools=[retriever_tool], llm_engine=llm_engine) # Should change this section for the finetuned model SYSTEM_PROMPT = [ { "role": "system", "content": [ { "type": "text", "text": "You are YALLA, a personalized AI chatbot assistant designed to enhance the user's experience in Morocco. Your mission is to provide accurate, real-time, and culturally rich information to make their visit enjoyable and stress-free. You can handle text and image inputs, offering recommendations on transport, event schedules, dining, accommodations, and cultural experiences. You can also perform real-time web searches and use various APIs to assist users effectively. Always be respectful, polite, and inclusive, and strive to offer truthful and helpful responses.", }, ], }, { "role": "assistant", "content": [ { "type": "text", "text": "Hello, I'm YALLA, your personalized AI assistant for exploring Morocco. How can I assist you today?", }, ], } ] examples_path = os.path.dirname(__file__) EXAMPLES = [ [ { "text": "For 2024, the interest expense is twice what it was in 2014, and the long-term debt is 10% higher than its 2015 level. Can you calculate the combined total of the interest and long-term debt for 2024?", "files": [f"{examples_path}/example_images/mmmu_example_2.png"], } ], [ { "text": "What's in the image?", "files": [f"{examples_path}/example_images/plant_bulb.webp"], } ], [ { "text": "Describe the image", "files": [f"{examples_path}/example_images/baguettes_guarding_paris.png"], } ], [ { "text": "Read what's written on the paper", "files": [f"{examples_path}/example_images/paper_with_text.png"], } ], ] # BOT_AVATAR = "IDEFICS_logo.png" BOT_AVATAR = "YALLA_logo.png" # Chatbot utils def turn_is_pure_media(turn): return turn[1] is None def load_image_from_url(url): with urllib.request.urlopen(url) as response: image_data = response.read() image_stream = io.BytesIO(image_data) image = Image.open(image_stream) return image def img_to_bytes(image_path): image = Image.open(image_path).convert(mode='RGB') buffer = io.BytesIO() image.save(buffer, format="JPEG") img_bytes = buffer.getvalue() image.close() return img_bytes def format_user_prompt_with_im_history_and_system_conditioning( user_prompt, chat_history ) -> List[Dict[str, Union[List, str]]]: """ Produces the resulting list that needs to go inside the processor. It handles the potential image(s), the history and the system conditionning. """ resulting_messages = copy.deepcopy(SYSTEM_PROMPT) resulting_images = [] for resulting_message in resulting_messages: if resulting_message["role"] == "user": for content in resulting_message["content"]: if content["type"] == "image": resulting_images.append(load_image_from_url(content["image"])) # Format history for turn in chat_history: if not resulting_messages or ( resulting_messages and resulting_messages[-1]["role"] != "user" ): resulting_messages.append( { "role": "user", "content": [], } ) if turn_is_pure_media(turn): media = turn[0][0] resulting_messages[-1]["content"].append({"type": "image"}) resulting_images.append(Image.open(media)) else: user_utterance, assistant_utterance = turn resulting_messages[-1]["content"].append( {"type": "text", "text": user_utterance.strip()} ) resulting_messages.append( { "role": "assistant", "content": [{"type": "text", "text": user_utterance.strip()}], } ) # Format current input if not user_prompt["files"]: resulting_messages.append( { "role": "user", "content": [{"type": "text", "text": user_prompt["text"]}], } ) else: # Choosing to put the image first (i.e. before the text), but this is an arbiratrary choice. resulting_messages.append( { "role": "user", "content": [{"type": "image"}] * len(user_prompt["files"]) + [{"type": "text", "text": user_prompt["text"]}], } ) resulting_images.extend([Image.open(path) for path in user_prompt["files"]]) return resulting_messages, resulting_images def extract_images_from_msg_list(msg_list): all_images = [] for msg in msg_list: for c_ in msg["content"]: if isinstance(c_, Image.Image): all_images.append(c_) return all_images @spaces.GPU(duration=60, queue=False) def model_inference( user_prompt, chat_history, model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p, ): if user_prompt["text"].strip() == "" and not user_prompt["files"]: gr.Error("Please input a query and optionally image(s).") if user_prompt["text"].strip() == "" and user_prompt["files"]: gr.Error("Please input a text query along the image(s).") streamer = TextIteratorStreamer( PROCESSOR.tokenizer, skip_prompt=True, timeout=5.0, ) # Common parameters to all decoding strategies generation_args = { "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, "streamer": streamer, } assert decoding_strategy in [ "Greedy", "Top P Sampling", ] if decoding_strategy == "Greedy": generation_args["do_sample"] = False elif decoding_strategy == "Top P Sampling": generation_args["temperature"] = temperature generation_args["do_sample"] = True generation_args["top_p"] = top_p # Creating model inputs resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning( user_prompt=user_prompt, chat_history=chat_history, ) prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True) inputs = PROCESSOR( text=prompt, images=resulting_images if resulting_images else None, return_tensors="pt", ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} generation_args.update(inputs) # Use the agent to perform RAG agent_output = agent.run(user_prompt["text"]) print("Agent output:", agent_output) # Stream the generated text thread = Thread( target=MODELS[model_selector].generate, kwargs=generation_args, ) thread.start() acc_text = "" for text_token in streamer: time.sleep(0.04) acc_text += text_token if acc_text.endswith(""): acc_text = acc_text[:-18] yield acc_text print("Success - generated the following text:", acc_text) print("-----") FEATURES = datasets.Features( { "model_selector": datasets.Value("string"), "images": datasets.Sequence(datasets.Image(decode=True)), "conversation": datasets.Sequence({"User": datasets.Value("string"), "Assistant": datasets.Value("string")}), "decoding_strategy": datasets.Value("string"), "temperature": datasets.Value("float32"), "max_new_tokens": datasets.Value("int32"), "repetition_penalty": datasets.Value("float32"), "top_p": datasets.Value("int32"), } ) # Hyper-parameters for generation max_new_tokens = gr.Slider( minimum=8, maximum=1024, value=512, step=1, interactive=True, label="Maximum number of new tokens to generate", ) repetition_penalty = gr.Slider( minimum=0.01, maximum=5.0, value=1.1, step=0.01, interactive=True, label="Repetition penalty", info="1.0 is equivalent to no penalty", ) decoding_strategy = gr.Radio( [ "Greedy", "Top P Sampling", ], value="Greedy", label="Decoding strategy", interactive=True, info="Higher values is equivalent to sampling more low-probability tokens.", ) temperature = gr.Slider( minimum=0.0, maximum=5.0, value=0.4, step=0.1, visible=False, interactive=True, label="Sampling temperature", info="Higher values will produce more diverse outputs.", ) top_p = gr.Slider( minimum=0.01, maximum=0.99, value=0.8, step=0.01, visible=False, interactive=True, label="Top P", info="Higher values is equivalent to sampling more low-probability tokens.", ) chatbot = gr.Chatbot( label="YALLA-Chatty", avatar_images=[None, BOT_AVATAR], height=450, ) with gr.Blocks(fill_height=True) as demo: gr.Markdown("# 🇲🇦 YALLA ") with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( choices=MODELS.keys(), value=list(MODELS.keys())[0], interactive=True, show_label=False, container=False, label="Model", visible=False, ) decoding_strategy.change( fn=lambda selection: gr.Slider( visible=( selection in [ "contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k", ] ) ), inputs=decoding_strategy, outputs=temperature, ) decoding_strategy.change( fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), inputs=decoding_strategy, outputs=top_p, ) gr.ChatInterface( fn=model_inference, chatbot=chatbot, examples=EXAMPLES, multimodal=False, cache_examples=False, additional_inputs=[ model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p, ], ) demo.launch()