Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
import torch
from timm.data import Mixup
from timm.data.mixup import mixup_target
class TimmMixup(Mixup):
""" Wrap timm.data.Mixup that avoids the assert that batch size must be even.
"""
def __call__(self, x, target):
if self.mode == 'elem':
lam = self._mix_elem(x)
elif self.mode == 'pair':
# We move the assert from the beginning of the function to here
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target