Spaces:
Running
Running
import os | |
import torch | |
from PIL import Image | |
import gradio as gr | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from langchain_huggingface import HuggingFaceEndpoint | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from random import randint | |
import asyncio | |
from threading import RLock | |
from externalmod import gr_Interface_load | |
# Define model IDs | |
llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
blip_model_id = "Salesforce/blip-image-captioning-large" | |
model_name = "John6666/wai-ani-hentai-pony-v3-sdxl" | |
# Initialize BLIP processor and model | |
processor = BlipProcessor.from_pretrained(blip_model_id) | |
model = BlipForConditionalGeneration.from_pretrained(blip_model_id) | |
# Initialize the model loading function | |
lock = RLock() | |
model_load = None | |
def load_fn(model): | |
global model_load | |
try: | |
model_load = gr_Interface_load(f'models/{model}') | |
except Exception as error: | |
print(error) | |
model_load = gr.Interface(lambda: None, ['text'], ['image']) | |
load_fn(model_name) | |
async def infer(prompt, timeout): | |
noise = "" | |
rand = randint(1, 500) | |
for i in range(rand): | |
noise += " " | |
task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}')) | |
await asyncio.sleep(0) | |
try: | |
result = await asyncio.wait_for(task, timeout=timeout) | |
except (Exception, asyncio.TimeoutError) as e: | |
print(e) | |
print(f"Task timed out: {model_name}") | |
if not task.done(): task.cancel() | |
result = None | |
if task.done() and result is not None: | |
with lock: | |
image = Image.open(result).convert('RGBA') | |
return image | |
return None | |
def gen_fn(prompt): | |
try: | |
loop = asyncio.new_event_loop() | |
result = loop.run_until_complete(infer(prompt, timeout=300)) | |
except (Exception, asyncio.CancelledError) as e: | |
print(e) | |
print(f"Task aborted: {model_name}") | |
result = None | |
finally: | |
loop.close() | |
return result | |
def add_gallery(image, gallery): | |
if gallery is None: gallery = [] | |
with lock: | |
if image is not None: gallery.insert(0, (image, model_name)) | |
return gallery | |
def gen_fn_gallery(prompt, gallery): | |
if gallery is None: gallery = [] | |
try: | |
loop = asyncio.new_event_loop() | |
result = loop.run_until_complete(infer(prompt, timeout=300)) | |
with lock: | |
if result: gallery.insert(0, result) | |
except (Exception, asyncio.CancelledError) as e: | |
print(e) | |
print(f"Task aborted: {model_name}") | |
finally: | |
loop.close() | |
yield gallery | |
def generate_caption(image, min_len=30, max_len=100): | |
try: | |
inputs = processor(image, return_tensors="pt") | |
out = model.generate(**inputs, min_length=min_len, max_length=max_len) | |
caption = processor.decode(out[0], skip_special_tokens=True) | |
return caption | |
except Exception as e: | |
return 'Unable to generate caption.' | |
def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1): | |
try: | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
token=os.getenv("HF_TOKEN") | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
llm = None | |
return llm | |
def get_response(system_message, chat_history, user_text, max_new_tokens=256): | |
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1) | |
if hf is None: | |
return "Error with model inference.", chat_history | |
prompt = PromptTemplate.from_template( | |
"[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:" | |
) | |
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content') | |
response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history)) | |
response = response.split("AI:")[-1] | |
chat_history.append({'role': 'user', 'content': user_text}) | |
chat_history.append({'role': 'assistant', 'content': response}) | |
return response, chat_history | |
def chat_function(user_text, uploaded_image, system_message, chat_history): | |
# If an image is uploaded, generate a caption for it | |
if uploaded_image: | |
caption = generate_caption(uploaded_image) | |
chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'}) | |
chat_history.append({'role': 'assistant', 'content': caption}) | |
# Return the updated chat history | |
return chat_history, chat_history | |
# If no image is uploaded, generate a response from the chat model | |
response, updated_history = get_response(system_message, chat_history, user_text) | |
return response, updated_history | |
def gradio_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Personal HuggingFace ChatBot") | |
with gr.Row(): | |
with gr.Column(): | |
txt_input = gr.Textbox(label='Enter your text here', lines=4) | |
img_input = gr.Image(label='Upload an image', type='pil') | |
system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.") | |
chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}]) | |
submit_btn = gr.Button('Submit') | |
response_output = gr.Markdown() | |
gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True) | |
submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history]) | |
img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output]) | |
demo.launch() | |
gradio_interface() | |