Spaces:
Running
Running
File size: 6,152 Bytes
71d4c29 6fef025 71d4c29 f5b8400 b2c6964 71d4c29 6fef025 71d4c29 6fef025 71d4c29 6fef025 71d4c29 6fef025 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 6fef025 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 b2c6964 71d4c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import os
import torch
from PIL import Image
import gradio as gr
from transformers import BlipProcessor, BlipForConditionalGeneration
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from random import randint
import asyncio
from threading import RLock
from externalmod import gr_Interface_load
# Define model IDs
llm_model_id = "mistralai/Mistral-7B-Instruct-v0.3"
blip_model_id = "Salesforce/blip-image-captioning-large"
model_name = "John6666/wai-ani-hentai-pony-v3-sdxl"
# Initialize BLIP processor and model
processor = BlipProcessor.from_pretrained(blip_model_id)
model = BlipForConditionalGeneration.from_pretrained(blip_model_id)
# Initialize the model loading function
lock = RLock()
model_load = None
def load_fn(model):
global model_load
try:
model_load = gr_Interface_load(f'models/{model}')
except Exception as error:
print(error)
model_load = gr.Interface(lambda: None, ['text'], ['image'])
load_fn(model_name)
async def infer(prompt, timeout):
noise = ""
rand = randint(1, 500)
for i in range(rand):
noise += " "
task = asyncio.create_task(asyncio.to_thread(model_load, f'{prompt} {noise}'))
await asyncio.sleep(0)
try:
result = await asyncio.wait_for(task, timeout=timeout)
except (Exception, asyncio.TimeoutError) as e:
print(e)
print(f"Task timed out: {model_name}")
if not task.done(): task.cancel()
result = None
if task.done() and result is not None:
with lock:
image = Image.open(result).convert('RGBA')
return image
return None
def gen_fn(prompt):
try:
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(prompt, timeout=300))
except (Exception, asyncio.CancelledError) as e:
print(e)
print(f"Task aborted: {model_name}")
result = None
finally:
loop.close()
return result
def add_gallery(image, gallery):
if gallery is None: gallery = []
with lock:
if image is not None: gallery.insert(0, (image, model_name))
return gallery
def gen_fn_gallery(prompt, gallery):
if gallery is None: gallery = []
try:
loop = asyncio.new_event_loop()
result = loop.run_until_complete(infer(prompt, timeout=300))
with lock:
if result: gallery.insert(0, result)
except (Exception, asyncio.CancelledError) as e:
print(e)
print(f"Task aborted: {model_name}")
finally:
loop.close()
yield gallery
def generate_caption(image, min_len=30, max_len=100):
try:
inputs = processor(image, return_tensors="pt")
out = model.generate(**inputs, min_length=min_len, max_length=max_len)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
except Exception as e:
return 'Unable to generate caption.'
def get_llm_hf_inference(model_id=llm_model_id, max_new_tokens=128, temperature=0.1):
try:
llm = HuggingFaceEndpoint(
repo_id=model_id,
max_new_tokens=max_new_tokens,
temperature=temperature,
token=os.getenv("HF_TOKEN")
)
except Exception as e:
print(f"Error loading model: {e}")
llm = None
return llm
def get_response(system_message, chat_history, user_text, max_new_tokens=256):
hf = get_llm_hf_inference(max_new_tokens=max_new_tokens, temperature=0.1)
if hf is None:
return "Error with model inference.", chat_history
prompt = PromptTemplate.from_template(
"[INST] {system_message}\nCurrent Conversation:\n{chat_history}\n\nUser: {user_text}.\n [/INST]\nAI:"
)
chat = prompt | hf.bind(skip_prompt=True) | StrOutputParser(output_key='content')
response = chat.invoke(input=dict(system_message=system_message, user_text=user_text, chat_history=chat_history))
response = response.split("AI:")[-1]
chat_history.append({'role': 'user', 'content': user_text})
chat_history.append({'role': 'assistant', 'content': response})
return response, chat_history
def chat_function(user_text, uploaded_image, system_message, chat_history):
# If an image is uploaded, generate a caption for it
if uploaded_image:
caption = generate_caption(uploaded_image)
chat_history.append({'role': 'user', 'content': f'![uploaded image](data:image/png;base64,{uploaded_image})'})
chat_history.append({'role': 'assistant', 'content': caption})
# Return the updated chat history
return chat_history, chat_history
# If no image is uploaded, generate a response from the chat model
response, updated_history = get_response(system_message, chat_history, user_text)
return response, updated_history
def gradio_interface():
with gr.Blocks() as demo:
gr.Markdown("# Personal HuggingFace ChatBot")
with gr.Row():
with gr.Column():
txt_input = gr.Textbox(label='Enter your text here', lines=4)
img_input = gr.Image(label='Upload an image', type='pil')
system_message = gr.Textbox(label='System Message', value="You are a friendly AI conversing with a human user.")
chat_history = gr.State(value=[{'role': 'assistant', 'content': 'Hello, there! How can I help you today?'}])
submit_btn = gr.Button('Submit')
response_output = gr.Markdown()
gallery_output = gr.Gallery(label="Generated Images", show_download_button=True, elem_classes="gallery", interactive=False, show_share_button=True, container=True)
submit_btn.click(chat_function, inputs=[txt_input, img_input, system_message, chat_history], outputs=[response_output, chat_history])
img_input.change(lambda img: add_gallery(gen_fn("Generate image of a fantasy scene"), gallery_output), inputs=[img_input], outputs=[gallery_output])
demo.launch()
gradio_interface()
|