Spaces:
Runtime error
Runtime error
import time | |
from threading import Thread | |
import subprocess | |
subprocess.run(["pip", "install", "."]) | |
import gradio as gr | |
from io import BytesIO | |
import requests | |
from PIL import Image | |
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | |
from llava.model.builder import load_pretrained_model | |
from llava.mm_utils import tokenizer_image_token | |
from transformers.generation.streamers import TextIteratorStreamer | |
import spaces | |
device = "cuda:0" | |
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="./checkpoints/llava-phi3-3.8b-lora", model_name="llava-phi3-3.8b-lora", model_base="microsoft/Phi-3-mini-128k-instruct", load_8bit=False, load_4bit=False, device=device) | |
model.to(device) | |
def load_image(image_file): | |
if image_file.startswith('http') or image_file.startswith('https'): | |
response = requests.get(image_file) | |
image = Image.open(BytesIO(response.content)).convert('RGB') | |
else: | |
image = Image.open(image_file).convert('RGB') | |
return image | |
def bot_streaming(message, history): | |
if message["files"]: | |
# message["files"][-1] is a Dict or just a string | |
if type(message["files"][-1]) == dict: | |
image = message["files"][-1]["path"] | |
else: | |
image = message["files"][-1] | |
else: | |
gr.Error("You need to upload an image for LLaVA to work.") | |
# print(f"prompt: {prompt}") | |
image_data = load_image(str(image)) | |
image_tensor = image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().to(device) | |
# just one turn, always prepend image token | |
prompt = f"<|user|>{chr(10)}{DEFAULT_IMAGE_TOKEN + chr(10) + message['text']}<|end|>{chr(10)}<|assistant|>" | |
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device) | |
streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, "timeout": 20.0}) | |
thread = Thread(target=model.generate, kwargs=dict( | |
inputs=input_ids, | |
images=image_tensor, | |
do_sample=True, | |
temperature=0.2, | |
top_p=1.0, | |
max_new_tokens=1024, | |
streamer=streamer, | |
use_cache=True)) | |
thread.start() | |
buffer = "" | |
time.sleep(0.5) | |
for new_text in streamer: | |
# find <|end|> and remove it from the new_text | |
if "<|end|>" in new_text: | |
new_text = new_text.split("<|end|>")[0] | |
buffer += new_text | |
# generated_text_without_prompt = buffer[len(text_prompt):] | |
generated_text_without_prompt = buffer | |
# print(generated_text_without_prompt) | |
time.sleep(0.06) | |
# print(f"new_text: {generated_text_without_prompt}") | |
yield generated_text_without_prompt | |
chatbot=gr.Chatbot(scale=1) | |
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
with gr.Blocks(fill_height=True, ) as demo: | |
gr.ChatInterface( | |
fn=bot_streaming, | |
title="Multimodal Phi 3 3.8B", | |
examples=[{"text": "What is the color of the cat?", "files": ["./FELV-cat.jpg"]}, | |
{"text": "What is the type of the fish?", "files": ["./fish.jpg"]}], | |
description="Try [Multimodal Phi3-3.8B LoRA](https://huggingface.co/praysimanjuntak/llava-phi3-3.8b-lora). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.", | |
stop_btn="Stop Generation", | |
multimodal=True, | |
textbox=chat_input, | |
chatbot=chatbot, | |
) | |
demo.queue(api_open=False) | |
demo.launch(show_api=False, share=False) |