|
|
|
import os |
|
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence, |
|
Sized, TypeVar, Union) |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import BatchSampler, Dataset, Sampler |
|
|
|
ASPECT_RATIO_512 = { |
|
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], |
|
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], |
|
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], |
|
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], |
|
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], |
|
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], |
|
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], |
|
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], |
|
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], |
|
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] |
|
} |
|
ASPECT_RATIO_RANDOM_CROP_512 = { |
|
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0], |
|
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], |
|
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], |
|
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], |
|
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0] |
|
} |
|
ASPECT_RATIO_RANDOM_CROP_PROB = [ |
|
1, 2, |
|
4, 4, 4, 4, |
|
8, 8, 8, |
|
4, 4, 4, 4, |
|
2, 1 |
|
] |
|
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) |
|
|
|
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): |
|
aspect_ratio = height / width |
|
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) |
|
return ratios[closest_ratio], float(closest_ratio) |
|
|
|
def get_image_size_without_loading(path): |
|
with Image.open(path) as img: |
|
return img.size |
|
|
|
class RandomSampler(Sampler[int]): |
|
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. |
|
|
|
If with replacement, then user can specify :attr:`num_samples` to draw. |
|
|
|
Args: |
|
data_source (Dataset): dataset to sample from |
|
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` |
|
num_samples (int): number of samples to draw, default=`len(dataset)`. |
|
generator (Generator): Generator used in sampling. |
|
""" |
|
|
|
data_source: Sized |
|
replacement: bool |
|
|
|
def __init__(self, data_source: Sized, replacement: bool = False, |
|
num_samples: Optional[int] = None, generator=None) -> None: |
|
self.data_source = data_source |
|
self.replacement = replacement |
|
self._num_samples = num_samples |
|
self.generator = generator |
|
self._pos_start = 0 |
|
|
|
if not isinstance(self.replacement, bool): |
|
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") |
|
|
|
if not isinstance(self.num_samples, int) or self.num_samples <= 0: |
|
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") |
|
|
|
@property |
|
def num_samples(self) -> int: |
|
|
|
if self._num_samples is None: |
|
return len(self.data_source) |
|
return self._num_samples |
|
|
|
def __iter__(self) -> Iterator[int]: |
|
n = len(self.data_source) |
|
if self.generator is None: |
|
seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
generator = torch.Generator() |
|
generator.manual_seed(seed) |
|
else: |
|
generator = self.generator |
|
|
|
if self.replacement: |
|
for _ in range(self.num_samples // 32): |
|
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() |
|
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() |
|
else: |
|
for _ in range(self.num_samples // n): |
|
xx = torch.randperm(n, generator=generator).tolist() |
|
if self._pos_start >= n: |
|
self._pos_start = 0 |
|
print("xx top 10", xx[:10], self._pos_start) |
|
for idx in range(self._pos_start, n): |
|
yield xx[idx] |
|
self._pos_start = (self._pos_start + 1) % n |
|
self._pos_start = 0 |
|
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] |
|
|
|
def __len__(self) -> int: |
|
return self.num_samples |
|
|
|
class AspectRatioBatchImageSampler(BatchSampler): |
|
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
|
|
|
Args: |
|
sampler (Sampler): Base sampler. |
|
dataset (Dataset): Dataset providing data information. |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size``. |
|
aspect_ratios (dict): The predefined aspect ratios. |
|
""" |
|
def __init__( |
|
self, |
|
sampler: Sampler, |
|
dataset: Dataset, |
|
batch_size: int, |
|
train_folder: str = None, |
|
aspect_ratios: dict = ASPECT_RATIO_512, |
|
drop_last: bool = False, |
|
config=None, |
|
**kwargs |
|
) -> None: |
|
if not isinstance(sampler, Sampler): |
|
raise TypeError('sampler should be an instance of ``Sampler``, ' |
|
f'but got {sampler}') |
|
if not isinstance(batch_size, int) or batch_size <= 0: |
|
raise ValueError('batch_size should be a positive integer value, ' |
|
f'but got batch_size={batch_size}') |
|
self.sampler = sampler |
|
self.dataset = dataset |
|
self.train_folder = train_folder |
|
self.batch_size = batch_size |
|
self.aspect_ratios = aspect_ratios |
|
self.drop_last = drop_last |
|
self.config = config |
|
|
|
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} |
|
|
|
self.current_available_bucket_keys = list(aspect_ratios.keys()) |
|
|
|
def __iter__(self): |
|
for idx in self.sampler: |
|
try: |
|
image_dict = self.dataset[idx] |
|
|
|
width, height = image_dict.get("width", None), image_dict.get("height", None) |
|
if width is None or height is None: |
|
image_id, name = image_dict['file_path'], image_dict['text'] |
|
if self.train_folder is None: |
|
image_dir = image_id |
|
else: |
|
image_dir = os.path.join(self.train_folder, image_id) |
|
|
|
width, height = get_image_size_without_loading(image_dir) |
|
|
|
ratio = height / width |
|
else: |
|
height = int(height) |
|
width = int(width) |
|
ratio = height / width |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
|
if closest_ratio not in self.current_available_bucket_keys: |
|
continue |
|
bucket = self._aspect_ratio_buckets[closest_ratio] |
|
bucket.append(idx) |
|
|
|
if len(bucket) == self.batch_size: |
|
yield bucket[:] |
|
del bucket[:] |
|
|
|
class AspectRatioBatchSampler(BatchSampler): |
|
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
|
|
|
Args: |
|
sampler (Sampler): Base sampler. |
|
dataset (Dataset): Dataset providing data information. |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size``. |
|
aspect_ratios (dict): The predefined aspect ratios. |
|
""" |
|
def __init__( |
|
self, |
|
sampler: Sampler, |
|
dataset: Dataset, |
|
batch_size: int, |
|
video_folder: str = None, |
|
train_data_format: str = "webvid", |
|
aspect_ratios: dict = ASPECT_RATIO_512, |
|
drop_last: bool = False, |
|
config=None, |
|
**kwargs |
|
) -> None: |
|
if not isinstance(sampler, Sampler): |
|
raise TypeError('sampler should be an instance of ``Sampler``, ' |
|
f'but got {sampler}') |
|
if not isinstance(batch_size, int) or batch_size <= 0: |
|
raise ValueError('batch_size should be a positive integer value, ' |
|
f'but got batch_size={batch_size}') |
|
self.sampler = sampler |
|
self.dataset = dataset |
|
self.video_folder = video_folder |
|
self.train_data_format = train_data_format |
|
self.batch_size = batch_size |
|
self.aspect_ratios = aspect_ratios |
|
self.drop_last = drop_last |
|
self.config = config |
|
|
|
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} |
|
|
|
self.current_available_bucket_keys = list(aspect_ratios.keys()) |
|
|
|
def __iter__(self): |
|
for idx in self.sampler: |
|
try: |
|
video_dict = self.dataset[idx] |
|
width, more = video_dict.get("width", None), video_dict.get("height", None) |
|
|
|
if width is None or height is None: |
|
if self.train_data_format == "normal": |
|
video_id, name = video_dict['file_path'], video_dict['text'] |
|
if self.video_folder is None: |
|
video_dir = video_id |
|
else: |
|
video_dir = os.path.join(self.video_folder, video_id) |
|
else: |
|
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] |
|
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") |
|
cap = cv2.VideoCapture(video_dir) |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
ratio = height / width |
|
else: |
|
height = int(height) |
|
width = int(width) |
|
ratio = height / width |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
|
if closest_ratio not in self.current_available_bucket_keys: |
|
continue |
|
bucket = self._aspect_ratio_buckets[closest_ratio] |
|
bucket.append(idx) |
|
|
|
if len(bucket) == self.batch_size: |
|
yield bucket[:] |
|
del bucket[:] |
|
|
|
class AspectRatioBatchImageVideoSampler(BatchSampler): |
|
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch. |
|
|
|
Args: |
|
sampler (Sampler): Base sampler. |
|
dataset (Dataset): Dataset providing data information. |
|
batch_size (int): Size of mini-batch. |
|
drop_last (bool): If ``True``, the sampler will drop the last batch if |
|
its size would be less than ``batch_size``. |
|
aspect_ratios (dict): The predefined aspect ratios. |
|
""" |
|
|
|
def __init__(self, |
|
sampler: Sampler, |
|
dataset: Dataset, |
|
batch_size: int, |
|
train_folder: str = None, |
|
aspect_ratios: dict = ASPECT_RATIO_512, |
|
drop_last: bool = False |
|
) -> None: |
|
if not isinstance(sampler, Sampler): |
|
raise TypeError('sampler should be an instance of ``Sampler``, ' |
|
f'but got {sampler}') |
|
if not isinstance(batch_size, int) or batch_size <= 0: |
|
raise ValueError('batch_size should be a positive integer value, ' |
|
f'but got batch_size={batch_size}') |
|
self.sampler = sampler |
|
self.dataset = dataset |
|
self.train_folder = train_folder |
|
self.batch_size = batch_size |
|
self.aspect_ratios = aspect_ratios |
|
self.drop_last = drop_last |
|
|
|
|
|
self.current_available_bucket_keys = list(aspect_ratios.keys()) |
|
self.bucket = { |
|
'image':{ratio: [] for ratio in aspect_ratios}, |
|
'video':{ratio: [] for ratio in aspect_ratios} |
|
} |
|
|
|
def __iter__(self): |
|
for idx in self.sampler: |
|
content_type = self.dataset[idx].get('type', 'image') |
|
if content_type == 'image': |
|
try: |
|
image_dict = self.dataset[idx] |
|
|
|
width, height = image_dict.get("width", None), image_dict.get("height", None) |
|
if width is None or height is None: |
|
image_id, name = image_dict['file_path'], image_dict['text'] |
|
if self.train_folder is None: |
|
image_dir = image_id |
|
else: |
|
image_dir = os.path.join(self.train_folder, image_id) |
|
|
|
width, height = get_image_size_without_loading(image_dir) |
|
|
|
ratio = height / width |
|
else: |
|
height = int(height) |
|
width = int(width) |
|
ratio = height / width |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
|
if closest_ratio not in self.current_available_bucket_keys: |
|
continue |
|
bucket = self.bucket['image'][closest_ratio] |
|
bucket.append(idx) |
|
|
|
if len(bucket) == self.batch_size: |
|
yield bucket[:] |
|
del bucket[:] |
|
else: |
|
try: |
|
video_dict = self.dataset[idx] |
|
width, height = video_dict.get("width", None), video_dict.get("height", None) |
|
|
|
if width is None or height is None: |
|
video_id, name = video_dict['file_path'], video_dict['text'] |
|
if self.train_folder is None: |
|
video_dir = video_id |
|
else: |
|
video_dir = os.path.join(self.train_folder, video_id) |
|
cap = cv2.VideoCapture(video_dir) |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
|
ratio = height / width |
|
else: |
|
height = int(height) |
|
width = int(width) |
|
ratio = height / width |
|
except Exception as e: |
|
print(e) |
|
continue |
|
|
|
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) |
|
if closest_ratio not in self.current_available_bucket_keys: |
|
continue |
|
bucket = self.bucket['video'][closest_ratio] |
|
bucket.append(idx) |
|
|
|
if len(bucket) == self.batch_size: |
|
yield bucket[:] |
|
del bucket[:] |