Spaces:
Runtime error
Runtime error
Create worker_runpod.py
Browse files- worker_runpod.py +138 -0
worker_runpod.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, requests, random, time, runpod
|
2 |
+
from urllib.parse import urlsplit
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import asyncio
|
9 |
+
import execution
|
10 |
+
import server
|
11 |
+
loop = asyncio.new_event_loop()
|
12 |
+
asyncio.set_event_loop(loop)
|
13 |
+
server_instance = server.PromptServer(loop)
|
14 |
+
execution.PromptQueue(server)
|
15 |
+
|
16 |
+
from nodes import load_custom_node
|
17 |
+
from nodes import NODE_CLASS_MAPPINGS
|
18 |
+
|
19 |
+
load_custom_node("/content/ComfyUI/custom_nodes/ComfyUI-CogVideoXWrapper")
|
20 |
+
load_custom_node("/content/ComfyUI/custom_nodes/ComfyUI-VideoHelperSuite")
|
21 |
+
load_custom_node("/content/ComfyUI/custom_nodes/ComfyUI-KJNodes")
|
22 |
+
|
23 |
+
LoadImage = NODE_CLASS_MAPPINGS["LoadImage"]()
|
24 |
+
ImageResizeKJ = NODE_CLASS_MAPPINGS["ImageResizeKJ"]()
|
25 |
+
CogVideoImageEncode = NODE_CLASS_MAPPINGS["CogVideoImageEncode"]()
|
26 |
+
CogVideoLoraSelect = NODE_CLASS_MAPPINGS["CogVideoLoraSelect"]()
|
27 |
+
DownloadAndLoadCogVideoModel = NODE_CLASS_MAPPINGS["DownloadAndLoadCogVideoModel"]()
|
28 |
+
CogVideoTextEncode = NODE_CLASS_MAPPINGS["CogVideoTextEncode"]()
|
29 |
+
CLIPLoader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
|
30 |
+
CogVideoSampler = NODE_CLASS_MAPPINGS["CogVideoSampler"]()
|
31 |
+
CogVideoDecode = NODE_CLASS_MAPPINGS["CogVideoDecode"]()
|
32 |
+
VHS_VideoCombine = NODE_CLASS_MAPPINGS["VHS_VideoCombine"]()
|
33 |
+
|
34 |
+
with torch.inference_mode():
|
35 |
+
lora = CogVideoLoraSelect.getlorapath("orbit_up_lora_weights.safetensors", 1.0, fuse_lora=True)[0]
|
36 |
+
pipeline = DownloadAndLoadCogVideoModel.loadmodel("THUDM/CogVideoX-5b-I2V", "bf16", fp8_transformer="disabled", compile="disabled", enable_sequential_cpu_offload=False, lora=lora)[0]
|
37 |
+
clip = CLIPLoader.load_clip("t5xxl_fp16.safetensors", type="sd3")[0]
|
38 |
+
|
39 |
+
def download_file(url, save_dir, file_name):
|
40 |
+
os.makedirs(save_dir, exist_ok=True)
|
41 |
+
original_file_name = url.split('/')[-1]
|
42 |
+
_, original_file_extension = os.path.splitext(original_file_name)
|
43 |
+
file_path = os.path.join(save_dir, file_name + original_file_extension)
|
44 |
+
response = requests.get(url)
|
45 |
+
response.raise_for_status()
|
46 |
+
with open(file_path, 'wb') as file:
|
47 |
+
file.write(response.content)
|
48 |
+
return file_path
|
49 |
+
|
50 |
+
@torch.inference_mode()
|
51 |
+
def generate(input):
|
52 |
+
values = input["input"]
|
53 |
+
|
54 |
+
input_image=values['input_image_check']
|
55 |
+
input_image=download_file(url=input_image, save_dir='/content/ComfyUI/input', file_name='input_image')
|
56 |
+
prompt = values['prompt']
|
57 |
+
negative_prompt = values['negative_prompt']
|
58 |
+
seed = values['seed']
|
59 |
+
steps = values['steps']
|
60 |
+
cfg = values['cfg']
|
61 |
+
|
62 |
+
if seed == 0:
|
63 |
+
random.seed(int(time.time()))
|
64 |
+
seed = random.randint(0, 18446744073709551615)
|
65 |
+
|
66 |
+
positive = CogVideoTextEncode.process(clip, prompt, strength=1.0, force_offload=True)[0]
|
67 |
+
negative = CogVideoTextEncode.process(clip, negative_prompt, strength=1.0, force_offload=True)[0]
|
68 |
+
|
69 |
+
image, _ = LoadImage.load_image(input_image)
|
70 |
+
image = ImageResizeKJ.resize(image, width=720, height=480, keep_proportion=False, upscale_method="lanczos", divisible_by=16, crop="center")[0]
|
71 |
+
image_cond_latents = CogVideoImageEncode.encode(pipeline, image, chunk_size=16, enable_tiling=True)[0]
|
72 |
+
samples = CogVideoSampler.process(pipeline, positive, negative, steps, cfg, seed, height=480, width=720, num_frames=49, scheduler="CogVideoXDPMScheduler", denoise_strength=1.0, image_cond_latents=image_cond_latents)
|
73 |
+
frames = CogVideoDecode.decode(samples[0], samples[1], enable_vae_tiling=True, tile_sample_min_height=240, tile_sample_min_width=360, tile_overlap_factor_height=0.2, tile_overlap_factor_width=0.2, auto_tile_size=True)[0]
|
74 |
+
|
75 |
+
out_video = VHS_VideoCombine.combine_video(images=frames, frame_rate=8, loop_count=0, filename_prefix="CogVideoX-I2V", format="video/h264-mp4", save_output=True)
|
76 |
+
source = out_video["result"][0][1][1]
|
77 |
+
destination = f"/content/ComfyUI/output/cogvideox-5b-i2v-dimensionx-{seed}-tost.mp4"
|
78 |
+
shutil.move(source, destination)
|
79 |
+
|
80 |
+
result = f"/content/ComfyUI/output/cogvideox-5b-i2v-dimensionx-{seed}-tost.mp4"
|
81 |
+
try:
|
82 |
+
notify_uri = values['notify_uri']
|
83 |
+
del values['notify_uri']
|
84 |
+
notify_token = values['notify_token']
|
85 |
+
del values['notify_token']
|
86 |
+
discord_id = values['discord_id']
|
87 |
+
del values['discord_id']
|
88 |
+
if(discord_id == "discord_id"):
|
89 |
+
discord_id = os.getenv('com_camenduru_discord_id')
|
90 |
+
discord_channel = values['discord_channel']
|
91 |
+
del values['discord_channel']
|
92 |
+
if(discord_channel == "discord_channel"):
|
93 |
+
discord_channel = os.getenv('com_camenduru_discord_channel')
|
94 |
+
discord_token = values['discord_token']
|
95 |
+
del values['discord_token']
|
96 |
+
if(discord_token == "discord_token"):
|
97 |
+
discord_token = os.getenv('com_camenduru_discord_token')
|
98 |
+
job_id = values['job_id']
|
99 |
+
del values['job_id']
|
100 |
+
default_filename = os.path.basename(result)
|
101 |
+
with open(result, "rb") as file:
|
102 |
+
files = {default_filename: file.read()}
|
103 |
+
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
|
104 |
+
response = requests.post(
|
105 |
+
f"https://discord.com/api/v9/channels/{discord_channel}/messages",
|
106 |
+
data=payload,
|
107 |
+
headers={"Authorization": f"Bot {discord_token}"},
|
108 |
+
files=files
|
109 |
+
)
|
110 |
+
response.raise_for_status()
|
111 |
+
result_url = response.json()['attachments'][0]['url']
|
112 |
+
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
|
113 |
+
web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
|
114 |
+
web_notify_token = os.getenv('com_camenduru_web_notify_token')
|
115 |
+
if(notify_uri == "notify_uri"):
|
116 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
117 |
+
else:
|
118 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
119 |
+
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
120 |
+
return {"jobId": job_id, "result": result_url, "status": "DONE"}
|
121 |
+
except Exception as e:
|
122 |
+
error_payload = {"jobId": job_id, "status": "FAILED"}
|
123 |
+
try:
|
124 |
+
if(notify_uri == "notify_uri"):
|
125 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
126 |
+
else:
|
127 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
128 |
+
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
129 |
+
except:
|
130 |
+
pass
|
131 |
+
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
|
132 |
+
finally:
|
133 |
+
if os.path.exists(result):
|
134 |
+
os.remove(result)
|
135 |
+
if os.path.exists(input_image):
|
136 |
+
os.remove(input_image)
|
137 |
+
|
138 |
+
runpod.serverless.start({"handler": generate})
|