3DGen-Arena / model /model_worker.py
ZhangYuhan's picture
update serve
6c1c5e7
raw
history blame
5.23 kB
import os
import json
import time
import kiui
from typing import List
import replicate
import subprocess
from gradio_client import Client
# from .client import Gau2Mesh_client
from constants import OFFLINE_GIF_DIR, REPLICATE_API_TOKEN
# os.environ("REPLICATE_API_TOKEN", "yourKey")
class BaseModelWorker:
def __init__(self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str = None
):
self.model_name = model_name
self.i2s_model = i2s_model
self.online_model = online_model
self.model_api = model_api
self.urls_json = None
urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json")
if os.path.exists(urls_json_path):
with open(urls_json_path, 'r') as f:
self.urls_json = json.load(f)
def check_online(self) -> bool:
if self.online_model and not self.model:
return True
else:
return False
def load_offline(self, offline: bool, offline_idx):
## offline
if offline and str(offline_idx) in self.urls_json.keys():
return self.urls_json[str(offline_idx)]
else:
return None
def inference(self, prompt):
pass
def render(self, shape, rgb_on=True, normal_on=True):
pass
class HuggingfaceApiWorker(BaseModelWorker):
def __init__(
self,
model_name: str,
i2s_model: bool,
online_model: bool,
model_api: str,
):
super().__init__(
model_name,
i2s_model,
online_model,
model_api,
)
# class PointE_Worker(BaseModelWorker):
# def __init__(self,
# model_name: str,
# i2s_model: bool,
# online_model: bool,
# model_api: str):
# super().__init__(model_name, i2s_model, online_model, model_api)
# class TriplaneGaussian(BaseModelWorker):
# def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None):
# super().__init__(model_name, i2s_model, online_model, model_api)
# class LGM_Worker(BaseModelWorker):
# def __init__(self,
# model_name: str,
# i2s_model: bool,
# online_model: bool,
# model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
# ):
# super().__init__(model_name, i2s_model, online_model, model_api)
# self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
# def inference(self, image):
# output = self.model_client.run(
# self.model_api,
# input={"input_image": image}
# )
# #=> .mp4 .ply
# return output[1]
# def render(self, shape):
# mesh = Gau2Mesh_client.run(shape)
# path_normal = ""
# cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \
# --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal"
# subprocess.run(cmd_normal, shell=True, check=True)
# path_rgb = ""
# cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \
# --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb"
# subprocess.run(cmd_rgb, shell=True, check=True)
# return path_normal, path_rgb
# class V3D_Worker(BaseModelWorker):
# def __init__(self,
# model_name: str,
# i2s_model: bool,
# online_model: bool,
# model_api: str = None):
# super().__init__(model_name, i2s_model, online_model, model_api)
# model = 'LGM'
# # model = 'TriplaneGaussian'
# folder = 'glbs_full'
# form = 'glb'
# pose = '+z'
# pair = ('OpenLRM', 'meshes', 'obj', '-y')
# pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y')
# pair = ('LGM', 'glbs_full', 'glb', '+z')
if __name__=="__main__":
# input = {
# "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png",
# }
# print("Start...")
# model_client = replicate.Client(api_token=REPLICATE_API_TOKEN)
# output = model_client.run(
# "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2",
# input=input
# )
# print("output: ", output)
#=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply']
output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply']
to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True)
mesh = to_mesh_client.predict(output[1], api_name="/run")
print(mesh)