Krypton / app.py
sandz7's picture
Update app.py
ced6c96 verified
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
from threading import Thread
import spaces
import accelerate
import time
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
model.to('cuda')
processor = AutoProcessor.from_pretrained(model_id)
# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = 128009
@spaces.GPU(duration=120)
def krypton(message, history):
print(message)
if message["files"]:
# message["files"][-1] is a Dict or just a string
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
# if there's no image uploaded for this turn, look for images in the past turns
# kept inside tuples, take the last one
for hist in history:
if type(hist[0]) == tuple:
image = hist[0][0]
try:
if image is None:
# Handle the case where image is None
gr.Error("Please upload an image so kypton can understand.")
except NameError:
# Handle the case where 'image' is not defined at all
gr.Error("Upload an image so krypton can work.")
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
image = Image.open(image)
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{message['text']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|eot_id|> and remove it from the new_text
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt
chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=krypton,
chatbot=chatbot,
fill_height=True,
multimodal=True,
textbox=chat_input,
)
if __name__ == "__main__":
demo.launch()