Spaces:
Runtime error
Runtime error
import sys | |
import os | |
import argparse | |
import time | |
import subprocess | |
import gradio as gr | |
import llava.serve.gradio_web_server as gws | |
def build_demo(embed_mode, cur_dir=None, concurrency_count=10): | |
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo: | |
state = gr.State() | |
if not embed_mode: | |
gr.Markdown(gws.title_markdown) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(elem_id="model_selector_row"): | |
model_selector = gr.Dropdown( | |
choices=gws.models, | |
value=gws.models[0] if len(gws.models) > 0 else "", | |
interactive=True, | |
show_label=False, | |
container=False) | |
imagebox = gr.Image(type="pil") | |
image_process_mode = gr.Radio( | |
["Crop", "Resize", "Pad", "Default"], | |
value="Default", | |
label="Preprocess for non-square image", visible=False) | |
if cur_dir is None: | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
user_prompt = "Evaluate and explain if this chart is misleading" | |
gr.Examples(examples=[ | |
[f"{cur_dir}/examples/bar_custom_1.png", user_prompt], | |
[f"{cur_dir}/examples/fox_news.jpeg", user_prompt], | |
[f"{cur_dir}/examples/bar_wiki.png", user_prompt], | |
], inputs=[imagebox, textbox]) | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
with gr.Column(scale=8): | |
chatbot = gr.Chatbot( | |
elem_id="chatbot", | |
label="LLaVA Chatbot", | |
height=650, | |
layout="panel", | |
) | |
with gr.Row(): | |
with gr.Column(scale=8): | |
textbox.render() | |
with gr.Column(scale=1, min_width=50): | |
submit_btn = gr.Button(value="Send", variant="primary") | |
with gr.Row(elem_id="buttons") as button_row: | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=False) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=False) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=False) | |
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) | |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
clear_btn = gr.Button(value="🗑️ Clear", interactive=False) | |
if not embed_mode: | |
gr.Markdown(gws.tos_markdown) | |
gr.Markdown(gws.learn_more_markdown) | |
url_params = gr.JSON(visible=False) | |
# Register listeners | |
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
upvote_btn.click( | |
gws.upvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
downvote_btn.click( | |
gws.downvote_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
flag_btn.click( | |
gws.flag_last_response, | |
[state, model_selector], | |
[textbox, upvote_btn, downvote_btn, flag_btn] | |
) | |
regenerate_btn.click( | |
gws.regenerate, | |
[state, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list | |
).then( | |
gws.http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
clear_btn.click( | |
gws.clear_history, | |
None, | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
) | |
textbox.submit( | |
gws.add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list, | |
queue=False | |
).then( | |
gws.http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
submit_btn.click( | |
gws.add_text, | |
[state, textbox, imagebox, image_process_mode], | |
[state, chatbot, textbox, imagebox] + btn_list | |
).then( | |
gws.http_bot, | |
[state, model_selector, temperature, top_p, max_output_tokens], | |
[state, chatbot] + btn_list, | |
concurrency_limit=concurrency_count | |
) | |
if gws.args.model_list_mode == "once": | |
demo.load( | |
gws.load_demo, | |
[url_params], | |
[state, model_selector], | |
js=gws.get_window_url_params | |
) | |
elif gws.args.model_list_mode == "reload": | |
demo.load( | |
gws.load_demo_refresh_model_list, | |
None, | |
[state, model_selector], | |
queue=False | |
) | |
else: | |
raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}") | |
return demo | |
# Execute the pip install command with additional options | |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']) | |
def start_controller(): | |
print("Starting the controller") | |
controller_command = [ | |
sys.executable, | |
"-m", | |
"llava.serve.controller", | |
"--host", | |
"0.0.0.0", | |
"--port", | |
"10000", | |
] | |
print(controller_command) | |
return subprocess.Popen(controller_command) | |
def start_worker(model_path: str, bits=16): | |
print(f"Starting the model worker for the model {model_path}") | |
model_name = model_path.strip("/").split("/")[-1] | |
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit." | |
if bits != 16: | |
model_name += f"-{bits}bit" | |
worker_command = [ | |
sys.executable, | |
"-m", | |
"llava.serve.model_worker", | |
"--host", | |
"0.0.0.0", | |
"--controller", | |
"http://localhost:10000", | |
"--model-path", | |
model_path, | |
"--model-name", | |
model_name, | |
"--use-flash-attn", | |
] | |
if bits != 16: | |
worker_command += [f"--load-{bits}bit"] | |
print(worker_command) | |
return subprocess.Popen(worker_command) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--controller-url", type=str, default="http://localhost:10000") | |
parser.add_argument("--concurrency-count", type=int, default=5) | |
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"]) | |
parser.add_argument("--share", action="store_true") | |
parser.add_argument("--moderate", action="store_true") | |
parser.add_argument("--embed", action="store_true") | |
gws.args = parser.parse_args() | |
gws.models = [] | |
gws.title_markdown += """ | |
ONLY WORKS WITH GPU! By default, we load the model with 4-bit quantization to make it fit in smaller hardwares. Set the environment variable `bits` to control the quantization. | |
Set the environment variable `model` to change the model: | |
[`liuhaotian/llava-v1.6-mistral-7b`](https://huggingface.co/liuhaotian/llava-v1.6-mistral-7b), | |
[`liuhaotian/llava-v1.6-vicuna-7b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b), | |
[`liuhaotian/llava-v1.6-vicuna-13b`](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-13b), | |
[`liuhaotian/llava-v1.6-34b`](https://huggingface.co/liuhaotian/llava-v1.6-34b). | |
""" | |
print(f"args: {gws.args}") | |
model_path = os.getenv("model", "liuhaotian/llava-v1.6-mistral-7b") | |
bits = int(os.getenv("bits", 4)) | |
concurrency_count = int(os.getenv("concurrency_count", 5)) | |
controller_proc = start_controller() | |
worker_proc = start_worker(model_path, bits=bits) | |
# Wait for worker and controller to start | |
time.sleep(10) | |
exit_status = 0 | |
try: | |
demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count) | |
demo.queue( | |
status_update_rate=10, | |
api_open=False | |
).launch( | |
server_name=gws.args.host, | |
server_port=gws.args.port, | |
share=gws.args.share | |
) | |
except Exception as e: | |
print(e) | |
exit_status = 1 | |
finally: | |
worker_proc.kill() | |
controller_proc.kill() | |
sys.exit(exit_status) | |