sam2-remove-video-background / worker_runpod.py
camenduru's picture
Update worker_runpod.py
8c65d88 verified
raw
history blame contribute delete
No virus
9.06 kB
import os, json, requests, random, runpod
import torch
import numpy as np
import cv2
from sam2.build_sam import build_sam2_video_predictor
import shutil
import subprocess
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.transforms import functional as F
def detect_body_keypoints(frame):
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img_tensor = F.to_tensor(frame_rgb).unsqueeze(0).to('cuda')
with torch.no_grad():
prediction = body_detector(img_tensor)[0]
if len(prediction['boxes']) > 0:
best_box = prediction['boxes'][prediction['scores'].argmax()].cpu().numpy()
x1, y1, x2, y2 = best_box
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
width, height = x2 - x1, y2 - y1
offset_x, offset_y = width * 0.2, height * 0.2
keypoints = np.array([
[center_x, center_y],
[center_x - offset_x, center_y],
[center_x + offset_x, center_y],
[center_x, center_y - offset_y],
[center_x, center_y + offset_y],
], dtype=np.float32)
keypoints[:, 0] = np.clip(keypoints[:, 0], x1, x2)
keypoints[:, 1] = np.clip(keypoints[:, 1], y1, y2)
return keypoints
else:
height, width = frame.shape[:2]
center = np.array([[width // 2, height // 2]], dtype=np.float32)
return np.tile(center, (5, 1))
def remove_background(frame, mask, bg_color):
mask = mask.squeeze()
if mask.dtype == bool:
mask = mask.astype(np.uint8) * 255
else:
mask = (mask > 0).astype(np.uint8) * 255
mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
bg = np.full(frame.shape, bg_color, dtype=np.uint8)
fg = cv2.bitwise_and(frame, frame, mask=mask)
bg = cv2.bitwise_and(bg, bg, mask=cv2.bitwise_not(mask))
result = cv2.add(fg, bg)
result = clean_hair_area(frame, result, mask, bg_color)
return result
def clean_hair_area(original, processed, mask, bg_color):
kernel = np.ones((5, 5), np.uint8)
dilated_mask = cv2.dilate(mask, kernel, iterations=2)
hair_edge_mask = cv2.subtract(dilated_mask, mask)
bg_sample = cv2.bitwise_and(original, original, mask=cv2.bitwise_not(dilated_mask))
bg_average = cv2.mean(bg_sample)[:3]
color_distances = np.sqrt(np.sum((original.astype(np.float32) - bg_average) ** 2, axis=2))
color_distances = (color_distances - color_distances.min()) / (color_distances.max() - color_distances.min())
alpha = (1 - color_distances) * (hair_edge_mask / 255.0)
alpha = np.clip(alpha, 0, 1)
for c in range(3):
processed[:, :, c] = processed[:, :, c] * (1 - alpha) + bg_color[c] * alpha
return processed
with torch.inference_mode():
checkpoint = 'sam2_hiera_large.pt'
model_cfg = 'sam2_hiera_l.yaml'
predictor = build_sam2_video_predictor(model_cfg, checkpoint)
body_detector = fasterrcnn_resnet50_fpn(pretrained=True)
body_detector.eval()
body_detector.to("cuda")
def download_file(url, save_dir):
os.makedirs(save_dir, exist_ok=True)
file_name = url.split('/')[-1]
file_path = os.path.join(save_dir, file_name)
response = requests.get(url)
response.raise_for_status()
with open(file_path, 'wb') as file:
file.write(response.content)
return file_path
@torch.inference_mode()
def generate(input):
values = input["input"]
input_video = values['input_video']
input_video = download_file(url=input_video, save_dir='/content')
bg_color = values['bg_color']
bg_color = tuple(int(bg_color.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))[::-1]
frames_dir = "/content/frames"
if os.path.exists(frames_dir):
shutil.rmtree(frames_dir)
os.makedirs(frames_dir, exist_ok=True)
ffmpeg_cmd = ["ffmpeg", "-i", str(input_video), "-q:v", "2", "-start_number", "0",f"{frames_dir}/%05d.jpg"]
result = subprocess.run(ffmpeg_cmd, capture_output=True, text=True, check=True)
frame_names = [p for p in os.listdir(frames_dir) if p.endswith(('.jpg', '.jpeg', '.JPG', '.JPEG'))]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(video_path=frames_dir)
first_frame_path = os.path.join(frames_dir, frame_names[0])
first_frame = cv2.imread(first_frame_path)
keypoints = detect_body_keypoints(first_frame)
_, out_obj_ids, out_mask_logits = predictor.add_new_points(inference_state=inference_state, frame_idx=0, obj_id=1, points=keypoints, labels=np.ones(len(keypoints), dtype=np.int32))
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: out_mask_logits[i].cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
output_frames_dir = '/content/output_frames'
os.makedirs(output_frames_dir, exist_ok=True)
frame_count = 0
for out_frame_idx in range(len(frame_names)):
frame_path = os.path.join(frames_dir, frame_names[out_frame_idx])
frame = cv2.imread(frame_path)
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
frame_with_bg_removed = remove_background(frame, out_mask, bg_color)
output_frame_path = os.path.join(output_frames_dir, f"{out_frame_idx:05d}.jpg")
cv2.imwrite(output_frame_path, frame_with_bg_removed)
frame_count += 1
output_video_path = '/content/sam2_rm_bg_tost.mp4'
final_video_cmd = ["ffmpeg", "-y", "-framerate", "30", "-i", f"{output_frames_dir}/%05d.jpg", "-c:v", "libx264", "-pix_fmt", "yuv420p", output_video_path]
result = subprocess.run(final_video_cmd, capture_output=True, text=True, check=True)
result = "/content/sam2_rm_bg_tost.mp4"
try:
notify_uri = values['notify_uri']
del values['notify_uri']
notify_token = values['notify_token']
del values['notify_token']
discord_id = values['discord_id']
del values['discord_id']
if(discord_id == "discord_id"):
discord_id = os.getenv('com_camenduru_discord_id')
discord_channel = values['discord_channel']
del values['discord_channel']
if(discord_channel == "discord_channel"):
discord_channel = os.getenv('com_camenduru_discord_channel')
discord_token = values['discord_token']
del values['discord_token']
if(discord_token == "discord_token"):
discord_token = os.getenv('com_camenduru_discord_token')
job_id = values['job_id']
del values['job_id']
default_filename = os.path.basename(result)
with open(result, "rb") as file:
files = {default_filename: file.read()}
payload = {"content": f"{json.dumps(values)} <@{discord_id}>"}
response = requests.post(
f"https://discord.com/api/v9/channels/{discord_channel}/messages",
data=payload,
headers={"Authorization": f"Bot {discord_token}"},
files=files
)
response.raise_for_status()
result_url = response.json()['attachments'][0]['url']
notify_payload = {"jobId": job_id, "result": result_url, "status": "DONE"}
web_notify_uri = os.getenv('com_camenduru_web_notify_uri')
web_notify_token = os.getenv('com_camenduru_web_notify_token')
if(notify_uri == "notify_uri"):
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
else:
requests.post(web_notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
requests.post(notify_uri, data=json.dumps(notify_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
return {"jobId": job_id, "result": result_url, "status": "DONE"}
except Exception as e:
error_payload = {"jobId": job_id, "status": "FAILED"}
try:
if(notify_uri == "notify_uri"):
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
else:
requests.post(web_notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": web_notify_token})
requests.post(notify_uri, data=json.dumps(error_payload), headers={'Content-Type': 'application/json', "Authorization": notify_token})
except:
pass
return {"jobId": job_id, "result": f"FAILED: {str(e)}", "status": "FAILED"}
finally:
if os.path.exists(result):
os.remove(result)
if os.path.exists(output_frames_dir):
shutil.rmtree(output_frames_dir)
if os.path.exists(frames_dir):
shutil.rmtree(frames_dir)
runpod.serverless.start({"handler": generate})