3DGen-Arena / serve /inference.py
ZhangYuhan's picture
update serve
25646a8
raw
history blame
32.6 kB
## All Generation Gradio Interface
import uuid
import time
from .utils import *
from .vote_utils import t2s_logger, t2s_multi_logger, i2s_logger, i2s_multi_logger
from constants import IMAGE_DIR, OFFLINE_DIR, TEXT_PROMPT_PATH, IMAGE_PROMPT_PATH
with open(TEXT_PROMPT_PATH, 'r') as f:
prompt_list = json.load(f)
with open(IMAGE_PROMPT_PATH, 'r') as f:
lines = f.readlines()
image_list = {}
for line in lines:
idx = line.split('.png')[0].split('_')[-1]
url = line.split(')')[0].split('(')[-1]
image_list[eval(idx)] = url
class State:
def __init__(self,
model_name, i2s_mode=False, offline=False,
prompt=None, image=None, offline_idx=None,
normal_video=None , rgb_video=None,
evaluted_dims=0):
self.conv_id = uuid.uuid4().hex
self.model_name = model_name
self.i2s_mode = i2s_mode
self.offline = offline
self.prompt = prompt
self.image = image
self.offline_idx = offline_idx
# self.output = None
self.normal_video = normal_video
self.rgb_video = rgb_video
self.evaluted_dims = evaluted_dims
def dict(self):
base = {
"conv_id": self.conv_id,
"model_name": self.model_name,
"i2s_mode": self.i2s_mode,
"offline": self.offline,
"prompt": self.prompt,
"evaluted_dims": self.evaluted_dims,
}
if self.offline:
base['offline_idx'] = self.offline_idx
return base
# class StateI2S:
# def __init__(self, model_name):
# self.conv_id = uuid.uuid4().hex
# self.model_name = model_name
# self.image = None
# self.output = None
# def dict(self):
# base = {
# "conv_id": self.conv_id,
# "model_name": self.model_name,
# }
# return base
def sample_t2s_model(state_0, state_1, model_list):
model_name_0, model_name_1 = random.sample(eval(model_list), 2)
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=False)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=False)
state_0.model_name = model_name_0
state_0.i2s_mode = False
state_1.model_name = model_name_1
state_1.i2s_mode = False
return state_0, state_1, model_name_0, model_name_1
def sample_i2s_model(state_0, state_1, model_list):
model_name_0, model_name_1 = random.sample(eval(model_list), 2)
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=True)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=True)
state_0.model_name = model_name_0
state_0.i2s_mode = True
state_1.model_name = model_name_1
state_1.i2s_mode = True
return state_0, state_1, model_name_0, model_name_1
def sample_prompt(state, model_name):
if state is None:
state = State(model_name)
idx = random.randint(0, len(prompt_list)-1)
prompt = prompt_list[idx]
state.model_name = model_name
state.prompt = prompt
state.i2s_mode = False
state.offline = True,
state.offline_idx = idx
return state, prompt
def sample_prompt_side_by_side(state_0, state_1, model_name_0, model_name_1):
if state_0 is None:
state_0 = State(model_name_0)
if state_1 is None:
state_1 = State(model_name_1)
idx = random.randint(0, len(prompt_list)-1)
prompt = prompt_list[idx]
state_0.i2s_mode, state_1.i2s_mode = False, False
state_0.offline, state_1.offline = True, True
state_0.offline_idx, state_1.offline_idx = idx, idx
state_0.prompt, state_1.prompt = prompt, prompt
return state_0, state_1, prompt
def sample_image(state, model_name):
if state is None:
state = State(model_name)
idx = random.sample(image_list.keys(), 1)[0]
img_url = image_list[idx]
state.model_name = model_name
state.image = img_url
state.i2s_mode = True
state.offline = True,
state.offline_idx = idx
return state, img_url
def sample_image_side_by_side(state_0, state_1, model_name_0, model_name_1):
if state_0 is None:
state_0 = State(model_name_0)
if state_1 is None:
state_1 = State(model_name_1)
idx = random.sample(image_list.keys(), 1)[0]
img_url = image_list[idx]
state_0.i2s_mode, state_1.i2s_mode = True, True
state_0.offline, state_1.offline = True, True
state_0.offline_idx, state_1.offline_idx = idx, idx
state_0.image, state_1.image = img_url, img_url
return state_0, state_1, img_url
def generate_t2s(gen_func, render_func,
state,
text,
model_name,
request: gr.Request):
if not text or text.strip()=="":
raise gr.Warning("Prompt cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = State(model_name, i2s_mode=False, offline=False)
text = text.strip()
ip = get_ip(request)
t2s_logger.info(f"generate. ip: {ip}")
state.model_name = model_name
state.prompt = text
state.evaluted_dims = 0
try:
idx = prompt_list.index(text)
state.offline = True
state.offline_idx = idx
except:
state.offline = False
state.offline_idx = None
if state.offline and state.offline_idx:
start_time = time.time()
videos = gen_func(text, model_name, offline=state.offline, offline_idx=state.offline_idx)
# normal_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "normal", f"{state.offline_idx}.mp4")
# rgb_video = os.path.join(OFFLINE_DIR, "text2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
state.normal_video = videos['normal']
state.rgb_video = videos['rgb']
yield state, videos['normal'], videos['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data = {
"tstamp": round(finish_tstamp, 4),
"model": model_name,
"type": "offline",
"gen_params": {},
"state": state.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape = gen_func(text, model_name)
generate_time = time.time() - start_time
videos = render_func(shape, model_name)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state.normal_video = videos['normal']
state.rgb_video = videos['rgb']
yield state, videos['normal'], videos['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data = {
"tstamp": round(finish_tstamp, 4),
"model": model_name,
"type": "online",
"gen_params": {},
"state": state.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(output_file)
def generate_t2s_multi(gen_func, render_func,
state_0, state_1,
text,
model_name_0, model_name_1,
request: gr.Request):
if not text or text.strip()=="":
raise gr.Warning("Prompt cannot be empty.")
if not model_name_0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name_1:
raise gr.Warning("Model name B cannot be empty.")
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=False, offline=False)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=False, offline=False)
text = text.strip()
ip = get_ip(request)
t2s_multi_logger.info(f"generate. ip: {ip}")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.prompt, state_1.prompt = text, text
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
try:
idx = prompt_list.index(text)
state_0.offline, state_1.offline = True, True
state_0.offline_idx, state_1.offline_idx = idx, idx
except:
state_0.offline, state_1.offline = False, False
state_0.offline_idx, state_1.offline_idx = None, None
if state_0.offline and state_0.offline_idx:
start_time = time.time()
videos_0, videos_1 = gen_func(text, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx)
# normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
# rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
# normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
# rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1,videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "offline",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "offline",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape_0, shape_1 = gen_func(text, model_name_0, model_name_1)
generate_time = time.time() - start_time
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "online",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "online",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data_0) + "\n")
fout.write(json.dumps(data_1) + "\n")
append_json_item_on_log_server(data_0, get_conv_log_filename())
append_json_item_on_log_server(data_1, get_conv_log_filename())
# for i, state in enumerate([state_0, state_1]):
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(output_file)
def generate_t2s_multi_annoy(gen_func, render_func,
state_0, state_1,
text,
model_name_0, model_name_1,
request: gr.Request):
if not text or text.strip()=="":
raise gr.Warning("Prompt cannot be empty.")
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=False, offline=False)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=False, offline=False)
text = text.strip()
ip = get_ip(request)
t2s_multi_logger.info(f"generate. ip: {ip}")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.prompt, state_1.prompt = text, text
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
try:
idx = prompt_list.index(text)
state_0.offline, state_1.offline = True, True
state_0.offline_idx, state_1.offline_idx = idx, idx
except:
state_0.offline, state_1.offline = False, False
state_0.offline_idx, state_1.offline_idx = None, None
if state_0.offline and state_0.offline_idx:
start_time = time.time()
videos_0, videos_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1,
i2s_model=False, offline=state_0.offline, offline_idx=state_0.offline_idx)
# normal_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
# rgb_video_0 = os.path.join(OFFLINE_DIR, "text2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
# normal_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
# rgb_video_1 = os.path.join(OFFLINE_DIR, "text2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "offline",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "offline",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape_0, shape_1, model_name_0, model_name_1 = gen_func(text, model_name_0, model_name_1, i2s_mode=False)
generate_time = time.time() - start_time
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0[0], videos_0[1], videos_1[0], videos_1[1], \
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "online",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "online",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data_0) + "\n")
fout.write(json.dumps(data_1) + "\n")
append_json_item_on_log_server(data_0, get_conv_log_filename())
append_json_item_on_log_server(data_1, get_conv_log_filename())
# for i, state in enumerate([state_0, state_1]):
# output_file = f'{IMAGE_DIR}/text2shape/{state.conv_id}.png'
# os.makedirs(os.path.dirname(output_file), exist_ok=True)
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(output_file)
def generate_i2s(gen_func, render_func, state, image, model_name, request: gr.Request):
if image is None:
raise gr.Warning("Image cannot be empty.")
if not model_name:
raise gr.Warning("Model name cannot be empty.")
if state is None:
state = State(model_name, i2s_mode=True, offline=False)
ip = get_ip(request)
t2s_logger.info(f"generate. ip: {ip}")
state.model_name = model_name
state.image = image
state.evaluted_dims = 0
if state.offline and state.offline_idx:
start_time = time.time()
videos = gen_func(image, model_name, offline=state.offline, offline_idx=state.offline_idx)
# normal_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "normal", f"{state.offline_idx}.mp4")
# rgb_video = os.path.join(OFFLINE_DIR, "image2shape", model_name, "rgb", f"{state.offline_idx}.mp4")
state.normal_video = videos['normal']
state.rgb_video = videos['rgb']
yield state, videos['normal'], videos['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data = {
"tstamp": round(finish_tstamp, 4),
"model": model_name,
"type": "offline",
"gen_params": {},
"state": state.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape = gen_func(image, model_name)
generate_time = time.time() - start_time
videos = render_func(shape, model_name)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state.normal_video = videos['normal']
state.rgb_video = videos['rgb']
yield state, videos['normal'], videos['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data = {
"tstamp": round(finish_tstamp, 4),
"model": model_name,
"type": "online",
"gen_params": {},
"state": state.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data) + "\n")
append_json_item_on_log_server(data, get_conv_log_filename())
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
# with open(src_img_file, 'w') as f:
# state.source_image.save(f, 'PNG')
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(src_img_file)
# save_image_file_on_log_server(output_file)
def generate_i2s_multi(gen_func, render_func,
state_0, state_1,
image,
model_name_0, model_name_1,
request: gr.Request):
if image is None:
raise gr.Warning("Image cannot be empty.")
if not model_name_0:
raise gr.Warning("Model name A cannot be empty.")
if not model_name_1:
raise gr.Warning("Model name B cannot be empty.")
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=True, offline=False)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=True, offline=False)
ip = get_ip(request)
t2s_multi_logger.info(f"generate. ip: {ip}")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.image, state_1.image = image, image
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
if state_0.offline and state_0.offline_idx:
start_time = time.time()
videos_0, videos_1 = gen_func(image, model_name_0, model_name_1, offline=state_0.offline, offline_idx=state_0.offline_idx)
# normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
# rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
# normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
# rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "offline",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "offline",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1)
generate_time = time.time() - start_time
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb']
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "online",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "online",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data_0) + "\n")
fout.write(json.dumps(data_1) + "\n")
append_json_item_on_log_server(data_0, get_conv_log_filename())
append_json_item_on_log_server(data_1, get_conv_log_filename())
# for i, state in enumerate([state_0, state_1]):
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
# with open(src_img_file, 'w') as f:
# state.source_image.save(f, 'PNG')
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(src_img_file)
# save_image_file_on_log_server(output_file)
def generate_i2s_multi_annoy(gen_func, render_func,
state_0, state_1,
image,
model_name_0, model_name_1,
request: gr.Request):
if image is None:
raise gr.Warning("Image cannot be empty.")
if state_0 is None:
state_0 = State(model_name_0, i2s_mode=True, offline=False)
if state_1 is None:
state_1 = State(model_name_1, i2s_mode=True, offline=False)
ip = get_ip(request)
t2s_multi_logger.info(f"generate. ip: {ip}")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.image, state_1.image = image, image
state_0.evaluted_dims, state_1.evaluted_dims = 0, 0
if state_0.offline and state_0.offline_idx and state_1.offline and state_1.offline_idx:
start_time = time.time()
videos_0, videos_1, model_name_0, model_name_1 = gen_func(image, model_name_0, model_name_1,
i2s_model=True, offline=state_0.offline, offline_idx=state_0.offline_idx)
# normal_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "normal", f"{state_0.offline_idx}.mp4")
# rgb_video_0 = os.path.join(OFFLINE_DIR, "image2shape", model_name_0, "rgb", f"{state_0.offline_idx}.mp4")
# normal_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "normal", f"{state_1.offline_idx}.mp4")
# rgb_video_1 = os.path.join(OFFLINE_DIR, "image2shape", model_name_1, "rgb", f"{state_1.offline_idx}.mp4")
state_0.model_name, state_1.model_name = model_name_0, model_name_1
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "offline",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "offline",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"ip": get_ip(request),
}
else:
start_time = time.time()
shape_0, shape_1 = gen_func(image, model_name_0, model_name_1, i2s_model=True)
generate_time = time.time() - start_time
videos_0, videos_1 = render_func(shape_0, model_name_0, shape_1, model_name_1)
finish_time = time.time()
render_time = finish_time - start_time - generate_time
state_0.normal_video, state_0.rgb_video = videos_0['normal'], videos_0['rgb']
state_1.normal_video, state_1.rgb_video = videos_1['normal'], videos_1['rgb']
yield state_0, state_1, videos_0['normal'], videos_0['rgb'], videos_1['normal'], videos_1['rgb'], \
gr.Markdown(f"### Model A: {model_name_0}"), gr.Markdown(f"### Model B: {model_name_1}")
finish_tstamp = time.time()
# logger.info(f"===output===: {output}")
data_0 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_0,
"type": "online",
"gen_params": {},
"state": state_0.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
data_1 = {
"tstamp": round(finish_tstamp, 4),
"model": model_name_1,
"type": "online",
"gen_params": {},
"state": state_1.dict(),
"start": round(start_time, 4),
"finish": round(finish_tstamp, 4),
"time": round(finish_time - start_time, 4),
"generate_time": round(generate_time, 4),
"render_time": round(render_time, 4),
"ip": get_ip(request),
}
with open(get_conv_log_filename(), "a") as fout:
fout.write(json.dumps(data_0) + "\n")
fout.write(json.dumps(data_1) + "\n")
append_json_item_on_log_server(data_0, get_conv_log_filename())
append_json_item_on_log_server(data_1, get_conv_log_filename())
# for i, state in enumerate([state_0, state_1]):
# src_img_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_src.png'
# os.makedirs(os.path.dirname(src_img_file), exist_ok=True)
# with open(src_img_file, 'w') as f:
# state.source_image.save(f, 'PNG')
# output_file = f'{IMAGE_DIR}/image2shape/{state.conv_id}_out.png'
# with open(output_file, 'w') as f:
# state.output.save(f, 'PNG')
# save_image_file_on_log_server(src_img_file)
# save_image_file_on_log_server(output_file)