Do0rMaMu's picture
Upload folder using huggingface_hub
e45d058 verified
import torch
from torch import Tensor
from torchmetrics import Metric, Accuracy
class AccuracyMine(Accuracy):
"""Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup.
"""
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target)