import os import torch from PIL import Image import gradio as gr from transformers import BlipProcessor, BlipForConditionalGeneration from langchain_huggingface import HuggingFaceEndpoint from langchain_core.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser from random import randint import asyncio from threading import RLock from externalmod import gr_Interface_load # Define model IDs llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3" blip_model_id = "Salesforce/blip-image-captioning-large" model_name = "John6666/wai-ani-hentai-pony-v3-sdxl" # Initialize BLIP processor and model processor = BlipProcessor.from_pretrained(blip_model_id) model = BlipForConditionalGeneration.from_pretrained(blip_model_id) # Initialize the model loading function lock = RLock() model_load = None def load_fn(model): global model_load try: model_load = gr_Interface_load(f'models/{model}') except Exception as error: print(error) model_load = gr.Interface(lambda: None, ['text'], ['image']) load_fn(model_name) async def infer(prompt, timeout): noise = "" rand = randint(1, 500) for i in range(rand): noise += " " task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}')) await asyncio.sleep(0) try: result = await asyncio.wait_for(task, timeout=timeout) except (Exception, asyncio.TimeoutError) as e: print(e) print(f"Task timed out: {model_name}") if not task.done(): task.cancel() result = None if task.done() and result is not None: with lock: image = Image.open(result).convert('RGBA') return image return None def gen_fn(prompt): try: loop = asyncio.new_event_loop() result = loop.run_until_complete(infer(prompt, timeout=300)) except (Exception, asyncio.CancelledError) as e: print(e) print(f"Task aborted: {model_name}") result = None finally: loop.close() return result def add_gallery(image, gallery): if gallery is None: gallery = [] with lock: if image is not None: gallery.insert(0, (image, model_name)) return gallery def gen_fn_gallery(prompt, gallery): if gallery is None: gallery = [] try: loop = asyncio.new_event_loop() result = loop.run_until_complete(infer(prompt, timeout=300)) with lock: if result: gallery.insert(0, result) except (Exception, asyncio.CancelledError) as e: print(e) print(f"Task aborted: {model_name}") finally: loop.close() yield gallery def generate_caption(image, min_len=30, max_len=100): try: inputs = processor(image, return_tensors="pt") out = model.generate(**inputs, min_length=min_len, max_length=max_len) caption = processor.decode(out[0], skip_special_tokens=True) return caption except Exception as e: return 'Unable to generate caption.' def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1): try: llm = HuggingFaceEndpoint( repo_id=model_id, max_new_tokens=max_new_tokens, temperature=temperature, token=os.getenv("HF_TOKEN") ) except Exception as e: print(f"Error loading model: {e}") llm = None return llm def get_response(system_message, chat_history, user_text, max_new_tokens=256): hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1) if hf is None: return "Error with model inference.", chat_history prompt = PromptTemplate.from_template( "[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:" ) chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content') response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history)) response = response.split("AI:")[-1] chat_history.append({'role': 'user', 'content': user_text}) chat_history.append({'role': 'assistant', 'content': response}) return response, chat_history def chat_function(user_text, uploaded_image, system_message, chat_history): # If an image is uploaded, generate a caption for it if uploaded_image: caption = generate_caption(uploaded_image) chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'}) chat_history.append({'role': 'assistant', 'content': caption}) # Return the updated chat history return chat_history, chat_history # If no image is uploaded, generate a response from the chat model response, updated_history = get_response(system_message, chat_history, user_text) return response, updated_history def gradio_interface(): with gr.Blocks() as demo: gr.Markdown("# Personal HuggingFace ChatBot") with gr.Row(): with gr.Column(): txt_input = gr.Textbox(label='Enter your text here', lines=4) img_input = gr.Image(label='Upload an image', type='pil') system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.") chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}]) submit_btn = gr.Button('Submit') response_output = gr.Markdown() gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True) submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history]) img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output]) demo.launch() gradio_interface()