File size: 536 Bytes
7ee3434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch


def check_nan(logger, loss, y_pred, y_gt):
    if torch.any(torch.isnan(loss)):
        logger.info("out has nan: ", torch.any(torch.isnan(y_pred)))
        logger.info("y_gt has nan: ", torch.any(torch.isnan(y_gt)))
        logger.info("out: ", y_pred)
        logger.info("y_gt: ", y_gt)
        logger.info("loss = {:.4f}\n".format(loss.item()))
        exit()