Spaces:
Runtime error
Runtime error
import os, json, requests, random, runpod | |
import torch | |
import numpy as np | |
import cv2 | |
from sam2.build_sam import build_sam2_video_predictor | |
import shutil | |
import subprocess | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from torchvision.transforms import functional as F | |
def detect_body_keypoints(frame): | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
img_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to('cuda') | |
with torch.no_grad(): | |
prediction = body_detector(img_tensor)[0] | |
if len(prediction['boxes']) > 0: | |
best_box = prediction['boxes'][prediction['scores'].argmax()].cpu().numpy() | |
x1, y1, x2, y2 = best_box | |
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2 | |
width, height = x2 - x1, y2 - y1 | |
offset_x, offset_y = width * 0.2, height * 0.2 | |
keypoints = np.array([ | |
[center_x, center_y], | |
[center_x - offset_x, center_y], | |
[center_x + offset_x, center_y], | |
[center_x, center_y - offset_y], | |
[center_x, center_y + offset_y], | |
], dtype=np.float32) | |
keypoints[:, 0] = np.clip(keypoints[:, 0], x1, x2) | |
keypoints[:, 1] = np.clip(keypoints[:, 1], y1, y2) | |
return keypoints | |
else: | |
height, width = frame.shape[:2] | |
center = np.array([[width // 2, height // 2]], dtype=np.float32) | |
return np.tile(center, (5, 1)) | |
def remove_background(frame, mask, bg_color): | |
mask = mask.squeeze() | |
if mask.dtype == bool: | |
mask = mask.astype(np.uint8) * 255 | |
else: | |
mask = (mask > 0).astype(np.uint8) * 255 | |
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) | |
bg = np.full(frame.shape, bg_color, dtype=np.uint8) | |
fg = cv2.bitwise_and(frame, frame, mask=mask) | |
bg = cv2.bitwise_and(bg, bg, mask=cv2.bitwise_not(mask)) | |
result = cv2.add(fg, bg) | |
result = clean_hair_area(frame, result, mask, bg_color) | |
return result | |
def clean_hair_area(original, processed, mask, bg_color): | |
kernel = np.ones((5, 5), np.uint8) | |
dilated_mask = cv2.dilate(mask, kernel, iterations=2) | |
hair_edge_mask = cv2.subtract(dilated_mask, mask) | |
bg_sample = cv2.bitwise_and(original, original, mask=cv2.bitwise_not(dilated_mask)) | |
bg_average = cv2.mean(bg_sample)[:3] | |
color_distances = np.sqrt(np.sum((original.astype(np.float32) - bg_average) ** 2, axis=2)) | |
color_distances = (color_distances - color_distances.min()) / (color_distances.max() - color_distances.min()) | |
alpha = (1 - color_distances) * (hair_edge_mask / 255.0) | |
alpha = np.clip(alpha, 0, 1) | |
for c in range(3): | |
processed[:, :, c] = processed[:, :, c] * (1 - alpha) + bg_color[c] * alpha | |
return processed | |
with torch.inference_mode(): | |
checkpoint = 'sam2_hiera_large.pt' | |
model_cfg = 'sam2_hiera_l.yaml' | |
predictor = build_sam2_video_predictor(model_cfg, checkpoint) | |
body_detector = fasterrcnn_resnet50_fpn(pretrained=True) | |
body_detector.eval() | |
body_detector.to("cuda") | |
def download_file(url, save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
file_name = url.split('/')[-1] | |
file_path = os.path.join(save_dir, file_name) | |
response = requests.get(url) | |
response.raise_for_status() | |
with open(file_path, 'wb') as file: | |
file.write(response.content) | |
return file_path | |
def generate(input): | |
values = input["input"] | |
input_video = values['input_video'] | |
input_video = download_file(url=input_video, save_dir='/content') | |
bg_color = values['bg_color'] | |
bg_color = tuple(int(bg_color.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))[::-1] | |
frames_dir = "/content/frames" | |
if os.path.exists(frames_dir): | |
shutil.rmtree(frames_dir) | |
os.makedirs(frames_dir, exist_ok=True) | |
ffmpeg_cmd = ["ffmpeg", "-i", str(input_video), "-q:v", "2", "-start_number", "0",f"{frames_dir}/%05d.jpg"] | |
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, check=True) | |
frame_names = [p for p in os.listdir(frames_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))] | |
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) | |
inference_state = predictor.init_state(video_path=frames_dir) | |
first_frame_path = os.path.join(frames_dir, frame_names[0]) | |
first_frame = cv2.imread(first_frame_path) | |
keypoints = detect_body_keypoints(first_frame) | |
_, out_obj_ids, out_mask_logits = predictor.add_new_points(inference_state=inference_state, frame_idx=0, obj_id=1, points=keypoints, labels=np.ones(len(keypoints), dtype=np.int32)) | |
video_segments = {} | |
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state): | |
video_segments[out_frame_idx] = { | |
out_obj_id: out_mask_logits[i].cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
output_frames_dir = '/content/output_frames' | |
os.makedirs(output_frames_dir, exist_ok=True) | |
frame_count = 0 | |
for out_frame_idx in range(len(frame_names)): | |
frame_path = os.path.join(frames_dir, frame_names[out_frame_idx]) | |
frame = cv2.imread(frame_path) | |
for out_obj_id, out_mask in video_segments[out_frame_idx].items(): | |
frame_with_bg_removed = remove_background(frame, out_mask, bg_color) | |
output_frame_path = os.path.join(output_frames_dir, f"{out_frame_idx:05d}.jpg") | |
cv2.imwrite(output_frame_path, frame_with_bg_removed) | |
frame_count += 1 | |
output_video_path = '/content/sam2_rm_bg_tost.mp4' | |
final_video_cmd = ["ffmpeg", "-y", "-framerate", "30", "-i", f"{output_frames_dir}/%05d.jpg", "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video_path] | |
result = subprocess.run(final_video_cmd, capture_output=True, text=True, check=True) | |
result = "/content/sam2_rm_bg_tost.mp4" | |
try: | |
notify_uri = values['notify_uri'] | |
del values['notify_uri'] | |
notify_token = values['notify_token'] | |
del values['notify_token'] | |
discord_id = values['discord_id'] | |
del values['discord_id'] | |
if(discord_id == "discord_id"): | |
discord_id = os.getenv('com_camenduru_discord_id') | |
discord_channel = values['discord_channel'] | |
del values['discord_channel'] | |
if(discord_channel == "discord_channel"): | |
discord_channel = os.getenv('com_camenduru_discord_channel') | |
discord_token = values['discord_token'] | |
del values['discord_token'] | |
if(discord_token == "discord_token"): | |
discord_token = os.getenv('com_camenduru_discord_token') | |
job_id = values['job_id'] | |
del values['job_id'] | |
default_filename = os.path.basename(result) | |
with open(result, "rb") as file: | |
files = {default_filename: file.read()} | |
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"} | |
response = requests.post( | |
f"https://discord.com/api/v9/channels/{discord_channel}/messages", | |
data=payload, | |
headers={"Authorization": f"Bot {discord_token}"}, | |
files=files | |
) | |
response.raise_for_status() | |
result_url = response.json()['attachments'][0]['url'] | |
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"} | |
web_notify_uri = os.getenv('com_camenduru_web_notify_uri') | |
web_notify_token = os.getenv('com_camenduru_web_notify_token') | |
if(notify_uri == "notify_uri"): | |
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
else: | |
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token}) | |
return {"jobId": job_id, "result": result_url, "status": "DONE"} | |
except Exception as e: | |
error_payload = {"jobId": job_id, "status": "FAILED"} | |
try: | |
if(notify_uri == "notify_uri"): | |
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
else: | |
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token}) | |
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token}) | |
except: | |
pass | |
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"} | |
finally: | |
if os.path.exists(result): | |
os.remove(result) | |
if os.path.exists(output_frames_dir): | |
shutil.rmtree(output_frames_dir) | |
if os.path.exists(frames_dir): | |
shutil.rmtree(frames_dir) | |
runpod.serverless.start({"handler": generate}) |