praysimanjuntak's picture
Update app.py
f37026a verified
raw
history blame
3.73 kB
import time
from threading import Thread
import subprocess
subprocess.run(["pip", "install", "."])
import gradio as gr
from io import BytesIO
import requests
from PIL import Image
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.model.builder import load_pretrained_model
from llava.mm_utils import tokenizer_image_token
from transformers.generation.streamers import TextIteratorStreamer
import spaces
device = "cuda:0"
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="./checkpoints/llava-phi3-3.8b-lora", model_name="llava-phi3-3.8b-lora", model_base="microsoft/Phi-3-mini-128k-instruct", load_8bit=False, load_4bit=False, device=device)
model.to(device)
def load_image(image_file):
if image_file.startswith('http') or image_file.startswith('https'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
@spaces.GPU
def bot_streaming(message, history):
if message["files"]:
# message["files"][-1] is a Dict or just a string
if type(message["files"][-1]) == dict:
image = message["files"][-1]["path"]
else:
image = message["files"][-1]
else:
gr.Error("You need to upload an image for LLaVA to work.")
# print(f"prompt: {prompt}")
image_data = load_image(str(image))
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device)
# just one turn, always prepend image token
prompt = f"<|user|>{chr(10)}{DEFAULT_IMAGE_TOKEN + chr(10) + message['text']}<|end|>{chr(10)}<|assistant|>"
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, "timeout": 20.0})
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
top_p=1.0,
max_new_tokens=1024,
streamer=streamer,
use_cache=True))
thread.start()
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|end|> and remove it from the new_text
if "<|end|>" in new_text:
new_text = new_text.split("<|end|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
chatbot=gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(fill_height=True, ) as demo:
gr.ChatInterface(
fn=bot_streaming,
title="Multimodal Phi 3 3.8B",
examples=[{"text": "What is the color of the cat?", "files": ["./FELV-cat.jpg"]},
{"text": "What is the type of the fish?", "files": ["./fish.jpg"]}],
description="Try [Multimodal Phi3-3.8B LoRA](https://huggingface.co/praysimanjuntak/llava-phi3-3.8b-lora). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.",
stop_btn="Stop Generation",
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
)
demo.queue(api_open=False)
demo.launch(show_api=False, share=False)