import datetime import time import json import uuid import gradio as gr import regex as re from pathlib import Path from .utils import * from .log_utils import build_logger from .constants import IMAGE_DIR, VIDEO_DIR import imageio from diffusers.utils import load_image import torch ig_logger = build_logger("gradio_web_server_image_generation", "gr_web_image_generation.log") # ig = image generation, loggers for single model direct chat igm_logger = build_logger("gradio_web_server_image_generation_multi", "gr_web_image_generation_multi.log") # igm = image generation multi, loggers for side-by-side and battle ie_logger = build_logger("gradio_web_server_image_editing", "gr_web_image_editing.log") # ie = image editing, loggers for single model direct chat iem_logger = build_logger("gradio_web_server_image_editing_multi", "gr_web_image_editing_multi.log") # iem = image editing multi, loggers for side-by-side and battle vg_logger = build_logger("gradio_web_server_video_generation", "gr_web_video_generation.log") # vg = video generation, loggers for single model direct chat vgm_logger = build_logger("gradio_web_server_video_generation_multi", "gr_web_video_generation_multi.log") # vgm = video generation multi, loggers for side-by-side and battle def save_any_image(image_file, file_path): if isinstance(image_file, str): image = load_image(image_file) image.save(file_path, 'JPEG') else: image_file.save(file_path, 'JPEG') def vote_last_response_ig(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def vote_last_response_igm(states, vote_type, model_selectors, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "models": [x for x in model_selectors], "states": [x.dict() for x in states], "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for state in states: output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def vote_last_response_ie(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) with open(source_file, 'w') as sf: save_any_image(state.source_image, sf) save_image_file_on_log_server(output_file) save_image_file_on_log_server(source_file) def vote_last_response_iem(states, vote_type, model_selectors, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "models": [x for x in model_selectors], "states": [x.dict() for x in states], "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for state in states: output_file = f'{IMAGE_DIR}/edition/{state.conv_id}.jpg' source_file = f'{IMAGE_DIR}/edition/{state.conv_id}_source.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) with open(source_file, 'w') as sf: save_any_image(state.source_image, sf) save_image_file_on_log_server(output_file) save_image_file_on_log_server(source_file) def vote_last_response_vg(state, vote_type, model_selector, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "model": model_selector, "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) if state.model_name.startswith('fal'): r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) else: print("======== video shape: ========") print(state.output.shape) # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] if state.output.shape[-1] != 3: state.output = state.output.permute(0, 2, 3, 1) imageio.mimwrite(output_file, state.output, fps=8, quality=9) save_video_file_on_log_server(output_file) def vote_last_response_vgm(states, vote_type, model_selectors, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "models": [x for x in model_selectors], "states": [x.dict() for x in states], "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for state in states: output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) if state.model_name.startswith('fal'): r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) elif isinstance(state.output, torch.Tensor): print("======== video shape: ========") print(state.output.shape) # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] if state.output.shape[-1] != 3: state.output = state.output.permute(0, 2, 3, 1) imageio.mimwrite(output_file, state.output, fps=8, quality=9) else: r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) save_video_file_on_log_server(output_file) ## Image Generation (IG) Single Model Direct Chat def upvote_last_response_ig(state, model_selector, request: gr.Request): ip = get_ip(request) ig_logger.info(f"upvote. ip: {ip}") vote_last_response_ig(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response_ig(state, model_selector, request: gr.Request): ip = get_ip(request) ig_logger.info(f"downvote. ip: {ip}") vote_last_response_ig(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response_ig(state, model_selector, request: gr.Request): ip = get_ip(request) ig_logger.info(f"flag. ip: {ip}") vote_last_response_ig(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 ## Image Generation Multi (IGM) Side-by-Side and Battle def leftvote_last_response_igm( state0, state1, model_selector0, model_selector1, request: gr.Request ): igm_logger.info(f"leftvote (named). ip: {get_ip(request)}") vote_last_response_igm( [state0, state1], "leftvote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True)) def rightvote_last_response_igm( state0, state1, model_selector0, model_selector1, request: gr.Request ): igm_logger.info(f"rightvote (named). ip: {get_ip(request)}") vote_last_response_igm( [state0, state1], "rightvote", [model_selector0, model_selector1], request ) print(model_selector0) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True)) def tievote_last_response_igm( state0, state1, model_selector0, model_selector1, request: gr.Request ): igm_logger.info(f"tievote (named). ip: {get_ip(request)}") vote_last_response_igm( [state0, state1], "tievote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True)) def bothbad_vote_last_response_igm( state0, state1, model_selector0, model_selector1, request: gr.Request ): igm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") vote_last_response_igm( [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + (gr.Markdown(state0.model_name, visible=True), gr.Markdown(state1.model_name, visible=True)) ## Image Editing (IE) Single Model Direct Chat def upvote_last_response_ie(state, model_selector, request: gr.Request): ip = get_ip(request) ie_logger.info(f"upvote. ip: {ip}") vote_last_response_ie(state, "upvote", model_selector, request) return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 def downvote_last_response_ie(state, model_selector, request: gr.Request): ip = get_ip(request) ie_logger.info(f"downvote. ip: {ip}") vote_last_response_ie(state, "downvote", model_selector, request) return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 def flag_last_response_ie(state, model_selector, request: gr.Request): ip = get_ip(request) ie_logger.info(f"flag. ip: {ip}") vote_last_response_ie(state, "flag", model_selector, request) return ("", "", gr.Image(height=512, width=512, type="pil"), "",) + (disable_btn,) * 3 ## Image Editing Multi (IEM) Side-by-Side and Battle def leftvote_last_response_iem( state0, state1, model_selector0, model_selector1, request: gr.Request ): iem_logger.info(f"leftvote (anony). ip: {get_ip(request)}") vote_last_response_iem( [state0, state1], "leftvote", [model_selector0, model_selector1], request ) # names = ( # "### Model A: " + state0.model_name, # "### Model B: " + state1.model_name, # ) # names = (state0.model_name, state1.model_name) if model_selector0 == "": names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 def rightvote_last_response_iem( state0, state1, model_selector0, model_selector1, request: gr.Request ): iem_logger.info(f"rightvote (anony). ip: {get_ip(request)}") vote_last_response_iem( [state0, state1], "rightvote", [model_selector0, model_selector1], request ) # names = ( # "### Model A: " + state0.model_name, # "### Model B: " + state1.model_name, # ) if model_selector0 == "": names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 def tievote_last_response_iem( state0, state1, model_selector0, model_selector1, request: gr.Request ): iem_logger.info(f"tievote (anony). ip: {get_ip(request)}") vote_last_response_iem( [state0, state1], "tievote", [model_selector0, model_selector1], request ) if model_selector0 == "": names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 def bothbad_vote_last_response_iem( state0, state1, model_selector0, model_selector1, request: gr.Request ): iem_logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}") vote_last_response_iem( [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request ) if model_selector0 == "": names = (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: names = (gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) return names + ("", "", gr.Image(height=512, width=512, type="pil"), "") + (disable_btn,) * 4 ## Video Generation (VG) Single Model Direct Chat def upvote_last_response_vg(state, model_selector, request: gr.Request): ip = get_ip(request) vg_logger.info(f"upvote. ip: {ip}") vote_last_response_vg(state, "upvote", model_selector, request) return ("",) + (disable_btn,) * 3 def downvote_last_response_vg(state, model_selector, request: gr.Request): ip = get_ip(request) vg_logger.info(f"downvote. ip: {ip}") vote_last_response_vg(state, "downvote", model_selector, request) return ("",) + (disable_btn,) * 3 def flag_last_response_vg(state, model_selector, request: gr.Request): ip = get_ip(request) vg_logger.info(f"flag. ip: {ip}") vote_last_response_vg(state, "flag", model_selector, request) return ("",) + (disable_btn,) * 3 ## Image Generation Multi (IGM) Side-by-Side and Battle def leftvote_last_response_vgm( state0, state1, model_selector0, model_selector1, request: gr.Request ): vgm_logger.info(f"leftvote (named). ip: {get_ip(request)}") vote_last_response_vgm( [state0, state1], "leftvote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + (gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + ( gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) def rightvote_last_response_vgm( state0, state1, model_selector0, model_selector1, request: gr.Request ): vgm_logger.info(f"rightvote (named). ip: {get_ip(request)}") vote_last_response_vgm( [state0, state1], "rightvote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + ( gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) def tievote_last_response_vgm( state0, state1, model_selector0, model_selector1, request: gr.Request ): vgm_logger.info(f"tievote (named). ip: {get_ip(request)}") vote_last_response_vgm( [state0, state1], "tievote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + ( gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) def bothbad_vote_last_response_vgm( state0, state1, model_selector0, model_selector1, request: gr.Request ): vgm_logger.info(f"bothbad_vote (named). ip: {get_ip(request)}") vote_last_response_vgm( [state0, state1], "bothbad_vote", [model_selector0, model_selector1], request ) if model_selector0 == "": return ("",) + (disable_btn,) * 4 + ( gr.Markdown(f"### Model A: {state0.model_name.split('_')[1]}", visible=True), gr.Markdown(f"### Model B: {state1.model_name.split('_')[1]}", visible=True)) else: return ("",) + (disable_btn,) * 4 + ( gr.Markdown(state0.model_name, visible=False), gr.Markdown(state1.model_name, visible=False)) share_js = """ function (a, b, c, d) { const captureElement = document.querySelector('#share-region-named'); html2canvas(captureElement) .then(canvas => { canvas.style.display = 'none' document.body.appendChild(canvas) return canvas }) .then(canvas => { const image = canvas.toDataURL('image/png') const a = document.createElement('a') a.setAttribute('download', 'chatbot-arena.png') a.setAttribute('href', image) a.click() canvas.remove() }); return [a, b, c, d]; } """ def share_click_igm(state0, state1, model_selector0, model_selector1, request: gr.Request): igm_logger.info(f"share (anony). ip: {get_ip(request)}") if state0 is not None and state1 is not None: vote_last_response_igm( [state0, state1], "share", [model_selector0, model_selector1], request ) def share_click_iem(state0, state1, model_selector0, model_selector1, request: gr.Request): iem_logger.info(f"share (anony). ip: {get_ip(request)}") if state0 is not None and state1 is not None: vote_last_response_iem( [state0, state1], "share", [model_selector0, model_selector1], request ) ## All Generation Gradio Interface class ImageStateIG: def __init__(self, model_name): self.conv_id = uuid.uuid4().hex self.model_name = model_name self.prompt = None self.output = None def dict(self): base = { "conv_id": self.conv_id, "model_name": self.model_name, "prompt": self.prompt } return base class ImageStateIE: def __init__(self, model_name): self.conv_id = uuid.uuid4().hex self.model_name = model_name self.source_prompt = None self.target_prompt = None self.instruct_prompt = None self.source_image = None self.output = None def dict(self): base = { "conv_id": self.conv_id, "model_name": self.model_name, "source_prompt": self.source_prompt, "target_prompt": self.target_prompt, "instruct_prompt": self.instruct_prompt } return base class VideoStateVG: def __init__(self, model_name): self.conv_id = uuid.uuid4().hex self.model_name = model_name self.prompt = None self.output = None def dict(self): base = { "conv_id": self.conv_id, "model_name": self.model_name, "prompt": self.prompt } return base def generate_ig(gen_func, state, text, model_name, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name: raise gr.Warning("Model name cannot be empty.") state = ImageStateIG(model_name) ip = get_ip(request) ig_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() generated_image = gen_func(text, model_name) state.prompt = text state.output = generated_image state.model_name = model_name if generated_image == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state, generated_image finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_ig_museum(gen_func, state, model_name, request: gr.Request): if not model_name: raise gr.Warning("Model name cannot be empty.") state = ImageStateIG(model_name) ip = get_ip(request) ig_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() generated_image, text = gen_func(model_name) state.prompt = text state.output = generated_image state.model_name = model_name yield state, generated_image, text finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_igm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name0: raise gr.Warning("Model name A cannot be empty.") if not model_name1: raise gr.Warning("Model name B cannot be empty.") state0 = ImageStateIG(model_name0) state1 = ImageStateIG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() # Remove ### Model (A|B): from model name model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) generated_image0, generated_image1 = gen_func(text, model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_image0 state1.output = generated_image1 state0.model_name = model_name0 state1.model_name = model_name1 if generated_image0 == '' and generated_image1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state0, state1, generated_image0, generated_image1 finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_igm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): if not model_name0: raise gr.Warning("Model name A cannot be empty.") if not model_name1: raise gr.Warning("Model name B cannot be empty.") state0 = ImageStateIG(model_name0) state1 = ImageStateIG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() # Remove ### Model (A|B): from model name model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) generated_image0, generated_image1, text = gen_func(model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_image0 state1.output = generated_image1 state0.model_name = model_name0 state1.model_name = model_name1 yield state0, state1, generated_image0, generated_image1, text finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_igm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") state0 = ImageStateIG(model_name0) state1 = ImageStateIG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = "" model_name1 = "" generated_image0, generated_image1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_image0 state1.output = generated_image1 state0.model_name = model_name0 state1.model_name = model_name1 if generated_image0 == '' and generated_image1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state0, state1, generated_image0, generated_image1, \ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_igm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): state0 = ImageStateIG(model_name0) state1 = ImageStateIG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() # model_name0 = re.sub(r"### Model A: ", "", model_name0) # model_name1 = re.sub(r"### Model B: ", "", model_name1) model_name0 = "" model_name1 = "" generated_image0, generated_image1, model_name0, model_name1, text = gen_func(model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_image0 state1.output = generated_image1 state0.model_name = model_name0 state1.model_name = model_name1 yield state0, state1, generated_image0, generated_image1, text,\ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{IMAGE_DIR}/generation/{state.conv_id}.jpg' os.makedirs(os.path.dirname(output_file), exist_ok=True) with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(output_file) def generate_ie(gen_func, state, source_text, target_text, instruct_text, source_image, model_name, request: gr.Request): if not source_text: raise gr.Warning("Source prompt cannot be empty.") if not target_text: raise gr.Warning("Target prompt cannot be empty.") if not instruct_text: raise gr.Warning("Instruction prompt cannot be empty.") if not source_image: raise gr.Warning("Source image cannot be empty.") if not model_name: raise gr.Warning("Model name cannot be empty.") state = ImageStateIE(model_name) ip = get_ip(request) ig_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() generated_image = gen_func(source_text, target_text, instruct_text, source_image, model_name) state.source_prompt = source_text state.target_prompt = target_text state.instruct_prompt = instruct_text state.source_image = source_image state.output = generated_image state.model_name = model_name if generated_image == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state, generated_image finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_ie_museum(gen_func, state, model_name, request: gr.Request): if not model_name: raise gr.Warning("Model name cannot be empty.") state = ImageStateIE(model_name) ip = get_ip(request) ig_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() source_image, generated_image, source_text, target_text, instruct_text = gen_func(model_name) state.source_prompt = source_text state.target_prompt = target_text state.instruct_prompt = instruct_text state.source_image = source_image state.output = generated_image state.model_name = model_name yield state, generated_image, source_image, source_text, target_text, instruct_text finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_iem(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): if not source_text: raise gr.Warning("Source prompt cannot be empty.") if not target_text: raise gr.Warning("Target prompt cannot be empty.") if not instruct_text: raise gr.Warning("Instruction prompt cannot be empty.") if not source_image: raise gr.Warning("Source image cannot be empty.") if not model_name0: raise gr.Warning("Model name A cannot be empty.") if not model_name1: raise gr.Warning("Model name B cannot be empty.") state0 = ImageStateIE(model_name0) state1 = ImageStateIE(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) generated_image0, generated_image1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) state0.source_prompt = source_text state0.target_prompt = target_text state0.instruct_prompt = instruct_text state0.source_image = source_image state0.output = generated_image0 state0.model_name = model_name0 state1.source_prompt = source_text state1.target_prompt = target_text state1.instruct_prompt = instruct_text state1.source_image = source_image state1.output = generated_image1 state1.model_name = model_name1 if generated_image0 == '' and generated_image1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state0, state1, generated_image0, generated_image1 finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_iem_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): if not model_name0: raise gr.Warning("Model name A cannot be empty.") if not model_name1: raise gr.Warning("Model name B cannot be empty.") state0 = ImageStateIE(model_name0) state1 = ImageStateIE(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) source_image, generated_image0, generated_image1, source_text, target_text, instruct_text = gen_func(model_name0, model_name1) state0.source_prompt = source_text state0.target_prompt = target_text state0.instruct_prompt = instruct_text state0.source_image = source_image state0.output = generated_image0 state0.model_name = model_name0 state1.source_prompt = source_text state1.target_prompt = target_text state1.instruct_prompt = instruct_text state1.source_image = source_image state1.output = generated_image1 state1.model_name = model_name1 yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_iem_annoy(gen_func, state0, state1, source_text, target_text, instruct_text, source_image, model_name0, model_name1, request: gr.Request): if not source_text: raise gr.Warning("Source prompt cannot be empty.") if not target_text: raise gr.Warning("Target prompt cannot be empty.") if not instruct_text: raise gr.Warning("Instruction prompt cannot be empty.") if not source_image: raise gr.Warning("Source image cannot be empty.") state0 = ImageStateIE(model_name0) state1 = ImageStateIE(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = "" model_name1 = "" generated_image0, generated_image1, model_name0, model_name1 = gen_func(source_text, target_text, instruct_text, source_image, model_name0, model_name1) state0.source_prompt = source_text state0.target_prompt = target_text state0.instruct_prompt = instruct_text state0.source_image = source_image state0.output = generated_image0 state0.model_name = model_name0 state1.source_prompt = source_text state1.target_prompt = target_text state1.instruct_prompt = instruct_text state1.source_image = source_image state1.output = generated_image1 state1.model_name = model_name1 if generated_image0 == '' and generated_image1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") yield state0, state1, generated_image0, generated_image1, \ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_iem_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): state0 = ImageStateIE(model_name0) state1 = ImageStateIE(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = "" model_name1 = "" source_image, generated_image0, generated_image1, source_text, target_text, instruct_text, model_name0, model_name1 = gen_func(model_name0, model_name1) state0.source_prompt = source_text state0.target_prompt = target_text state0.instruct_prompt = instruct_text state0.source_image = source_image state0.output = generated_image0 state0.model_name = model_name0 state1.source_prompt = source_text state1.target_prompt = target_text state1.instruct_prompt = instruct_text state1.source_image = source_image state1.output = generated_image1 state1.model_name = model_name1 yield state0, state1, generated_image0, generated_image1, source_image, source_text, target_text, instruct_text, \ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): src_img_file = f'{IMAGE_DIR}/edition/{state.conv_id}_src.jpg' os.makedirs(os.path.dirname(src_img_file), exist_ok=True) with open(src_img_file, 'w') as f: save_any_image(state.source_image, f) output_file = f'{IMAGE_DIR}/edition/{state.conv_id}_out.jpg' with open(output_file, 'w') as f: save_any_image(state.output, f) save_image_file_on_log_server(src_img_file) save_image_file_on_log_server(output_file) def generate_vg(gen_func, state, text, model_name, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name: raise gr.Warning("Model name cannot be empty.") state = VideoStateVG(model_name) ip = get_ip(request) vg_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() generated_video = gen_func(text, model_name) state.prompt = text state.output = generated_video state.model_name = model_name if generated_video == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") # yield state, generated_video finish_tstamp = time.time() with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) if model_name.startswith('fal'): r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) else: print("======== video shape: ========") print(state.output.shape) # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] if state.output.shape[-1] != 3: state.output = state.output.permute(0, 2, 3, 1) imageio.mimwrite(output_file, state.output, fps=8, quality=9) save_video_file_on_log_server(output_file) yield state, output_file def generate_vg_museum(gen_func, state, model_name, request: gr.Request): state = VideoStateVG(model_name) ip = get_ip(request) vg_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() generated_video, text = gen_func(model_name) state.prompt = text state.output = generated_video state.model_name = model_name # yield state, generated_video finish_tstamp = time.time() with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) save_video_file_on_log_server(output_file) yield state, output_file, text def generate_vgm(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") if not model_name0: raise gr.Warning("Model name A cannot be empty.") if not model_name1: raise gr.Warning("Model name B cannot be empty.") state0 = VideoStateVG(model_name0) state1 = VideoStateVG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() # Remove ### Model (A|B): from model name model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) generated_video0, generated_video1 = gen_func(text, model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_video0 state1.output = generated_video1 state0.model_name = model_name0 state1.model_name = model_name1 if generated_video0 == '' and generated_video1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") # yield state0, state1, generated_video0, generated_video1 print("====== model name =========") print(state0.model_name) print(state1.model_name) finish_tstamp = time.time() with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) print(state.model_name) if state.model_name.startswith('fal'): r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) else: print("======== video shape: ========") print(state.output) print(state.output.shape) # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] if state.output.shape[-1] != 3: state.output = state.output.permute(0, 2, 3, 1) imageio.mimwrite(output_file, state.output, fps=8, quality=9) save_video_file_on_log_server(output_file) yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4' def generate_vgm_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): state0 = VideoStateVG(model_name0) state1 = VideoStateVG(model_name1) ip = get_ip(request) igm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() # Remove ### Model (A|B): from model name model_name0 = re.sub(r"### Model A: ", "", model_name0) model_name1 = re.sub(r"### Model B: ", "", model_name1) generated_video0, generated_video1, text = gen_func(model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_video0 state1.output = generated_video1 state0.model_name = model_name0 state1.model_name = model_name1 # yield state0, state1, generated_video0, generated_video1 print("====== model name =========") print(state0.model_name) print(state1.model_name) finish_tstamp = time.time() with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) print(state.model_name) r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) save_video_file_on_log_server(output_file) yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text def generate_vgm_annoy(gen_func, state0, state1, text, model_name0, model_name1, request: gr.Request): if not text: raise gr.Warning("Prompt cannot be empty.") state0 = VideoStateVG(model_name0) state1 = VideoStateVG(model_name1) ip = get_ip(request) vgm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = "" model_name1 = "" generated_video0, generated_video1, model_name0, model_name1 = gen_func(text, model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_video0 state1.output = generated_video1 state0.model_name = model_name0 state1.model_name = model_name1 if generated_video0 == '' and generated_video1 == '': with open(get_nsfw_conv_log_filename(), "a") as fout: data = { "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) data = { "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_nsfw_conv_log_filename()) raise gr.Warning("Input prompt is blocked by the NSFW filter, please input safe content and try again!") # yield state0, state1, generated_video0, generated_video1, \ # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) if state.model_name.startswith('fal'): r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) else: print("======== video shape: ========") print(state.output.shape) # Assuming state.output has to be a tensor with shape [num_frames, height, width, num_channels] if state.output.shape[-1] != 3: state.output = state.output.permute(0, 2, 3, 1) imageio.mimwrite(output_file, state.output, fps=8, quality=9) save_video_file_on_log_server(output_file) yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', \ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False) def generate_vgm_annoy_museum(gen_func, state0, state1, model_name0, model_name1, request: gr.Request): state0 = VideoStateVG(model_name0) state1 = VideoStateVG(model_name1) ip = get_ip(request) vgm_logger.info(f"generate. ip: {ip}") start_tstamp = time.time() model_name0 = "" model_name1 = "" generated_video0, generated_video1, model_name0, model_name1, text = gen_func(model_name0, model_name1) state0.prompt = text state1.prompt = text state0.output = generated_video0 state1.output = generated_video1 state0.model_name = model_name0 state1.model_name = model_name1 # yield state0, state1, generated_video0, generated_video1, \ # gr.Markdown(f"### Model A: {model_name0}"), gr.Markdown(f"### Model B: {model_name1}") finish_tstamp = time.time() # logger.info(f"===output===: {output}") with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name0, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state0.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": model_name1, "gen_params": {}, "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": state1.dict(), "ip": get_ip(request), } fout.write(json.dumps(data) + "\n") append_json_item_on_log_server(data, get_conv_log_filename()) for i, state in enumerate([state0, state1]): output_file = f'{VIDEO_DIR}/generation/{state.conv_id}.mp4' os.makedirs(os.path.dirname(output_file), exist_ok=True) r = requests.get(state.output) with open(output_file, 'wb') as outfile: outfile.write(r.content) save_video_file_on_log_server(output_file) yield state0, state1, f'{VIDEO_DIR}/generation/{state0.conv_id}.mp4', f'{VIDEO_DIR}/generation/{state1.conv_id}.mp4', text,\ gr.Markdown(f"### Model A: {model_name0}", visible=False), gr.Markdown(f"### Model B: {model_name1}", visible=False)