camenduru commited on
Commit
8067938
1 Parent(s): 5dc8c78

Create worker_runpod.py

Browse files
Files changed (1) hide show
  1. worker_runpod.py +187 -0
worker_runpod.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, requests, random, runpod
2
+
3
+ import torch
4
+ import numpy as np
5
+ import cv2
6
+ from sam2.build_sam import build_sam2_video_predictor
7
+ import shutil
8
+ import subprocess
9
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
10
+ from torchvision.transforms import functional as F
11
+
12
+ def detect_body_keypoints(frame):
13
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
14
+ img_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to('cuda')
15
+ with torch.no_grad():
16
+ prediction = body_detector(img_tensor)[0]
17
+ if len(prediction['boxes']) > 0:
18
+ best_box = prediction['boxes'][prediction['scores'].argmax()].cpu().numpy()
19
+ x1, y1, x2, y2 = best_box
20
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
21
+ width, height = x2 - x1, y2 - y1
22
+ offset_x, offset_y = width * 0.2, height * 0.2
23
+ keypoints = np.array([
24
+ [center_x, center_y],
25
+ [center_x - offset_x, center_y],
26
+ [center_x + offset_x, center_y],
27
+ [center_x, center_y - offset_y],
28
+ [center_x, center_y + offset_y],
29
+ ], dtype=np.float32)
30
+ keypoints[:, 0] = np.clip(keypoints[:, 0], x1, x2)
31
+ keypoints[:, 1] = np.clip(keypoints[:, 1], y1, y2)
32
+ return keypoints
33
+ else:
34
+ height, width = frame.shape[:2]
35
+ center = np.array([[width // 2, height // 2]], dtype=np.float32)
36
+ return np.tile(center, (5, 1))
37
+
38
+ def remove_background(frame, mask, bg_color):
39
+ mask = mask.squeeze()
40
+ if mask.dtype == bool:
41
+ mask = mask.astype(np.uint8) * 255
42
+ else:
43
+ mask = (mask > 0).astype(np.uint8) * 255
44
+ mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
45
+ bg = np.full(frame.shape, bg_color, dtype=np.uint8)
46
+ fg = cv2.bitwise_and(frame, frame, mask=mask)
47
+ bg = cv2.bitwise_and(bg, bg, mask=cv2.bitwise_not(mask))
48
+ result = cv2.add(fg, bg)
49
+ result = clean_hair_area(frame, result, mask, bg_color)
50
+ return result
51
+
52
+ def clean_hair_area(original, processed, mask, bg_color):
53
+ kernel = np.ones((5, 5), np.uint8)
54
+ dilated_mask = cv2.dilate(mask, kernel, iterations=2)
55
+ hair_edge_mask = cv2.subtract(dilated_mask, mask)
56
+ bg_sample = cv2.bitwise_and(original, original, mask=cv2.bitwise_not(dilated_mask))
57
+ bg_average = cv2.mean(bg_sample)[:3]
58
+ color_distances = np.sqrt(np.sum((original.astype(np.float32) - bg_average) ** 2, axis=2))
59
+ color_distances = (color_distances - color_distances.min()) / (color_distances.max() - color_distances.min())
60
+ alpha = (1 - color_distances) * (hair_edge_mask / 255.0)
61
+ alpha = np.clip(alpha, 0, 1)
62
+ for c in range(3):
63
+ processed[:, :, c] = processed[:, :, c] * (1 - alpha) + bg_color[c] * alpha
64
+ return processed
65
+
66
+ with torch.inference_mode():
67
+ checkpoint = 'sam2_hiera_large.pt'
68
+ model_cfg = 'sam2_hiera_l.yaml'
69
+ predictor = build_sam2_video_predictor(model_cfg, checkpoint)
70
+ body_detector = fasterrcnn_resnet50_fpn(pretrained=True)
71
+ body_detector.eval()
72
+ body_detector.to("cuda")
73
+
74
+ def download_file(url, save_dir):
75
+ os.makedirs(save_dir, exist_ok=True)
76
+ file_name = url.split('/')[-1]
77
+ file_path = os.path.join(save_dir, file_name)
78
+ response = requests.get(url)
79
+ response.raise_for_status()
80
+ with open(file_path, 'wb') as file:
81
+ file.write(response.content)
82
+ return file_path
83
+
84
+ @torch.inference_mode()
85
+ def generate(input):
86
+ values = input["input"]
87
+
88
+ input_video = values['input_video']
89
+ input_video = download_file(url=input_video, save_dir='/content')
90
+ bg_color = values['bg_color']
91
+
92
+ bg_color = tuple(int(bg_color.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))[::-1]
93
+ frames_dir = "/content/frames"
94
+ if os.path.exists(frames_dir):
95
+ shutil.rmtree(frames_dir)
96
+ os.makedirs(frames_dir, exist_ok=True)
97
+ ffmpeg_cmd = ["ffmpeg", "-i", str(input_video), "-q:v", "2", "-start_number", "0",f"{frames_dir}/%05d.jpg"]
98
+ result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, check=True)
99
+ frame_names = [p for p in os.listdir(frames_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))]
100
+ frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
101
+ inference_state = predictor.init_state(video_path=frames_dir)
102
+ first_frame_path = os.path.join(frames_dir, frame_names[0])
103
+ first_frame = cv2.imread(first_frame_path)
104
+ keypoints = detect_body_keypoints(first_frame)
105
+ _, out_obj_ids, out_mask_logits = predictor.add_new_points(inference_state=inference_state, frame_idx=0, obj_id=1, points=keypoints, labels=np.ones(len(keypoints), dtype=np.int32))
106
+ video_segments = {}
107
+ for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
108
+ video_segments[out_frame_idx] = {
109
+ out_obj_id: out_mask_logits[i].cpu().numpy()
110
+ for i, out_obj_id in enumerate(out_obj_ids)
111
+ }
112
+ output_frames_dir = '/content/output_frames'
113
+ os.makedirs(output_frames_dir, exist_ok=True)
114
+ frame_count = 0
115
+ for out_frame_idx in range(len(frame_names)):
116
+ frame_path = os.path.join(frames_dir, frame_names[out_frame_idx])
117
+ frame = cv2.imread(frame_path)
118
+ for out_obj_id, out_mask in video_segments[out_frame_idx].items():
119
+ frame_with_bg_removed = remove_background(frame, out_mask, bg_color)
120
+ output_frame_path = os.path.join(output_frames_dir, f"{out_frame_idx:05d}.jpg")
121
+ cv2.imwrite(output_frame_path, frame_with_bg_removed)
122
+ frame_count += 1
123
+ output_video_path = '/content/sam2_rm_bg_tost.mp4'
124
+ final_video_cmd = ["ffmpeg", "-y", "-framerate", "30", "-i", f"{output_frames_dir}/%05d.jpg", "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video_path]
125
+ result = subprocess.run(final_video_cmd, capture_output=True, text=True, check=True)
126
+
127
+ result = "/content/sam2_rm_bg_tost.mp4"
128
+ try:
129
+ notify_uri = values['notify_uri']
130
+ del values['notify_uri']
131
+ notify_token = values['notify_token']
132
+ del values['notify_token']
133
+ discord_id = values['discord_id']
134
+ del values['discord_id']
135
+ if(discord_id == "discord_id"):
136
+ discord_id = os.getenv('com_camenduru_discord_id')
137
+ discord_channel = values['discord_channel']
138
+ del values['discord_channel']
139
+ if(discord_channel == "discord_channel"):
140
+ discord_channel = os.getenv('com_camenduru_discord_channel')
141
+ discord_token = values['discord_token']
142
+ del values['discord_token']
143
+ if(discord_token == "discord_token"):
144
+ discord_token = os.getenv('com_camenduru_discord_token')
145
+ job_id = values['job_id']
146
+ del values['job_id']
147
+ default_filename = os.path.basename(result)
148
+ with open(result, "rb") as file:
149
+ files = {default_filename: file.read()}
150
+ payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
151
+ response = requests.post(
152
+ f"https://discord.com/api/v9/channels/{discord_channel}/messages",
153
+ data=payload,
154
+ headers={"Authorization": f"Bot {discord_token}"},
155
+ files=files
156
+ )
157
+ response.raise_for_status()
158
+ result_url = response.json()['attachments'][0]['url']
159
+ notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
160
+ web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
161
+ web_notify_token = os.getenv('com_camenduru_web_notify_token')
162
+ if(notify_uri == "notify_uri"):
163
+ requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
164
+ else:
165
+ requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
166
+ requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
167
+ return {"jobId": job_id, "result": result_url, "status": "DONE"}
168
+ except Exception as e:
169
+ error_payload = {"jobId": job_id, "status": "FAILED"}
170
+ try:
171
+ if(notify_uri == "notify_uri"):
172
+ requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
173
+ else:
174
+ requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
175
+ requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
176
+ except:
177
+ pass
178
+ return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
179
+ finally:
180
+ if os.path.exists(result):
181
+ os.remove(result)
182
+ if os.path.exists(output_frames_dir):
183
+ os.remove(output_frames_dir)
184
+ if os.path.exists(frames_dir):
185
+ os.remove(frames_dir)
186
+
187
+ runpod.serverless.start({"handler": generate})