File size: 2,251 Bytes
c19ca42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from collections import defaultdict
import torch


class MemUsageMonitor():
    device = None
    disabled = False
    opts = None
    data = None

    def __init__(self, name, device):
        self.name = name
        self.device = device
        self.data = defaultdict(int)
        if not torch.cuda.is_available():
            self.disabled = True
        else:
            try:
                torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
                torch.cuda.memory_stats(self.device)
            except Exception:
                self.disabled = True

    def cuda_mem_get_info(self): # legacy for extensions only
        if self.disabled:
            return 0, 0
        return torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())

    def reset(self):
        if not self.disabled:
            try:
                torch.cuda.reset_peak_memory_stats(self.device)
                self.data['retries'] = 0
                self.data['oom'] = 0
                # torch.cuda.reset_accumulated_memory_stats(self.device)
                # torch.cuda.reset_max_memory_allocated(self.device)
                # torch.cuda.reset_max_memory_cached(self.device)
            except Exception:
                pass

    def read(self):
        if not self.disabled:
            try:
                self.data["free"], self.data["total"] = torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device())
                torch_stats = torch.cuda.memory_stats(self.device)
                self.data["active"] = torch_stats["active.all.current"]
                self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
                self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
                self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
                self.data['retries'] = torch_stats["num_alloc_retries"]
                self.data['oom'] = torch_stats["num_ooms"]
                self.data["used"] = self.data["total"] - self.data["free"]
            except Exception:
                self.disabled = True
        return self.data