Fucius's picture
Upload 52 files
ad5354d verified
raw
history blame
3 kB
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023
import copy
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from src.efficientvit.models.utils import torch_random_choices
__all__ = [
"RRSController",
"get_interpolate",
"MyRandomResizedCrop",
]
class RRSController:
ACTIVE_SIZE = (224, 224)
IMAGE_SIZE_LIST = [(224, 224)]
CHOICE_LIST = None
@staticmethod
def get_candidates() -> list[tuple[int, int]]:
return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
@staticmethod
def sample_resolution(batch_id: int) -> None:
RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
@staticmethod
def set_epoch(epoch: int, batch_per_epoch: int) -> None:
g = torch.Generator()
g.manual_seed(epoch)
RRSController.CHOICE_LIST = torch_random_choices(
RRSController.get_candidates(),
g,
batch_per_epoch,
)
def get_interpolate(name: str) -> F.InterpolationMode:
mapping = {
"nearest": F.InterpolationMode.NEAREST,
"bilinear": F.InterpolationMode.BILINEAR,
"bicubic": F.InterpolationMode.BICUBIC,
"box": F.InterpolationMode.BOX,
"hamming": F.InterpolationMode.HAMMING,
"lanczos": F.InterpolationMode.LANCZOS,
}
if name in mapping:
return mapping[name]
elif name == "random":
return torch_random_choices(
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
F.InterpolationMode.BOX,
F.InterpolationMode.HAMMING,
F.InterpolationMode.LANCZOS,
],
)
else:
raise NotImplementedError
class MyRandomResizedCrop(transforms.RandomResizedCrop):
def __init__(
self,
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
interpolation: str = "random",
):
super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
self.interpolation = interpolation
def forward(self, img: torch.Tensor) -> torch.Tensor:
i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
target_size = RRSController.ACTIVE_SIZE
return F.resized_crop(
img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
)
def __repr__(self) -> str:
format_string = self.__class__.__name__
format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
format_string += f"\tinterpolation={self.interpolation})"
return format_string