import argparse import datetime import json import os import time import gradio as gr import requests from mplug_owl2.conversation import (default_conversation, conv_templates, SeparatorStyle) from mplug_owl2.constants import LOGDIR from mplug_owl2.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) from model_worker import ModelWorker import hashlib logger = build_logger("gradio_web_server_local", "gradio_web_server_local.log") headers = {"User-Agent": "mPLUG-Owl2 Client"} no_change_btn = gr.Button() enable_btn = gr.Button(interactive=True) disable_btn = gr.Button(interactive=False) def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") return name get_window_url_params = """ function() { const params = new URLSearchParams(window.location.search); url_params = Object.fromEntries(params); console.log(url_params); return url_params; } """ def load_demo(url_params, request: gr.Request): logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") state = default_conversation.copy() return state def vote_last_response(state, vote_type, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "state": state.dict(), "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def upvote_last_response(state, request: gr.Request): logger.info(f"upvote. ip: {request.client.host}") vote_last_response(state, "upvote", request) return ("",) + (disable_btn,) * 3 def downvote_last_response(state, request: gr.Request): logger.info(f"downvote. ip: {request.client.host}") vote_last_response(state, "downvote", request) return ("",) + (disable_btn,) * 3 def flag_last_response(state, request: gr.Request): logger.info(f"flag. ip: {request.client.host}") vote_last_response(state, "flag", request) return ("",) + (disable_btn,) * 3 def regenerate(state, image_process_mode, request: gr.Request): logger.info(f"regenerate. ip: {request.client.host}") state.messages[-1][-1] = None prev_human_msg = state.messages[-2] if type(prev_human_msg[1]) in (tuple, list): prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def clear_history(request: gr.Request): logger.info(f"clear_history. ip: {request.client.host}") state = default_conversation.copy() return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def add_text(state, text, image, image_process_mode, request: gr.Request): logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") if len(text) <= 0 and image is None: state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 if args.moderate: flagged = violates_moderation(text) if flagged: state.skip_next = True return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( no_change_btn,) * 5 text = text[:3584] # Hard cut-off if image is not None: text = text[:3500] # Hard cut-off for images if '<|image|>' not in text: text = '<|image|>' + text text = (text, image, image_process_mode) if len(state.get_images(return_pil=True)) > 0: state = default_conversation.copy() state.append_message(state.roles[0], text) state.append_message(state.roles[1], None) state.skip_next = False return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 def http_bot(state, temperature, top_p, max_new_tokens, request: gr.Request): logger.info(f"http_bot. ip: {request.client.host}") start_tstamp = time.time() if state.skip_next: # This generate call is skipped due to invalid inputs yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return if len(state.messages) == state.offset + 2: # First round of conversation template_name = "mplug_owl2" new_state = conv_templates[template_name].copy() new_state.append_message(new_state.roles[0], state.messages[-2][1]) new_state.append_message(new_state.roles[1], None) state = new_state # Construct prompt prompt = state.get_prompt() all_images = state.get_images(return_pil=True) all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] for image, hash in zip(all_images, all_image_hash): t = datetime.datetime.now() filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) # Make requests pload = { "prompt": prompt, "temperature": float(temperature), "top_p": float(top_p), "max_new_tokens": min(int(max_new_tokens), 2048), "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, "images": f'List of {len(state.get_images())} images: {all_image_hash}', } logger.info(f"==== request ====\n{pload}") pload['images'] = state.get_images() state.messages[-1][-1] = "â" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 try: # Stream output # response = requests.post(worker_addr + "/worker_generate_stream", # headers=headers, json=pload, stream=True, timeout=10) # for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): response = model.generate_stream_gate(pload) for chunk in response: if chunk: data = json.loads(chunk.decode()) if data["error_code"] == 0: output = data["text"][len(prompt):].strip() state.messages[-1][-1] = output + "â" yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 else: output = data["text"] + f" (error_code: {data['error_code']})" state.messages[-1][-1] = output yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return time.sleep(0.03) except requests.exceptions.RequestException as e: state.messages[-1][-1] = server_error_msg yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) return state.messages[-1][-1] = state.messages[-1][-1][:-1] yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 finish_tstamp = time.time() logger.info(f"{output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "start": round(start_tstamp, 4), "finish": round(start_tstamp, 4), "state": state.dict(), "images": all_image_hash, "ip": request.client.host, } fout.write(json.dumps(data) + "\n") title_markdown = ("""