glenn-jocher commited on
Commit
4de8b24
1 Parent(s): de9c25b

Suppress `torch` AMP-CPU warnings (#6706)

Browse files

This is a torch bug, but they seem unable or unwilling to fix it so I'm creating a suppression in YOLOv5.

Resolves https://github.com/ultralytics/yolov5/issues/6692

Files changed (1) hide show
  1. utils/torch_utils.py +7 -7
utils/torch_utils.py CHANGED
@@ -9,6 +9,7 @@ import os
9
  import platform
10
  import subprocess
11
  import time
 
12
  from contextlib import contextmanager
13
  from copy import deepcopy
14
  from pathlib import Path
@@ -25,6 +26,9 @@ try:
25
  except ImportError:
26
  thop = None
27
 
 
 
 
28
 
29
  @contextmanager
30
  def torch_distributed_zero_first(local_rank: int):
@@ -293,13 +297,9 @@ class EarlyStopping:
293
 
294
 
295
  class ModelEMA:
296
- """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
297
- Keep a moving average of everything in the model state_dict (parameters and buffers).
298
- This is intended to allow functionality like
299
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
300
- A smoothed version of the weights is necessary for some training schemes to perform well.
301
- This class is sensitive where it is initialized in the sequence of model init,
302
- GPU assignment and distributed training wrappers.
303
  """
304
 
305
  def __init__(self, model, decay=0.9999, updates=0):
 
9
  import platform
10
  import subprocess
11
  import time
12
+ import warnings
13
  from contextlib import contextmanager
14
  from copy import deepcopy
15
  from pathlib import Path
 
26
  except ImportError:
27
  thop = None
28
 
29
+ # Suppress PyTorch warnings
30
+ warnings.filterwarnings('ignore', message='User provided device_type of \'cuda\', but CUDA is not available. Disabling')
31
+
32
 
33
  @contextmanager
34
  def torch_distributed_zero_first(local_rank: int):
 
297
 
298
 
299
  class ModelEMA:
300
+ """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
301
+ Keeps a moving average of everything in the model state_dict (parameters and buffers)
302
+ For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
 
 
 
 
303
  """
304
 
305
  def __init__(self, model, decay=0.9999, updates=0):