Spaces:
Sleeping
Sleeping
import torch | |
from timm.data import Mixup | |
from timm.data.mixup import mixup_target | |
class TimmMixup(Mixup): | |
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even. | |
""" | |
def __call__(self, x, target): | |
if self.mode == 'elem': | |
lam = self._mix_elem(x) | |
elif self.mode == 'pair': | |
# We move the assert from the beginning of the function to here | |
assert len(x) % 2 == 0, 'Batch size should be even when using this' | |
lam = self._mix_pair(x) | |
else: | |
lam = self._mix_batch(x) | |
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) | |
return x, target | |