Exched's picture
Update app.py
71d4c29 verified
raw
history blame
6.15 kB
import os
import torch
from PIL import Image
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from random import randint
import asyncio
from threading import RLock
from externalmod import gr_Interface_load
# Define model IDs
llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3"
blip_model_id = "Salesforce/blip-image-captioning-large"
model_name = "John6666/wai-ani-hentai-pony-v3-sdxl"
# Initialize BLIP processor and model
processor = BlipProcessor.from_pretrained(blip_model_id)
model = BlipForConditionalGeneration.from_pretrained(blip_model_id)
# Initialize the model loading function
lock = RLock()
model_load = None
def load_fn(model):
global model_load
try:
model_load = gr_Interface_load(f'models/{model}')
except Exception as error:
print(error)
model_load = gr.Interface(lambda: None, ['text'], ['image'])
load_fn(model_name)
async def infer(prompt, timeout):
noise = ""
rand = randint(1, 500)
for i in range(rand):
noise += " "
task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}'))
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except (Exception, asyncio.TimeoutError) as e:
print(e)
print(f"Task timed out: {model_name}")
if not task.done(): task.cancel()
result = None
if task.done() and result is not None:
with lock:
image = Image.open(result).convert('RGBA')
return image
return None
def gen_fn(prompt):
try:
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(prompt, timeout=300))
except (Exception, asyncio.CancelledError) as e:
print(e)
print(f"Task aborted: {model_name}")
result = None
finally:
loop.close()
return result
def add_gallery(image, gallery):
if gallery is None: gallery = []
with lock:
if image is not None: gallery.insert(0, (image, model_name))
return gallery
def gen_fn_gallery(prompt, gallery):
if gallery is None: gallery = []
try:
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(prompt, timeout=300))
with lock:
if result: gallery.insert(0, result)
except (Exception, asyncio.CancelledError) as e:
print(e)
print(f"Task aborted: {model_name}")
finally:
loop.close()
yield gallery
def generate_caption(image, min_len=30, max_len=100):
try:
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, min_length=min_len, max_length=max_len)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
return 'Unable to generate caption.'
def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1):
try:
llm = HuggingFaceEndpoint(
repo_id=model_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
token=os.getenv("HF_TOKEN")
)
except Exception as e:
print(f"Error loading model: {e}")
llm = None
return llm
def get_response(system_message, chat_history, user_text, max_new_tokens=256):
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
if hf is None:
return "Error with model inference.", chat_history
prompt = PromptTemplate.from_template(
"[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
)
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
response = response.split("AI:")[-1]
chat_history.append({'role': 'user', 'content': user_text})
chat_history.append({'role': 'assistant', 'content': response})
return response, chat_history
def chat_function(user_text, uploaded_image, system_message, chat_history):
# If an image is uploaded, generate a caption for it
if uploaded_image:
caption = generate_caption(uploaded_image)
chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'})
chat_history.append({'role': 'assistant', 'content': caption})
# Return the updated chat history
return chat_history, chat_history
# If no image is uploaded, generate a response from the chat model
response, updated_history = get_response(system_message, chat_history, user_text)
return response, updated_history
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# Personal HuggingFace ChatBot")
with gr.Row():
with gr.Column():
txt_input = gr.Textbox(label='Enter your text here', lines=4)
img_input = gr.Image(label='Upload an image', type='pil')
system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.")
chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}])
submit_btn = gr.Button('Submit')
response_output = gr.Markdown()
gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True)
submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history])
img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output])
demo.launch()
gradio_interface()