Spaces:
Runtime error
Runtime error
Create worker_runpod.py
Browse files- worker_runpod.py +187 -0
worker_runpod.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, json, requests, random, runpod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
from sam2.build_sam import build_sam2_video_predictor
|
7 |
+
import shutil
|
8 |
+
import subprocess
|
9 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
10 |
+
from torchvision.transforms import functional as F
|
11 |
+
|
12 |
+
def detect_body_keypoints(frame):
|
13 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
14 |
+
img_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to('cuda')
|
15 |
+
with torch.no_grad():
|
16 |
+
prediction = body_detector(img_tensor)[0]
|
17 |
+
if len(prediction['boxes']) > 0:
|
18 |
+
best_box = prediction['boxes'][prediction['scores'].argmax()].cpu().numpy()
|
19 |
+
x1, y1, x2, y2 = best_box
|
20 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
21 |
+
width, height = x2 - x1, y2 - y1
|
22 |
+
offset_x, offset_y = width * 0.2, height * 0.2
|
23 |
+
keypoints = np.array([
|
24 |
+
[center_x, center_y],
|
25 |
+
[center_x - offset_x, center_y],
|
26 |
+
[center_x + offset_x, center_y],
|
27 |
+
[center_x, center_y - offset_y],
|
28 |
+
[center_x, center_y + offset_y],
|
29 |
+
], dtype=np.float32)
|
30 |
+
keypoints[:, 0] = np.clip(keypoints[:, 0], x1, x2)
|
31 |
+
keypoints[:, 1] = np.clip(keypoints[:, 1], y1, y2)
|
32 |
+
return keypoints
|
33 |
+
else:
|
34 |
+
height, width = frame.shape[:2]
|
35 |
+
center = np.array([[width // 2, height // 2]], dtype=np.float32)
|
36 |
+
return np.tile(center, (5, 1))
|
37 |
+
|
38 |
+
def remove_background(frame, mask, bg_color):
|
39 |
+
mask = mask.squeeze()
|
40 |
+
if mask.dtype == bool:
|
41 |
+
mask = mask.astype(np.uint8) * 255
|
42 |
+
else:
|
43 |
+
mask = (mask > 0).astype(np.uint8) * 255
|
44 |
+
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
|
45 |
+
bg = np.full(frame.shape, bg_color, dtype=np.uint8)
|
46 |
+
fg = cv2.bitwise_and(frame, frame, mask=mask)
|
47 |
+
bg = cv2.bitwise_and(bg, bg, mask=cv2.bitwise_not(mask))
|
48 |
+
result = cv2.add(fg, bg)
|
49 |
+
result = clean_hair_area(frame, result, mask, bg_color)
|
50 |
+
return result
|
51 |
+
|
52 |
+
def clean_hair_area(original, processed, mask, bg_color):
|
53 |
+
kernel = np.ones((5, 5), np.uint8)
|
54 |
+
dilated_mask = cv2.dilate(mask, kernel, iterations=2)
|
55 |
+
hair_edge_mask = cv2.subtract(dilated_mask, mask)
|
56 |
+
bg_sample = cv2.bitwise_and(original, original, mask=cv2.bitwise_not(dilated_mask))
|
57 |
+
bg_average = cv2.mean(bg_sample)[:3]
|
58 |
+
color_distances = np.sqrt(np.sum((original.astype(np.float32) - bg_average) ** 2, axis=2))
|
59 |
+
color_distances = (color_distances - color_distances.min()) / (color_distances.max() - color_distances.min())
|
60 |
+
alpha = (1 - color_distances) * (hair_edge_mask / 255.0)
|
61 |
+
alpha = np.clip(alpha, 0, 1)
|
62 |
+
for c in range(3):
|
63 |
+
processed[:, :, c] = processed[:, :, c] * (1 - alpha) + bg_color[c] * alpha
|
64 |
+
return processed
|
65 |
+
|
66 |
+
with torch.inference_mode():
|
67 |
+
checkpoint = 'sam2_hiera_large.pt'
|
68 |
+
model_cfg = 'sam2_hiera_l.yaml'
|
69 |
+
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
|
70 |
+
body_detector = fasterrcnn_resnet50_fpn(pretrained=True)
|
71 |
+
body_detector.eval()
|
72 |
+
body_detector.to("cuda")
|
73 |
+
|
74 |
+
def download_file(url, save_dir):
|
75 |
+
os.makedirs(save_dir, exist_ok=True)
|
76 |
+
file_name = url.split('/')[-1]
|
77 |
+
file_path = os.path.join(save_dir, file_name)
|
78 |
+
response = requests.get(url)
|
79 |
+
response.raise_for_status()
|
80 |
+
with open(file_path, 'wb') as file:
|
81 |
+
file.write(response.content)
|
82 |
+
return file_path
|
83 |
+
|
84 |
+
@torch.inference_mode()
|
85 |
+
def generate(input):
|
86 |
+
values = input["input"]
|
87 |
+
|
88 |
+
input_video = values['input_video']
|
89 |
+
input_video = download_file(url=input_video, save_dir='/content')
|
90 |
+
bg_color = values['bg_color']
|
91 |
+
|
92 |
+
bg_color = tuple(int(bg_color.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))[::-1]
|
93 |
+
frames_dir = "/content/frames"
|
94 |
+
if os.path.exists(frames_dir):
|
95 |
+
shutil.rmtree(frames_dir)
|
96 |
+
os.makedirs(frames_dir, exist_ok=True)
|
97 |
+
ffmpeg_cmd = ["ffmpeg", "-i", str(input_video), "-q:v", "2", "-start_number", "0",f"{frames_dir}/%05d.jpg"]
|
98 |
+
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, check=True)
|
99 |
+
frame_names = [p for p in os.listdir(frames_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))]
|
100 |
+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
|
101 |
+
inference_state = predictor.init_state(video_path=frames_dir)
|
102 |
+
first_frame_path = os.path.join(frames_dir, frame_names[0])
|
103 |
+
first_frame = cv2.imread(first_frame_path)
|
104 |
+
keypoints = detect_body_keypoints(first_frame)
|
105 |
+
_, out_obj_ids, out_mask_logits = predictor.add_new_points(inference_state=inference_state, frame_idx=0, obj_id=1, points=keypoints, labels=np.ones(len(keypoints), dtype=np.int32))
|
106 |
+
video_segments = {}
|
107 |
+
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
|
108 |
+
video_segments[out_frame_idx] = {
|
109 |
+
out_obj_id: out_mask_logits[i].cpu().numpy()
|
110 |
+
for i, out_obj_id in enumerate(out_obj_ids)
|
111 |
+
}
|
112 |
+
output_frames_dir = '/content/output_frames'
|
113 |
+
os.makedirs(output_frames_dir, exist_ok=True)
|
114 |
+
frame_count = 0
|
115 |
+
for out_frame_idx in range(len(frame_names)):
|
116 |
+
frame_path = os.path.join(frames_dir, frame_names[out_frame_idx])
|
117 |
+
frame = cv2.imread(frame_path)
|
118 |
+
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
|
119 |
+
frame_with_bg_removed = remove_background(frame, out_mask, bg_color)
|
120 |
+
output_frame_path = os.path.join(output_frames_dir, f"{out_frame_idx:05d}.jpg")
|
121 |
+
cv2.imwrite(output_frame_path, frame_with_bg_removed)
|
122 |
+
frame_count += 1
|
123 |
+
output_video_path = '/content/sam2_rm_bg_tost.mp4'
|
124 |
+
final_video_cmd = ["ffmpeg", "-y", "-framerate", "30", "-i", f"{output_frames_dir}/%05d.jpg", "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video_path]
|
125 |
+
result = subprocess.run(final_video_cmd, capture_output=True, text=True, check=True)
|
126 |
+
|
127 |
+
result = "/content/sam2_rm_bg_tost.mp4"
|
128 |
+
try:
|
129 |
+
notify_uri = values['notify_uri']
|
130 |
+
del values['notify_uri']
|
131 |
+
notify_token = values['notify_token']
|
132 |
+
del values['notify_token']
|
133 |
+
discord_id = values['discord_id']
|
134 |
+
del values['discord_id']
|
135 |
+
if(discord_id == "discord_id"):
|
136 |
+
discord_id = os.getenv('com_camenduru_discord_id')
|
137 |
+
discord_channel = values['discord_channel']
|
138 |
+
del values['discord_channel']
|
139 |
+
if(discord_channel == "discord_channel"):
|
140 |
+
discord_channel = os.getenv('com_camenduru_discord_channel')
|
141 |
+
discord_token = values['discord_token']
|
142 |
+
del values['discord_token']
|
143 |
+
if(discord_token == "discord_token"):
|
144 |
+
discord_token = os.getenv('com_camenduru_discord_token')
|
145 |
+
job_id = values['job_id']
|
146 |
+
del values['job_id']
|
147 |
+
default_filename = os.path.basename(result)
|
148 |
+
with open(result, "rb") as file:
|
149 |
+
files = {default_filename: file.read()}
|
150 |
+
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
|
151 |
+
response = requests.post(
|
152 |
+
f"https://discord.com/api/v9/channels/{discord_channel}/messages",
|
153 |
+
data=payload,
|
154 |
+
headers={"Authorization": f"Bot {discord_token}"},
|
155 |
+
files=files
|
156 |
+
)
|
157 |
+
response.raise_for_status()
|
158 |
+
result_url = response.json()['attachments'][0]['url']
|
159 |
+
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
|
160 |
+
web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
|
161 |
+
web_notify_token = os.getenv('com_camenduru_web_notify_token')
|
162 |
+
if(notify_uri == "notify_uri"):
|
163 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
164 |
+
else:
|
165 |
+
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
166 |
+
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
167 |
+
return {"jobId": job_id, "result": result_url, "status": "DONE"}
|
168 |
+
except Exception as e:
|
169 |
+
error_payload = {"jobId": job_id, "status": "FAILED"}
|
170 |
+
try:
|
171 |
+
if(notify_uri == "notify_uri"):
|
172 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
173 |
+
else:
|
174 |
+
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
|
175 |
+
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
|
176 |
+
except:
|
177 |
+
pass
|
178 |
+
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
|
179 |
+
finally:
|
180 |
+
if os.path.exists(result):
|
181 |
+
os.remove(result)
|
182 |
+
if os.path.exists(output_frames_dir):
|
183 |
+
os.remove(output_frames_dir)
|
184 |
+
if os.path.exists(frames_dir):
|
185 |
+
os.remove(frames_dir)
|
186 |
+
|
187 |
+
runpod.serverless.start({"handler": generate})
|