|
import torch |
|
from PIL import Image |
|
import cv2 |
|
import torchvision.transforms as T |
|
import os |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from dift_util import DIFT_Demo, SDFeaturizer |
|
from torchvision.transforms import PILToTensor |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
def load_dinov2(): |
|
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14').cuda() |
|
dinov2_vitl14.eval() |
|
return dinov2_vitl14 |
|
|
|
def infer_model(model, image): |
|
transform = T.Compose([ |
|
T.Resize((196, 196)), |
|
T.ToTensor(), |
|
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
]) |
|
image = transform(image).unsqueeze(0).cuda() |
|
|
|
cls_token = model(image, is_training=False) |
|
return cls_token |
|
|
|
def sort_frames(frame_name): |
|
return int(frame_name.split('.')[0]) |
|
|
|
def find_largest_inner_rectangle_coordinates(mask_gray): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL) |
|
_, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist) |
|
radius = int(maxVal) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return maxLoc, radius |
|
|
|
|
|
def pil_image_to_numpy(image, is_maks = False, index = 1): |
|
"""Convert a PIL image to a NumPy array.""" |
|
|
|
if is_maks: |
|
image = image.resize((256, 256)) |
|
|
|
|
|
return np.array(image) |
|
else: |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image = image.resize((256, 256)) |
|
return np.array(image) |
|
|
|
def get_ID(images_list,masks_list,dinov2): |
|
|
|
ID_images = [] |
|
|
|
|
|
image = images_list[0] |
|
mask = masks_list |
|
|
|
|
|
try: |
|
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
x, y, w, h = cv2.boundingRect(contours[0]) |
|
|
|
mask = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2RGB) |
|
image = image * mask |
|
|
|
image = image[y:y+h,x:x+w] |
|
except: |
|
pass |
|
print("cv2.findContours error") |
|
|
|
|
|
|
|
image = Image.fromarray(image).convert('RGB') |
|
|
|
img_embedding = infer_model(dinov2, image) |
|
|
|
|
|
return img_embedding |
|
|
|
def get_dift_ID(feature_map,mask): |
|
|
|
|
|
|
|
new_feature = [] |
|
non_zero_coordinates = np.column_stack(np.where(mask != 0)) |
|
for coord in non_zero_coordinates: |
|
|
|
new_feature.append(feature_map[:, coord[0], coord[1]]) |
|
|
|
stacked_tensor = torch.stack(new_feature, dim=0) |
|
|
|
average_pooled_tensor = torch.mean(stacked_tensor, dim=0) |
|
|
|
return average_pooled_tensor |
|
|
|
|
|
def extract_dift_feature(image, dift_model): |
|
if isinstance(image, Image.Image): |
|
image = image |
|
else: |
|
image = Image.open(image).convert('RGB') |
|
|
|
prompt = '' |
|
img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2 |
|
dift_feature = dift_model.forward(img_tensor, prompt=prompt, up_ft_index=3,ensemble_size=8) |
|
return dift_feature |
|
|
|
dinov2 = load_dinov2() |
|
dinov2.requires_grad_(False) |
|
|
|
|
|
model_id = 'pretrained_models/chilloutmix' |
|
dift_model = SDFeaturizer(sd_id=model_id) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_type = "ref-youtube-vos" |
|
|
|
|
|
if dataset_type == "ref-youtube-vos": |
|
video_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/JPEGImages" |
|
ann_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/Annotations" |
|
save_p = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/ref-youtube-vos/train/embedding_SD_512_once" |
|
else: |
|
video_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/VIPSeg/imgs" |
|
ann_folder = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/VIPSeg/panomasks" |
|
save_p = "/mmu-ocr/weijiawu/MovieDiffusion/ShowAnything/data/VIPSeg/embedding_SD_512_once" |
|
|
|
dataset_size = 512 |
|
|
|
dataset = [i for i in os.listdir(ann_folder)] |
|
|
|
for videoid in dataset: |
|
|
|
video_dir_1 = os.path.join(video_folder, videoid) |
|
ann_folder_1 = os.path.join(ann_folder, videoid) |
|
save_embedding = os.path.join(save_p, videoid) |
|
save_embedding_once = os.path.join(save_p, videoid+".pth") |
|
|
|
|
|
|
|
|
|
|
|
image_files = sorted(os.listdir(video_dir_1), key=sort_frames) |
|
depth_files = sorted(os.listdir(ann_folder_1), key=sort_frames) |
|
|
|
|
|
keyframe_image = Image.open(os.path.join(video_dir_1, image_files[0])).convert('RGB') |
|
keyframe_image = keyframe_image.resize((dataset_size, dataset_size)) |
|
keyframe_dift = extract_dift_feature(keyframe_image, dift_model=dift_model) |
|
|
|
|
|
mask = np.array(Image.open(os.path.join(ann_folder_1, depth_files[0]))) |
|
|
|
|
|
ids = [i for i in np.unique(mask)] |
|
numpy_depth_images = np.array([pil_image_to_numpy(Image.open(os.path.join(ann_folder_1, df)),True,ids) for df in depth_files]) |
|
|
|
ids_list = {} |
|
for index_mask, mask in tqdm(enumerate(numpy_depth_images)): |
|
ids_embedding = torch.ones((dataset_size, dataset_size, 320)) |
|
|
|
|
|
|
|
|
|
|
|
for index in ids: |
|
mask_array = (np.array(mask)==index)*1 |
|
|
|
try: |
|
center_coordinate,_ = find_largest_inner_rectangle_coordinates(mask_array) |
|
except: |
|
continue |
|
print("find_largest_inner_rectangle_coordinates error") |
|
|
|
|
|
circle_img = np.zeros((dataset_size, dataset_size), np.float32) |
|
circle_mask = cv2.circle(circle_img, (center_coordinate[0],center_coordinate[1]), 20, 1, -1) |
|
|
|
|
|
|
|
if index_mask == 0: |
|
|
|
mask_32 = cv2.resize(mask_array.astype(np.uint8),(int(dataset_size/8),int(dataset_size/8))) |
|
if len(np.column_stack(np.where(mask_32 != 0)))==0: |
|
continue |
|
|
|
id_feature = get_dift_ID(keyframe_dift[0],mask_32) |
|
ids_list[index] = id_feature |
|
|
|
else: |
|
try: |
|
id_feature = ids_list[index] |
|
except: |
|
print("index error") |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.save(ids_list, save_embedding_once) |
|
|
|
|
|
break |
|
|
|
|
|
|
|
ids_embedding = F.avg_pool1d(ids_embedding, kernel_size=2, stride=2) |
|
torch.save(ids_embedding, os.path.join(save_embedding, '{}.pth'.format(index_mask))) |
|
|
|
|