|
|
|
import gradio as gr |
|
import os |
|
from threading import Thread |
|
|
|
|
|
import cv2 |
|
|
|
import datetime |
|
|
|
import torch |
|
|
|
import spaces |
|
import numpy as np |
|
|
|
from llava import conversation as conversation_lib |
|
from llava.constants import DEFAULT_IMAGE_TOKEN |
|
|
|
|
|
from llava.constants import ( |
|
IMAGE_TOKEN_INDEX, |
|
DEFAULT_IMAGE_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IM_END_TOKEN, |
|
) |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.model.builder import load_pretrained_model |
|
from llava.utils import disable_torch_init |
|
from llava.mm_utils import ( |
|
tokenizer_image_token, |
|
process_images, |
|
get_model_name_from_path, |
|
KeywordsStoppingCriteria, |
|
) |
|
|
|
from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown |
|
|
|
import requests |
|
from PIL import Image |
|
from io import BytesIO |
|
from transformers import TextStreamer, TextIteratorStreamer |
|
|
|
import hashlib |
|
import PIL |
|
import base64 |
|
import json |
|
|
|
import datetime |
|
import gradio as gr |
|
import gradio_client |
|
import subprocess |
|
import sys |
|
|
|
from huggingface_hub import HfApi |
|
from huggingface_hub import login |
|
from huggingface_hub import revision_exists |
|
|
|
login(token=os.environ["HF_TOKEN"], |
|
write_permission=True) |
|
|
|
api = HfApi() |
|
repo_name = os.environ["LOG_REPO"] |
|
|
|
external_log_dir = "./logs" |
|
LOGDIR = external_log_dir |
|
|
|
|
|
def install_gradio_4_35_0(): |
|
current_version = gr.__version__ |
|
if current_version != "4.35.0": |
|
print(f"Current Gradio version: {current_version}") |
|
print("Installing Gradio 4.35.0...") |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"]) |
|
print("Gradio 4.35.0 installed successfully.") |
|
else: |
|
print("Gradio 4.35.0 is already installed.") |
|
|
|
|
|
install_gradio_4_35_0() |
|
|
|
import gradio as gr |
|
import gradio_client |
|
print(f"Gradio version: {gr.__version__}") |
|
print(f"Gradio-client version: {gradio_client.__version__}") |
|
|
|
def get_conv_log_filename(): |
|
t = datetime.datetime.now() |
|
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") |
|
return name |
|
|
|
class InferenceDemo(object): |
|
def __init__( |
|
self, args, model_path, tokenizer, model, image_processor, context_len |
|
) -> None: |
|
disable_torch_init() |
|
|
|
self.tokenizer, self.model, self.image_processor, self.context_len = ( |
|
tokenizer, |
|
model, |
|
image_processor, |
|
context_len, |
|
) |
|
|
|
if "llama-2" in model_name.lower(): |
|
conv_mode = "llava_llama_2" |
|
elif "v1" in model_name.lower() or "pulse" in model_name.lower(): |
|
conv_mode = "llava_v1" |
|
elif "mpt" in model_name.lower(): |
|
conv_mode = "mpt" |
|
elif "qwen" in model_name.lower(): |
|
conv_mode = "qwen_1_5" |
|
else: |
|
conv_mode = "llava_v0" |
|
|
|
if args.conv_mode is not None and conv_mode != args.conv_mode: |
|
print( |
|
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( |
|
conv_mode, args.conv_mode, args.conv_mode |
|
) |
|
) |
|
else: |
|
args.conv_mode = conv_mode |
|
self.conv_mode = conv_mode |
|
self.conversation = conv_templates[args.conv_mode].copy() |
|
self.num_frames = args.num_frames |
|
|
|
|
|
def is_valid_video_filename(name): |
|
video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] |
|
|
|
ext = name.split(".")[-1].lower() |
|
|
|
if ext in video_extensions: |
|
return True |
|
else: |
|
return False |
|
|
|
def is_valid_image_filename(name): |
|
image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] |
|
|
|
ext = name.split(".")[-1].lower() |
|
|
|
if ext in image_extensions: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def sample_frames(video_file, num_frames): |
|
video = cv2.VideoCapture(video_file) |
|
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
interval = total_frames // num_frames |
|
frames = [] |
|
for i in range(total_frames): |
|
ret, frame = video.read() |
|
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
if not ret: |
|
continue |
|
if i % interval == 0: |
|
frames.append(pil_img) |
|
video.release() |
|
return frames |
|
|
|
|
|
def load_image(image_file): |
|
if image_file.startswith("http") or image_file.startswith("https"): |
|
response = requests.get(image_file) |
|
if response.status_code == 200: |
|
image = Image.open(BytesIO(response.content)).convert("RGB") |
|
else: |
|
print("failed to load the image") |
|
else: |
|
print("Load image from local file") |
|
print(image_file) |
|
image = Image.open(image_file).convert("RGB") |
|
|
|
return image |
|
|
|
|
|
def clear_history(history): |
|
|
|
our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy() |
|
|
|
return None |
|
|
|
|
|
def clear_response(history): |
|
for index_conv in range(1, len(history)): |
|
|
|
conv = history[-index_conv] |
|
if not (conv[0] is None): |
|
break |
|
question = history[-index_conv][0] |
|
history = history[:-index_conv] |
|
return history, question |
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_message(history, message): |
|
|
|
global our_chatbot |
|
if len(history) == 0: |
|
our_chatbot = InferenceDemo( |
|
args, model_path, tokenizer, model, image_processor, context_len |
|
) |
|
|
|
for x in message["files"]: |
|
history.append(((x,), None)) |
|
if message["text"] is not None: |
|
history.append((message["text"], None)) |
|
return history, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
|
|
|
@spaces.GPU |
|
def bot(history, temperature, top_p, max_output_tokens): |
|
print("### turn start history",history) |
|
print("### turn start conv",our_chatbot.conversation) |
|
text = history[-1][0] |
|
images_this_term = [] |
|
text_this_term = "" |
|
|
|
num_new_images = 0 |
|
for i, message in enumerate(history[:-1]): |
|
if type(message[0]) is tuple: |
|
images_this_term.append(message[0][0]) |
|
if is_valid_video_filename(message[0][0]): |
|
|
|
raise ValueError("Video is not supported") |
|
num_new_images += our_chatbot.num_frames |
|
elif is_valid_image_filename(message[0][0]): |
|
print("#### Load image from local file",message[0][0]) |
|
num_new_images += 1 |
|
else: |
|
raise ValueError("Invalid image file") |
|
else: |
|
num_new_images = 0 |
|
|
|
|
|
|
|
|
|
assert len(images_this_term) > 0, "must have an image" |
|
|
|
|
|
|
|
all_image_hash = [] |
|
all_image_path = [] |
|
for image_path in images_this_term: |
|
with open(image_path, "rb") as image_file: |
|
image_data = image_file.read() |
|
image_hash = hashlib.md5(image_data).hexdigest() |
|
all_image_hash.append(image_hash) |
|
image = PIL.Image.open(image_path).convert("RGB") |
|
t = datetime.datetime.now() |
|
filename = os.path.join( |
|
LOGDIR, |
|
"serve_images", |
|
f"{t.year}-{t.month:02d}-{t.day:02d}", |
|
f"{image_hash}.jpg", |
|
) |
|
all_image_path.append(filename) |
|
if not os.path.isfile(filename): |
|
os.makedirs(os.path.dirname(filename), exist_ok=True) |
|
print("image save to",filename) |
|
image.save(filename) |
|
|
|
image_list = [] |
|
for f in images_this_term: |
|
if is_valid_video_filename(f): |
|
image_list += sample_frames(f, our_chatbot.num_frames) |
|
elif is_valid_image_filename(f): |
|
image_list.append(load_image(f)) |
|
else: |
|
raise ValueError("Invalid image file") |
|
|
|
image_tensor = [ |
|
process_images([f], our_chatbot.image_processor, our_chatbot.model.config)[0] |
|
.to(our_chatbot.model.device) |
|
for f in image_list |
|
] |
|
|
|
|
|
image_tensor = torch.stack(image_tensor) |
|
image_token = DEFAULT_IMAGE_TOKEN * num_new_images |
|
|
|
|
|
|
|
inp = text |
|
inp = image_token + "\n" + inp |
|
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) |
|
|
|
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) |
|
prompt = our_chatbot.conversation.get_prompt() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = tokenizer_image_token( |
|
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" |
|
).unsqueeze(0).to(our_chatbot.model.device) |
|
|
|
stop_str = ( |
|
our_chatbot.conversation.sep |
|
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO |
|
else our_chatbot.conversation.sep2 |
|
) |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria( |
|
keywords, our_chatbot.tokenizer, input_ids |
|
) |
|
|
|
|
|
|
|
streamer = TextIteratorStreamer( |
|
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True |
|
) |
|
print(our_chatbot.model.device) |
|
print(input_ids.device) |
|
print(image_tensor.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_kwargs = dict( |
|
inputs=input_ids, |
|
streamer=streamer, |
|
images=image_tensor, |
|
do_sample=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_new_tokens=max_output_tokens, |
|
use_cache=False, |
|
stopping_criteria=[stopping_criteria], |
|
) |
|
|
|
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs) |
|
t.start() |
|
|
|
outputs = [] |
|
for stream_token in streamer: |
|
outputs.append(stream_token) |
|
|
|
|
|
history[-1] = [text, "".join(outputs)] |
|
yield history |
|
our_chatbot.conversation.messages[-1][-1] = "".join(outputs) |
|
print("### turn end history", history) |
|
print("### turn end conv",our_chatbot.conversation) |
|
|
|
with open(get_conv_log_filename(), "a") as fout: |
|
data = { |
|
"type": "chat", |
|
"model": "PULSE-7b", |
|
"state": history, |
|
"images": all_image_hash, |
|
"images_path": all_image_path |
|
} |
|
print("#### conv log",data) |
|
fout.write(json.dumps(data) + "\n") |
|
for upload_img in all_image_path: |
|
api.upload_file( |
|
path_or_fileobj=upload_img, |
|
path_in_repo=upload_img.replace("./logs/", ""), |
|
repo_id=repo_name, |
|
repo_type="dataset", |
|
|
|
|
|
) |
|
|
|
api.upload_file( |
|
path_or_fileobj=get_conv_log_filename(), |
|
path_in_repo=get_conv_log_filename().replace("./logs/", ""), |
|
repo_id=repo_name, |
|
repo_type="dataset") |
|
|
|
|
|
|
|
txt = gr.Textbox( |
|
scale=4, |
|
show_label=False, |
|
placeholder="Enter text and press enter.", |
|
container=False, |
|
) |
|
|
|
with gr.Blocks( |
|
css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}", |
|
) as demo: |
|
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
gr.HTML(html_header) |
|
|
|
with gr.Column(): |
|
with gr.Accordion("Parameters", open=False) as parameter_row: |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=0.0, |
|
step=0.1, |
|
interactive=True, |
|
label="Temperature", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.0, |
|
maximum=1.0, |
|
value=1, |
|
step=0.1, |
|
interactive=True, |
|
label="Top P", |
|
) |
|
max_output_tokens = gr.Slider( |
|
minimum=0, |
|
maximum=8192, |
|
value=4096, |
|
step=256, |
|
interactive=True, |
|
label="Max output tokens", |
|
) |
|
with gr.Row(): |
|
chatbot = gr.Chatbot([], elem_id="PULSE", bubble_full_width=False, height=750) |
|
|
|
with gr.Row(): |
|
upvote_btn = gr.Button(value="π Upvote", interactive=True) |
|
downvote_btn = gr.Button(value="π Downvote", interactive=True) |
|
flag_btn = gr.Button(value="β οΈ Flag", interactive=True) |
|
|
|
regenerate_btn = gr.Button(value="π Regenerate", interactive=True) |
|
clear_btn = gr.Button(value="ποΈ Clear history", interactive=True) |
|
|
|
|
|
chat_input = gr.MultimodalTextbox( |
|
interactive=True, |
|
file_types=["image"], |
|
placeholder="Enter message or upload file...", |
|
show_label=False, |
|
submit_btn="π" |
|
) |
|
|
|
print(cur_dir) |
|
gr.Examples( |
|
examples_per_page=5, |
|
examples=[ |
|
[ |
|
{ |
|
"files": [ |
|
f"{cur_dir}/examples/ecg_example2.png", |
|
], |
|
"text": "What are the main features in this ECG image?", |
|
}, |
|
], |
|
[ |
|
{ |
|
"files": [ |
|
f"{cur_dir}/examples/ecg_example1.jpg", |
|
], |
|
"text": "What can be inferred from the pattern of the qR complexes and rS complexes in the leads of this ECG image?", |
|
}, |
|
] |
|
], |
|
inputs=[chat_input], |
|
label="Image", |
|
) |
|
|
|
gr.Markdown(tos_markdown) |
|
gr.Markdown(learn_more_markdown) |
|
gr.Markdown(bibtext) |
|
|
|
chat_msg = chat_input.submit( |
|
add_message, [chatbot, chat_input], [chatbot, chat_input] |
|
) |
|
bot_msg = chat_msg.then(bot, [chatbot,temperature, top_p, max_output_tokens], chatbot, api_name="bot_response") |
|
bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) |
|
|
|
|
|
clear_btn.click( |
|
fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all" |
|
) |
|
|
|
|
|
demo.queue() |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
|
|
argparser = argparse.ArgumentParser() |
|
argparser.add_argument("--server_name", default="0.0.0.0", type=str) |
|
argparser.add_argument("--port", default="6123", type=str) |
|
argparser.add_argument( |
|
"--model_path", default="PULSE-ECG/PULSE-7B", type=str |
|
) |
|
|
|
argparser.add_argument("--model-base", type=str, default=None) |
|
argparser.add_argument("--num-gpus", type=int, default=1) |
|
argparser.add_argument("--conv-mode", type=str, default=None) |
|
argparser.add_argument("--temperature", type=float, default=0.0) |
|
argparser.add_argument("--max-new-tokens", type=int, default=1024) |
|
argparser.add_argument("--num_frames", type=int, default=16) |
|
argparser.add_argument("--load-8bit", action="store_true") |
|
argparser.add_argument("--load-4bit", action="store_true") |
|
argparser.add_argument("--debug", action="store_true") |
|
|
|
args = argparser.parse_args() |
|
|
|
model_path = args.model_path |
|
filt_invalid = "cut" |
|
model_name = get_model_name_from_path(args.model_path) |
|
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) |
|
print("### image_processor",image_processor) |
|
model=model.to(torch.device('cuda')) |
|
our_chatbot = None |
|
demo.launch() |