import os import json import time import kiui from typing import List import replicate import subprocess from gradio_client import Client # from .client import Gau2Mesh_client from constants import OFFLINE_GIF_DIR, REPLICATE_API_TOKEN # os.environ("REPLICATE_API_TOKEN", "yourKey") class BaseModelWorker: def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None ): self.model_name = model_name self.i2s_model = i2s_model self.online_model = online_model self.model_api = model_api self.urls_json = None urls_json_path = os.path.join(OFFLINE_GIF_DIR, f"{model_name}.json") if os.path.exists(urls_json_path): with open(urls_json_path, 'r') as f: self.urls_json = json.load(f) def check_online(self) -> bool: if self.online_model and not self.model: return True else: return False def load_offline(self, offline: bool, offline_idx): ## offline if offline and str(offline_idx) in self.urls_json.keys(): return self.urls_json[str(offline_idx)] else: return None def inference(self, prompt): pass def render(self, shape, rgb_on=True, normal_on=True): pass class HuggingfaceApiWorker(BaseModelWorker): def __init__( self, model_name: str, i2s_model: bool, online_model: bool, model_api: str, ): super().__init__( model_name, i2s_model, online_model, model_api, ) # class PointE_Worker(BaseModelWorker): # def __init__(self, # model_name: str, # i2s_model: bool, # online_model: bool, # model_api: str): # super().__init__(model_name, i2s_model, online_model, model_api) # class TriplaneGaussian(BaseModelWorker): # def __init__(self, model_name: str, i2s_model: bool, online_model: bool, model_api: str = None): # super().__init__(model_name, i2s_model, online_model, model_api) # class LGM_Worker(BaseModelWorker): # def __init__(self, # model_name: str, # i2s_model: bool, # online_model: bool, # model_api: str = "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", # ): # super().__init__(model_name, i2s_model, online_model, model_api) # self.model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) # def inference(self, image): # output = self.model_client.run( # self.model_api, # input={"input_image": image} # ) # #=> .mp4 .ply # return output[1] # def render(self, shape): # mesh = Gau2Mesh_client.run(shape) # path_normal = "" # cmd_normal = f"python -m ..kiuikit.kiui.render {mesh} --save {path_normal} \ # --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode normal" # subprocess.run(cmd_normal, shell=True, check=True) # path_rgb = "" # cmd_rgb = f"python -m ..kiuikit.kiui.render {mesh} --save {path_rgb} \ # --wogui --H 512 --W 512 --radius 3 --elevation 0 --num_azimuth 40 --front_dir='+z' --mode rgb" # subprocess.run(cmd_rgb, shell=True, check=True) # return path_normal, path_rgb # class V3D_Worker(BaseModelWorker): # def __init__(self, # model_name: str, # i2s_model: bool, # online_model: bool, # model_api: str = None): # super().__init__(model_name, i2s_model, online_model, model_api) # model = 'LGM' # # model = 'TriplaneGaussian' # folder = 'glbs_full' # form = 'glb' # pose = '+z' # pair = ('OpenLRM', 'meshes', 'obj', '-y') # pair = ('TriplaneGaussian', 'glbs_full', 'glb', '-y') # pair = ('LGM', 'glbs_full', 'glb', '+z') if __name__=="__main__": # input = { # "input_image": "https://replicate.delivery/pbxt/KN0hQI9pYB3NOpHLqktkkQIblwpXt0IG7qI90n5hEnmV9kvo/bird_rgba.png", # } # print("Start...") # model_client = replicate.Client(api_token=REPLICATE_API_TOKEN) # output = model_client.run( # "camenduru/lgm:d2870893aa115773465a823fe70fd446673604189843f39a99642dd9171e05e2", # input=input # ) # print("output: ", output) #=> ['https://replicate.delivery/pbxt/toffawxRE3h6AUofI9sPtiAsoYI0v73zuGDZjZWBWAPzHKSlA/gradio_output.mp4', 'https://replicate.delivery/pbxt/oSn1XPfoJuw2UKOUIAue2iXeT7aXncVjC4QwHKU5W5x0HKSlA/gradio_output.ply'] output = ['https://replicate.delivery/pbxt/RPSTEes37lzAJav3jy1lPuzizm76WGU4IqDcFcAMxhQocjUJA/gradio_output.mp4', 'https://replicate.delivery/pbxt/2Vy8yrPO3PYiI1YJBxPXAzryR0SC0oyqW3XKPnXiuWHUuRqE/gradio_output.ply'] to_mesh_client = Client("https://dylanebert-splat-to-mesh.hf.space/", upload_files=True, download_files=True) mesh = to_mesh_client.predict(output[1], api_name="/run") print(mesh)