Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
from torchmetrics import Metric, Accuracy | |
class AccuracyMine(Accuracy): | |
"""Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. | |
""" | |
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | |
super().update(preds, target.argmax(dim=-1) if target.is_floating_point() else target) | |