Dragreal / demo.py
BasicNp's picture
Upload 1672 files
e8aa256 verified
import os
import torch
import datetime
import numpy as np
from PIL import Image
from pipeline.pipeline_svd_DragAnything import StableVideoDiffusionPipeline
from models.DragAnything import DragAnythingSDVModel
from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
import cv2
import re
from scipy.ndimage import distance_transform_edt
import torchvision.transforms as T
import torch.nn.functional as F
from utils.dift_util import DIFT_Demo, SDFeaturizer
from torchvision.transforms import PILToTensor
import json
def save_gifs_side_by_side(batch_output, validation_control_images,output_folder,name = 'none', target_size=(512 , 512),duration=200):
flattened_batch_output = batch_output
def create_gif(image_list, gif_path, duration=100):
pil_images = [validate_and_convert_image(img,target_size=target_size) for img in image_list]
pil_images = [img for img in pil_images if img is not None]
if pil_images:
pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration)
# Creating GIFs for each image list
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
gif_paths = []
# validation_control_images = validation_control_images*255 validation_images,
for idx, image_list in enumerate([validation_control_images, flattened_batch_output]):
# if idx==0:
# continue
gif_path = os.path.join(output_folder, f"temp_{idx}_{timestamp}.gif")
create_gif(image_list, gif_path)
gif_paths.append(gif_path)
# Function to combine GIFs side by side
def combine_gifs_side_by_side(gif_paths, output_path):
print(gif_paths)
gifs = [Image.open(gif) for gif in gif_paths]
# Assuming all gifs have the same frame count and duration
frames = []
for frame_idx in range(gifs[0].n_frames):
combined_frame = None
for gif in gifs:
gif.seek(frame_idx)
if combined_frame is None:
combined_frame = gif.copy()
else:
combined_frame = get_concat_h(combined_frame, gif.copy())
frames.append(combined_frame)
print(gifs[0].info['duration'])
frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)
# Helper function to concatenate images horizontally
def get_concat_h(im1, im2):
dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
# Combine the GIFs into a single file
combined_gif_path = os.path.join(output_folder, f"combined_frames_{name}_{timestamp}.gif")
combine_gifs_side_by_side(gif_paths, combined_gif_path)
# Clean up temporary GIFs
for gif_path in gif_paths:
os.remove(gif_path)
return combined_gif_path
# Define functions
def validate_and_convert_image(image, target_size=(512 , 512)):
if image is None:
print("Encountered a None image")
return None
if isinstance(image, torch.Tensor):
# Convert PyTorch tensor to PIL Image
if image.ndim == 3 and image.shape[0] in [1, 3]: # Check for CxHxW format
if image.shape[0] == 1: # Convert single-channel grayscale to RGB
image = image.repeat(3, 1, 1)
image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
image = Image.fromarray(image)
else:
print(f"Invalid image tensor shape: {image.shape}")
return None
elif isinstance(image, Image.Image):
# Resize PIL Image
image = image.resize(target_size)
else:
print("Image is not a PIL Image or a PyTorch tensor")
return None
return image
def create_image_grid(images, rows, cols, target_size=(512 , 512)):
valid_images = [validate_and_convert_image(img, target_size) for img in images]
valid_images = [img for img in valid_images if img is not None]
if not valid_images:
print("No valid images to create a grid")
return None
w, h = target_size
grid = Image.new('RGB', size=(cols * w, rows * h))
for i, image in enumerate(valid_images):
grid.paste(image, box=((i % cols) * w, (i // cols) * h))
return grid
def tensor_to_pil(tensor):
""" Convert a PyTorch tensor to a PIL Image. """
# Convert tensor to numpy array
if len(tensor.shape) == 4: # batch of images
images = [Image.fromarray(img.numpy().transpose(1, 2, 0)) for img in tensor]
else: # single image
images = Image.fromarray(tensor.numpy().transpose(1, 2, 0))
return images
def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder):
# Flatten batch_output to a list of PIL Images
flattened_batch_output = [img for sublist in batch_output for img in sublist]
# Convert tensors in lists to PIL Images
validation_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_images]
validation_control_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_control_images]
flattened_batch_output = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in batch_output]
# Flatten lists if they contain sublists (for tensors converted to multiple images)
validation_images = [img for sublist in validation_images for img in (sublist if isinstance(sublist, list) else [sublist])]
validation_control_images = [img for sublist in validation_control_images for img in (sublist if isinstance(sublist, list) else [sublist])]
flattened_batch_output = [img for sublist in flattened_batch_output for img in (sublist if isinstance(sublist, list) else [sublist])]
# Combine frames into a list
combined_frames = validation_images + validation_control_images + flattened_batch_output
# Calculate rows and columns for the grid
num_images = len(combined_frames)
cols = 3
rows = (num_images + cols - 1) // cols
# Create and save the grid image
grid = create_image_grid(combined_frames, rows, cols, target_size=(512, 512))
if grid is not None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"combined_frames_{timestamp}.png"
output_path = os.path.join(output_folder, filename)
grid.save(output_path)
else:
print("Failed to create image grid")
def load_images_from_folder(folder):
images = []
valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} # Add or remove extensions as needed
# Function to extract frame number from the filename
def frame_number(filename):
matches = re.findall(r'\d+', filename) # Find all sequences of digits in the filename
if matches:
if matches[-1] == '0000' and len(matches) > 1:
return int(matches[-2]) # Return the second-to-last sequence if the last is '0000'
return int(matches[-1]) # Otherwise, return the last sequence
return float('inf') # Return 'inf'
# Sorting files based on frame number
sorted_files = sorted(os.listdir(folder), key=frame_number)
# Load images in sorted order
for filename in sorted_files:
ext = os.path.splitext(filename)[1].lower()
if ext in valid_extensions:
img = Image.open(os.path.join(folder, filename)).convert('RGB')
images.append(img)
return images
def gen_gaussian_heatmap(imgSize=200):
circle_img = np.zeros((imgSize, imgSize), np.float32)
circle_mask = cv2.circle(circle_img, (imgSize//2, imgSize//2), imgSize//2, 1, -1)
# print(circle_mask)
isotropicGrayscaleImage = np.zeros((imgSize, imgSize), np.float32)
# ็”Ÿๆˆ้ซ˜ๆ–ฏๅ›พ
for i in range(imgSize):
for j in range(imgSize):
isotropicGrayscaleImage[i, j] = 1 / 2 / np.pi / (40 ** 2) * np.exp(
-1 / 2 * ((i - imgSize / 2) ** 2 / (40 ** 2) + (j - imgSize / 2) ** 2 / (40 ** 2)))
# ๅฆ‚ๆžœ่ฆๅฏ่ง†ๅŒ–ๅฏนๆฏ”ๆญฃๆ–นๅฝขๅ’Œๆœ€ๅคงๅ†…ๅˆ‡ๅœ†้ซ˜ๆ–ฏๅ›พ็š„ๅŒบๅˆซ๏ผŒๆณจ้‡Šไธ‹้ข่ฟ™่กŒๅณๅฏ
isotropicGrayscaleImage = isotropicGrayscaleImage * circle_mask
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)).astype(np.float32)
isotropicGrayscaleImage = (isotropicGrayscaleImage / np.max(isotropicGrayscaleImage)*255).astype(np.uint8)
# ๅฐ†ๅ›พๅƒ่ฐƒๆ•ดๅคงๅฐไธบ 50x50
# isotropicGrayscaleImage = cv2.resize(isotropicGrayscaleImage, (40, 40))
return isotropicGrayscaleImage
def infer_model(model, image):
transform = T.Compose([
T.Resize((196, 196)),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
image = transform(image).unsqueeze(0).cuda()
# cls_token = model.forward_features(image)
cls_token = model(image, is_training=False)
return cls_token
def find_largest_inner_rectangle_coordinates(mask_gray):
refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL)
_, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist)
radius = int(maxVal)
return maxLoc, radius
def get_ID(images_list,masks_list,dinov2):
ID_images = []
image = images_list
mask = masks_list
# try:
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# ๆ‰พๅˆฐๆœ€ๅคง็š„่ฝฎๅป“
max_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(max_contour)
mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB)
image = image * mask
image = image[y:y+h,x:x+w]
# import random
# cv2.imwrite("./{}.jpg".format(random.randint(1, 100)),image)
# except:
# pass
# print("cv2.findContours error")
image = Image.fromarray(image).convert('RGB')
img_embedding = infer_model(dinov2, image)
return img_embedding
def get_dift_ID(feature_map,mask):
# feature_map = feature_map * 0
new_feature = []
non_zero_coordinates = np.column_stack(np.where(mask != 0))
for coord in non_zero_coordinates:
# feature_map[:, coord[0], coord[1]] = 1
new_feature.append(feature_map[:, coord[0], coord[1]])
stacked_tensor = torch.stack(new_feature, dim=0)
# ๅœจ็ปดๅบฆ0ไธŠ่ฟ›่กŒๅนณๅ‡ๆฑ ๅŒ–
average_pooled_tensor = torch.mean(stacked_tensor, dim=0)
return average_pooled_tensor
def extract_dift_feature(image, dift_model):
if isinstance(image, Image.Image):
image = image
else:
image = Image.open(image).convert('RGB')
prompt = ''
img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2
dift_feature = dift_model.forward(img_tensor, prompt=prompt, up_ft_index=3,ensemble_size=8)
return dift_feature
# cloud
def get_condition(target_size=(512 , 512), original_size=(512 , 512), args="", first_frame=None, is_mask = False, side=20,model_id=None):
images = []
vis_images = []
heatmap = gen_gaussian_heatmap()
original_size = (original_size[1],original_size[0])
size = (target_size[1],target_size[0])
latent_size = (int(target_size[1]/8), int(target_size[0]/8))
dift_model = SDFeaturizer(sd_id=model_id)
keyframe_dift = extract_dift_feature(first_frame, dift_model=dift_model)
ID_images=[]
ids_list={}
with open(os.path.join(args["validation_image"],"demo.json"), 'r') as json_file:
trajectory_json = json.load(json_file)
mask_list = []
trajectory_list = []
radius_list = []
for index in trajectory_json:
ann = trajectory_json[index]
mask_name = ann["mask_name"]
trajectories = ann["trajectory"]
trajectories = [[int(i[0]/original_size[0]*size[0]),int(i[1]/original_size[1]*size[1])] for i in trajectories]
trajectory_list.append(trajectories)
#mask
first_mask = (cv2.imread(os.path.join(args["validation_image"],mask_name))/255).astype(np.uint8)
first_mask = cv2.cvtColor(first_mask.astype(np.uint8), cv2.COLOR_RGB2GRAY)
mask_list.append(first_mask)
mask_322 = cv2.resize(first_mask.astype(np.uint8),(int(target_size[1]), int(target_size[0])))
_, radius = find_largest_inner_rectangle_coordinates(mask_322)
radius_list.append(radius)
viss = 0
if viss:
mask_list_vis = [cv2.resize(i,(int(target_size[1]), int(target_size[0]))) for i in mask_list]
vis_first_mask = show_mask(cv2.resize(np.array(first_frame).astype(np.uint8),(int(target_size[1]), int(target_size[0]))), mask_list_vis)
vis_first_mask = cv2.cvtColor(vis_first_mask, cv2.COLOR_BGR2RGB)
cv2.imwrite("test.jpg",vis_first_mask)
assert False
for idxx,point in enumerate(trajectory_list[0]):
new_img = np.zeros(target_size, np.uint8)
vis_img = new_img.copy()
ids_embedding = torch.zeros((target_size[0], target_size[1], 320))
if idxx>= args["frame_number"]:
break
for cc,(mask,trajectory,radius) in enumerate(zip(mask_list,trajectory_list,radius_list)):
center_coordinate = trajectory[idxx]
trajectory_ = trajectory[:idxx]
side = min(radius,50)
# side = radius
# if cc>=1:
# continue
# ID embedding
if idxx == 0:
# diffusion feature
mask_32 = cv2.resize(mask.astype(np.uint8),latent_size)
if len(np.column_stack(np.where(mask_32 != 0)))==0:
continue
ids_list[cc] = get_dift_ID(keyframe_dift[0],mask_32)
id_feature = ids_list[cc]
else:
id_feature = ids_list[cc]
circle_img = np.zeros((target_size[0], target_size[1]), np.float32)
circle_mask = cv2.circle(circle_img, (center_coordinate[0],center_coordinate[1]), side, 1, -1)
y1 = max(center_coordinate[1]-side,0)
y2 = min(center_coordinate[1]+side,target_size[0]-1)
x1 = max(center_coordinate[0]-side,0)
x2 = min(center_coordinate[0]+side,target_size[1]-1)
if x2-x1>3 and y2-y1>3:
need_map = cv2.resize(heatmap, (x2-x1, y2-y1))
new_img[y1:y2,x1:x2] = need_map.copy()
if cc>=0:
vis_img[y1:y2,x1:x2] = need_map.copy()
if len(trajectory_) == 1:
vis_img[trajectory_[0][1],trajectory_[0][0]] = 255
else:
for itt in range(len(trajectory_)-1):
cv2.line(vis_img,(trajectory_[itt][0],trajectory_[itt][1]),(trajectory_[itt+1][0],trajectory_[itt+1][1]),(255,255,255),3)
# ่Žทๅ–้ž้›ถๅƒ็ด ็š„ๅๆ ‡
non_zero_coordinates = np.column_stack(np.where(circle_mask != 0))
for coord in non_zero_coordinates:
ids_embedding[coord[0], coord[1]] = id_feature[0]
ids_embedding = F.avg_pool1d(ids_embedding, kernel_size=2, stride=2)
img = new_img
# Ensure all images are in RGB format
if len(img.shape) == 2: # Grayscale image
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB)
elif len(img.shape) == 3 and img.shape[2] == 3: # Color image in BGR format
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
# Convert the numpy array to a PIL image
pil_img = Image.fromarray(img)
images.append(pil_img)
vis_images.append(Image.fromarray(vis_img))
ID_images.append(ids_embedding)
return images,ID_images,vis_images
# Usage example
def convert_list_bgra_to_rgba(image_list):
"""
Convert a list of PIL Image objects from BGRA to RGBA format.
Parameters:
image_list (list of PIL.Image.Image): A list of images in BGRA format.
Returns:
list of PIL.Image.Image: The list of images converted to RGBA format.
"""
rgba_images = []
for image in image_list:
if image.mode == 'RGBA' or image.mode == 'BGRA':
# Split the image into its components
b, g, r, a = image.split()
# Re-merge in RGBA order
converted_image = Image.merge("RGBA", (r, g, b, a))
else:
# For non-alpha images, assume they are BGR and convert to RGB
b, g, r = image.split()
converted_image = Image.merge("RGB", (r, g, b))
rgba_images.append(converted_image)
return rgba_images
def show_mask(image, masks, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3)], axis=0)
h, w = mask.shape[:2]
color_a = np.concatenate([np.random.random(3)*255], axis=0)
mask_image = mask.reshape(h, w, 1) * color_a.reshape(1, 1, -1)
else:
h, w = masks[0].shape[:2]
# mask_image = mask1.reshape(h, w, 1) * np.array([30, 144, 255])
mask_image = 0
for idx,mask in enumerate(masks):
if idx!=1 and idx!=0:
continue
color = np.concatenate([np.random.random(3)*255], axis=0)
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + mask_image
return (np.array(image).copy()*0.4+mask_image*0.6).astype(np.uint8)
# Main script
if __name__ == "__main__":
args = {
"pretrained_model_name_or_path": "stabilityai/stable-video-diffusion-img2vid",
"DragAnything":"./model_out/DragAnything",
"model_DIFT":"./utils/pretrained_models/chilloutmix",
"validation_image": "./validation_demo/Demo/ship_@",
"output_dir": "./validation_demo",
"height": 320,
"width": 576,
"frame_number": 20
# cant be bothered to add the args in myself, just use notepad
}
# Load and set up the pipeline
controlnet = controlnet = DragAnythingSDVModel.from_pretrained(args["DragAnything"])
unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args["pretrained_model_name_or_path"],subfolder="unet")
pipeline = StableVideoDiffusionPipeline.from_pretrained(args["pretrained_model_name_or_path"],controlnet=controlnet,unet=unet)
pipeline.enable_model_cpu_offload()
validation_image = Image.open(os.path.join(args["validation_image"],"demo.jpg")).convert('RGB')
width, height = validation_image.size
validation_image = validation_image.resize((args["width"], args["height"]))
validation_control_images,ids_embedding,vis_images = get_condition(target_size=(args["height"] , args["width"]),
original_size=(height , width),
args = args,first_frame = validation_image,
side=100,model_id=args["model_DIFT"])
ids_embedding = torch.stack(ids_embedding, dim=0).permute(0, 3, 1, 2)
# Additional pipeline configurations can be added here
#pipeline.enable_xformers_memory_efficient_attention()
# Create output directory if it doesn't exist
val_save_dir = os.path.join(args["output_dir"], "saved_video")
os.makedirs(val_save_dir, exist_ok=True)
# Inference and saving loop
video_frames = pipeline(validation_image, validation_control_images[:args["frame_number"]], decode_chunk_size=8,num_frames=args["frame_number"],motion_bucket_id=180,controlnet_cond_scale=1.0,height=args["height"],width=args["width"],ids_embedding=ids_embedding[:args["frame_number"]]).frames
vis_images = [cv2.applyColorMap(np.array(img).astype(np.uint8), cv2.COLORMAP_JET) for img in vis_images]
vis_images = [cv2.cvtColor(np.array(img).astype(np.uint8), cv2.COLOR_BGR2RGB) for img in vis_images]
vis_images = [Image.fromarray(img) for img in vis_images]
video_frames = [img for sublist in video_frames for img in sublist]
save_gifs_side_by_side(video_frames, vis_images[:args["frame_number"]],val_save_dir,target_size=(width,height),duration=110)