Spaces:
Runtime error
Runtime error
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
# International Conference on Computer Vision (ICCV), 2023 | |
import copy | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as F | |
from src.efficientvit.models.utils import torch_random_choices | |
__all__ = [ | |
"RRSController", | |
"get_interpolate", | |
"MyRandomResizedCrop", | |
] | |
class RRSController: | |
ACTIVE_SIZE = (224, 224) | |
IMAGE_SIZE_LIST = [(224, 224)] | |
CHOICE_LIST = None | |
def get_candidates() -> list[tuple[int, int]]: | |
return copy.deepcopy(RRSController.IMAGE_SIZE_LIST) | |
def sample_resolution(batch_id: int) -> None: | |
RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id] | |
def set_epoch(epoch: int, batch_per_epoch: int) -> None: | |
g = torch.Generator() | |
g.manual_seed(epoch) | |
RRSController.CHOICE_LIST = torch_random_choices( | |
RRSController.get_candidates(), | |
g, | |
batch_per_epoch, | |
) | |
def get_interpolate(name: str) -> F.InterpolationMode: | |
mapping = { | |
"nearest": F.InterpolationMode.NEAREST, | |
"bilinear": F.InterpolationMode.BILINEAR, | |
"bicubic": F.InterpolationMode.BICUBIC, | |
"box": F.InterpolationMode.BOX, | |
"hamming": F.InterpolationMode.HAMMING, | |
"lanczos": F.InterpolationMode.LANCZOS, | |
} | |
if name in mapping: | |
return mapping[name] | |
elif name == "random": | |
return torch_random_choices( | |
[ | |
F.InterpolationMode.NEAREST, | |
F.InterpolationMode.BILINEAR, | |
F.InterpolationMode.BICUBIC, | |
F.InterpolationMode.BOX, | |
F.InterpolationMode.HAMMING, | |
F.InterpolationMode.LANCZOS, | |
], | |
) | |
else: | |
raise NotImplementedError | |
class MyRandomResizedCrop(transforms.RandomResizedCrop): | |
def __init__( | |
self, | |
scale=(0.08, 1.0), | |
ratio=(3.0 / 4.0, 4.0 / 3.0), | |
interpolation: str = "random", | |
): | |
super(MyRandomResizedCrop, self).__init__(224, scale, ratio) | |
self.interpolation = interpolation | |
def forward(self, img: torch.Tensor) -> torch.Tensor: | |
i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio)) | |
target_size = RRSController.ACTIVE_SIZE | |
return F.resized_crop( | |
img, i, j, h, w, list(target_size), get_interpolate(self.interpolation) | |
) | |
def __repr__(self) -> str: | |
format_string = self.__class__.__name__ | |
format_string += f"(\n\tsize={RRSController.get_candidates()},\n" | |
format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n" | |
format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n" | |
format_string += f"\tinterpolation={self.interpolation})" | |
return format_string | |