File size: 4,103 Bytes
7c3ff16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
import random
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
class HumanDanceDataset(Dataset):
def __init__(
self,
img_size,
img_scale=(1.0, 1.0),
img_ratio=(0.9, 1.0),
drop_ratio=0.1,
data_meta_paths=["./data/fahsion_meta.json"],
sample_margin=30,
):
super().__init__()
self.img_size = img_size
self.img_scale = img_scale
self.img_ratio = img_ratio
self.sample_margin = sample_margin
# -----
# vid_meta format:
# [{'video_path': , 'kps_path': , 'other':},
# {'video_path': , 'kps_path': , 'other':}]
# -----
vid_meta = []
for data_meta_path in data_meta_paths:
vid_meta.extend(json.load(open(data_meta_path, "r")))
self.vid_meta = vid_meta
self.clip_image_processor = CLIPImageProcessor()
self.transform = transforms.Compose(
[
# transforms.RandomResizedCrop(
# self.img_size,
# scale=self.img_scale,
# ratio=self.img_ratio,
# interpolation=transforms.InterpolationMode.BILINEAR,
# ),
transforms.Resize(
self.img_size,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.cond_transform = transforms.Compose(
[
# transforms.RandomResizedCrop(
# self.img_size,
# scale=self.img_scale,
# ratio=self.img_ratio,
# interpolation=transforms.InterpolationMode.BILINEAR,
# ),
transforms.Resize(
self.img_size,
),
transforms.ToTensor(),
]
)
self.drop_ratio = drop_ratio
def augmentation(self, image, transform, state=None):
if state is not None:
torch.set_rng_state(state)
return transform(image)
def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["video_path"]
kps_path = video_meta["kps_path"]
video_reader = VideoReader(video_path)
kps_reader = VideoReader(kps_path)
assert len(video_reader) == len(
kps_reader
), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
video_length = len(video_reader)
margin = min(self.sample_margin, video_length)
ref_img_idx = random.randint(0, video_length - 1)
if ref_img_idx + margin < video_length:
tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
elif ref_img_idx - margin > 0:
tgt_img_idx = random.randint(0, ref_img_idx - margin)
else:
tgt_img_idx = random.randint(0, video_length - 1)
ref_img = video_reader[ref_img_idx]
ref_img_pil = Image.fromarray(ref_img.asnumpy())
tgt_img = video_reader[tgt_img_idx]
tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
tgt_pose = kps_reader[tgt_img_idx]
tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
state = torch.get_rng_state()
tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
clip_image = self.clip_image_processor(
images=ref_img_pil, return_tensors="pt"
).pixel_values[0]
sample = dict(
video_dir=video_path,
img=tgt_img,
tgt_pose=tgt_pose_img,
ref_img=ref_img_vae,
clip_images=clip_image,
)
return sample
def __len__(self):
return len(self.vid_meta)
|