import gradio as gr import torch import os import numpy as np from groq import Groq import spaces from transformers import AutoModel, AutoTokenizer from diffusers import StableDiffusion3Pipeline from parler_tts import ParlerTTSForConditionalGeneration import soundfile as sf from langchain_groq import ChatGroq from PIL import Image from tavily import TavilyClient from langchain.schema import AIMessage from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import TextLoader from langchain.text_splitter import CharacterTextSplitter from langchain.chains import RetrievalQA from torchvision import transforms import json import pandas # Initialize models and clients MODEL = 'llama-3.1-70b-versatile' client = Groq(api_key=os.environ.get("GROQ_API_KEY")) vqa_model = AutoModel.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True, device_map="auto", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2', trust_remote_code=True) tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-large-v1") tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-large-v1") # Updated Image generation model pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16) pipe = pipe.to("cuda") # Tavily Client for web search tavily_client = TavilyClient(api_key=os.environ.get("TAVILY_API")) # Function to play voice output def play_voice_output(response): print("Executing play_voice_output function") description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." input_ids = tts_tokenizer(description, return_tensors="pt").input_ids.to('cuda') prompt_input_ids = tts_tokenizer(response, return_tensors="pt").input_ids.to('cuda') generation = tts_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() sf.write("output.wav", audio_arr, tts_model.config.sampling_rate) return "output.wav" # Function to classify user input using LLM def classify_function(user_prompt): prompt = f""" You are a function classifier AI assistant. You are given a user input and you need to classify it into one of the following functions: - `image_generation`: If the user wants to generate an image. - `image_vqa`: If the user wants to ask questions about an image. - `document_qa`: If the user wants to ask questions about a document. - `text_to_text`: If the user wants a text-based response. Respond with a JSON object containing only the chosen function. For example: ```json {{"function": "image_generation"}} ``` User input: {user_prompt} """ chat_completion = client.chat.completions.create( messages=[ { "role": "user", "content": prompt, } ], model="llama3-8b-8192", ) try: response = json.loads(chat_completion.choices[0].message.content) function = response.get("function") return function except json.JSONDecodeError: print(f"Error decoding JSON: {chat_completion.choices[0].message.content}") return "text_to_text" # Default to text-to-text if JSON parsing fails # Document Question Answering Tool class DocumentQuestionAnswering: def __init__(self, document): self.document = document self.qa_chain = self._setup_qa_chain() def _setup_qa_chain(self): print("Setting up DocumentQuestionAnswering tool") loader = TextLoader(self.document) documents = loader.load() text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) texts = text_splitter.split_documents(documents) embeddings = HuggingFaceEmbeddings() db = FAISS.from_documents(texts, embeddings) retriever = db.as_retriever() qa_chain = RetrievalQA.from_chain_type( llm=ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")), chain_type="stuff", retriever=retriever, ) return qa_chain def run(self, query: str) -> str: print("Executing DocumentQuestionAnswering tool") response = self.qa_chain.run(query) return str(response) # Function to handle different input types and choose the right pipeline def handle_input(user_prompt, image=None, audio=None, websearch=False, document=None): print(f"Handling input: {user_prompt}") # Initialize the LLM llm = ChatGroq(model=MODEL, api_key=os.environ.get("GROQ_API_KEY")) # Handle voice-only mode if audio: print("Processing audio input") transcription = client.audio.transcriptions.create( file=(audio.name, audio.read()), model="whisper-large-v3" ) user_prompt = transcription.text response = llm.invoke(query=user_prompt) audio_output = play_voice_output(response) return "Response generated.", audio_output # Handle websearch mode if websearch: print("Executing Web Search") answer = tavily_client.qna_search(query=user_prompt) return answer, None # Handle cases with only image or document input if user_prompt is None or user_prompt.strip() == "": if image: user_prompt = "Describe this image" elif document: user_prompt = "Summarize this document" # Classify user input using LLM function = classify_function(user_prompt) # Handle different functions if function == "image_generation": print("Executing Image Generation") image = pipe( user_prompt, negative_prompt="", num_inference_steps=15, guidance_scale=7.0, ).images[0] image.save("output.jpg") return "output.jpg", None elif function == "image_vqa": print("Executing Image Description") if image: print("1") image = Image.open(image).convert('RGB') print("2") # Add preprocessing steps here (see examples above) preprocess = transforms.Compose([ transforms.Resize((512, 512)), # Example size, replace with the correct one transforms.ToTensor(), ]) image = preprocess(image) image = image.unsqueeze(0) # Add batch dimension image = image.to(torch.float32) # Ensure correct data type print("3") messages = [{"role": "user", "content": user_prompt}] print("4") response,ctxt = vqa_model.chat(image=image, msgs=messages, tokenizer=tokenizer, context=None, temperature=0.5) print("5") return response, None else: return "Please upload an imagee.", None elif function == "document_qa": print("Executing Document Summarization") if document: document_qa = DocumentQuestionAnswering(document) response = document_qa.run(user_prompt) return response, None else: return "Please upload a documentt.", None else: # function == "text_to_text" print("Executing Text-to-Text") response = llm.invoke(query=user_prompt) return response, None # Main interface function @spaces.GPU(duration=120) def main_interface(user_prompt, image=None, audio=None, voice_only=False, websearch=False, document=None): print("Starting main_interface function") vqa_model.to(device='cuda', dtype=torch.bfloat16) tts_model.to("cuda") pipe.to("cuda") print(f"user_prompt: {user_prompt}, image: {image}, audio: {audio}, voice_only: {voice_only}, websearch: {websearch}, document: {document}") try: response = handle_input(user_prompt, image=image, audio=audio, websearch=websearch, document=document) print("handle_input function executed successfully") except Exception as e: print(f"Error in handle_input: {e}") response = "Error occurred during processing." return response def create_ui(): with gr.Blocks(css=""" /* Overall Styling */ body { font-family: 'Poppins', sans-serif; background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); margin: 0; padding: 0; color: #333; } /* Title Styling */ .gradio-container h1 { text-align: center; padding: 20px 0; background: linear-gradient(45deg, #007bff, #00c6ff); color: white; font-size: 2.5em; font-weight: bold; letter-spacing: 1px; text-transform: uppercase; margin: 0; box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.2); } /* Input Area Styling */ .gradio-container .gr-row { display: flex; justify-content: space-around; align-items: center; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1); margin-bottom: 20px; } .gradio-container .gr-column { flex: 1; margin: 0 10px; } /* Textbox Styling */ .gradio-container textarea { width: calc(100% - 20px); padding: 15px; border: 2px solid #007bff; border-radius: 8px; font-size: 1.1em; transition: border-color 0.3s, box-shadow 0.3s; } .gradio-container textarea:focus { border-color: #00c6ff; box-shadow: 0px 0px 8px rgba(0, 198, 255, 0.5); outline: none; } /* Button Styling */ .gradio-container button { background: linear-gradient(45deg, #007bff, #00c6ff); color: white; padding: 15px 25px; border: none; border-radius: 8px; cursor: pointer; font-size: 1.2em; font-weight: bold; transition: background 0.3s, transform 0.3s; box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1); } .gradio-container button:hover { background: linear-gradient(45deg, #0056b3, #009bff); transform: translateY(-3px); } .gradio-container button:active { transform: translateY(0); } /* Output Area Styling */ .gradio-container .output-area { padding: 20px; text-align: center; background-color: #f7f9fc; border-radius: 10px; box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.1); margin-top: 20px; } /* Image Styling */ .gradio-container img { max-width: 100%; height: auto; border-radius: 10px; box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1); transition: transform 0.3s, box-shadow 0.3s; } .gradio-container img:hover { transform: scale(1.05); box-shadow: 0px 6px 12px rgba(0, 0, 0, 0.2); } /* Checkbox Styling */ .gradio-container input[type="checkbox"] { width: 20px; height: 20px; cursor: pointer; accent-color: #007bff; transition: transform 0.3s; } .gradio-container input[type="checkbox"]:checked { transform: scale(1.2); } /* Audio and Document Upload Styling */ .gradio-container .gr-file-upload input[type="file"] { width: 100%; padding: 10px; border: 2px solid #007bff; border-radius: 8px; cursor: pointer; background-color: white; transition: border-color 0.3s, background-color 0.3s; } .gradio-container .gr-file-upload input[type="file"]:hover { border-color: #00c6ff; background-color: #f0f8ff; } /* Advanced Tooltip Styling */ .gradio-container .gr-tooltip { position: relative; display: inline-block; cursor: pointer; } .gradio-container .gr-tooltip .tooltiptext { visibility: hidden; width: 200px; background-color: black; color: #fff; text-align: center; border-radius: 6px; padding: 5px; position: absolute; z-index: 1; bottom: 125%; left: 50%; margin-left: -100px; opacity: 0; transition: opacity 0.3s; } .gradio-container .gr-tooltip:hover .tooltiptext { visibility: visible; opacity: 1; } /* Footer Styling */ .gradio-container footer { text-align: center; padding: 10px; background: #007bff; color: white; font-size: 0.9em; border-radius: 0 0 10px 10px; box-shadow: 0px -2px 8px rgba(0, 0, 0, 0.1); } """) as demo: gr.Markdown("# AI Assistant") with gr.Row(): with gr.Column(scale=2): user_prompt = gr.Textbox(placeholder="Type your message here...", lines=1) with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload an image", elem_id="image-icon") audio_input = gr.Audio(type="filepath", label="Upload audio", elem_id="mic-icon") document_input = gr.File(type="filepath", label="Upload a document", elem_id="document-icon") voice_only_mode = gr.Checkbox(label="Enable Voice Only Mode", elem_id="voice-only-mode") websearch_mode = gr.Checkbox(label="Enable Web Search", elem_id="websearch-mode") with gr.Column(scale=1): submit = gr.Button("Submit") output_label = gr.Label(label="Output") audio_output = gr.Audio(label="Audio Output", visible=False) submit.click( fn=main_interface, inputs=[user_prompt, image_input, audio_input, voice_only_mode, websearch_mode, document_input], outputs=[output_label, audio_output] ) voice_only_mode.change( lambda x: gr.update(visible=not x), inputs=voice_only_mode, outputs=[user_prompt, image_input, websearch_mode, document_input, submit] ) voice_only_mode.change( lambda x: gr.update(visible=x), inputs=voice_only_mode, outputs=[audio_input] ) return demo # Launch the UI demo = create_ui() demo.launch()