jiaweir
init
21c4e64
raw
history blame
4.21 kB
import torch
from diffusers import StableVideoDiffusionPipeline
from PIL import Image
import numpy as np
import cv2
import rembg
import argparse
import imageio
import os
def add_margin(pil_img, top, right, bottom, left, color):
width, height = pil_img.size
new_width = width + right + left
new_height = height + top + bottom
result = Image.new(pil_img.mode, (new_width, new_height), color)
result.paste(pil_img, (left, top))
return result
def resize_image(image, output_size=(1024, 576)):
image = image.resize((output_size[1],output_size[1]))
pad_size = (output_size[0]-output_size[1]) //2
image = add_margin(image, 0, pad_size, 0, pad_size, tuple(np.array(image)[0,0]))
return image
def load_image(file, W, H, bg='white'):
# load image
print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
bg_remover = rembg.new_session()
img = rembg.remove(img, session=bg_remover)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
input_mask = img[..., 3:]
# white bg
if bg == 'white':
input_img = img[..., :3] * input_mask + (1 - input_mask)
elif bg == 'black':
input_img = img[..., :3]
else:
raise NotImplementedError
# bgr to rgb
input_img = input_img[..., ::-1].copy()
input_img = Image.fromarray(np.uint8(input_img*255))
return input_img
def load_image_w_bg(file, W, H):
# load image
print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
input_img = img[..., :3]
# bgr to rgb
input_img = input_img[..., ::-1].copy()
input_img = Image.fromarray(np.uint8(input_img*255))
return input_img
def gen_vid(input_path, seed, bg, is_pad):
name = input_path.split('/')[-1].split('.')[0]
input_dir = os.path.dirname(input_path)
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16"
)
# pipe.enable_model_cpu_offload()
pipe.to("cuda")
if is_pad:
height, width = 576, 1024
else:
height, width = 512, 512
if seed is None:
for bg in ['white', 'black', 'orig']:
if bg == 'orig':
if 'rgba' in name:
continue
image = load_image_w_bg(input_path, width, height)
else:
image = load_image(input_path, width, height, bg)
if is_pad:
image = resize_image(image, output_size=(width, height))
for seed in range(20):
generator = torch.manual_seed(seed)
frames = pipe(image, height, width, generator=generator).frames[0]
imageio.mimwrite(f"{input_dir}/videos/{name}_{bg}_{seed:03}.mp4", frames, fps=7)
else:
if bg == 'orig':
if 'rgba' in name:
raise ValueError
image = load_image_w_bg(input_path, width, height)
else:
image = load_image(input_path, width, height, bg)
if is_pad:
image = resize_image(image, output_size=(width, height))
generator = torch.manual_seed(seed)
frames = pipe(image, height, width, generator=generator).frames[0]
imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7)
os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True)
for idx, img in enumerate(frames):
if is_pad:
img = img.crop(((width-height) //2, 0, width - (width-height) //2, height))
img.save(f"{input_dir}/{name}_frames/{idx:03}.png")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--bg", type=str, default='white')
parser.add_argument("--is_pad", type=bool, default=False)
args, extras = parser.parse_known_args()
gen_vid(args.path, args.seed, args.bg, args.is_pad)