File size: 5,004 Bytes
d2b7e94
02e90e4
d2b7e94
 
02e90e4
 
d2b7e94
02e90e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5458aa
02e90e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bed01bd
 
 
 
02e90e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21f455c
02e90e4
 
 
 
 
 
 
 
374f426
 
 
 
 
 
 
bed01bd
 
 
d5d0921
 
 
 
 
 
1df74c6
02e90e4
 
 
 
 
 
 
374f426
 
 
 
 
 
 
 
02e90e4
d5d0921
374f426
 
02e90e4
 
374f426
02e90e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import logging
import sys
from functools import lru_cache

import torch

from modules import config

logger = logging.getLogger(__name__)

if sys.platform == "darwin":
    from modules.devices import mac_devices


def has_mps() -> bool:
    if sys.platform != "darwin":
        return False
    else:
        return mac_devices.has_mps


def get_cuda_device_id():
    return (
        int(config.runtime_env_vars.device_id)
        if config.runtime_env_vars.device_id is not None
        and config.runtime_env_vars.device_id.isdigit()
        else 0
    ) or torch.cuda.current_device()


def get_cuda_device_string():
    if config.runtime_env_vars.device_id is not None:
        return f"cuda:{config.runtime_env_vars.device_id}"

    return "cuda"


def get_available_gpus() -> list[tuple[int, int]]:
    """
    Get the list of available GPUs and their free memory.

    :return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
    """
    available_gpus = []
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        free_memory = props.total_memory - torch.cuda.memory_reserved(i)
        available_gpus.append((i, free_memory))
    return available_gpus


def get_memory_available_gpus(min_memory=2048):
    available_gpus = get_available_gpus()
    memory_available_gpus = [
        gpu for gpu, free_memory in available_gpus if free_memory > min_memory
    ]
    return memory_available_gpus


def get_target_device_id_or_memory_available_gpu():
    memory_available_gpus = get_memory_available_gpus()
    device_id = get_cuda_device_id()
    if device_id not in memory_available_gpus:
        if len(memory_available_gpus) != 0:
            logger.warning(
                f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
            )
            config.runtime_env_vars.device_id = str(memory_available_gpus[0])
        else:
            logger.warning(
                f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
            )
            return "cpu"
    return get_cuda_device_string()


def get_optimal_device_name():
    if config.runtime_env_vars.use_cpu == "all":
        return "cpu"

    if torch.cuda.is_available():
        return get_target_device_id_or_memory_available_gpu()

    if has_mps():
        return "mps"

    return "cpu"


def get_optimal_device():
    return torch.device(get_optimal_device_name())


def get_device_for(task):
    if (
        task in config.runtime_env_vars.use_cpu
        or "all" in config.runtime_env_vars.use_cpu
    ):
        return cpu

    return get_optimal_device()


def torch_gc():
    try:
        if torch.cuda.is_available():
            with torch.cuda.device(get_cuda_device_string()):
                torch.cuda.empty_cache()
                torch.cuda.ipc_collect()

        if has_mps():
            mac_devices.torch_mps_gc()
    except Exception as e:
        logger.error(f"Error in torch_gc", exc_info=True)


cpu: torch.device = torch.device("cpu")
device: torch.device = None
dtype: torch.dtype = torch.float32
dtype_dvae: torch.dtype = torch.float32
dtype_vocos: torch.dtype = torch.float32
dtype_gpt: torch.dtype = torch.float32
dtype_decoder: torch.dtype = torch.float32


def reset_device():
    global device
    global dtype
    global dtype_dvae
    global dtype_vocos
    global dtype_gpt
    global dtype_decoder

    if config.runtime_env_vars.use_cpu is None:
        config.runtime_env_vars.use_cpu = []

    if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
        logger.warning(
            "Cannot use half precision with CPU, using full precision instead"
        )
        config.runtime_env_vars.no_half = True

    if not config.runtime_env_vars.no_half:
        dtype = torch.float16
        dtype_dvae = torch.float16
        dtype_vocos = torch.float16
        dtype_gpt = torch.float16
        dtype_decoder = torch.float16

        logger.info("Using half precision: torch.float16")
    else:
        dtype = torch.float32
        dtype_dvae = torch.float32
        dtype_vocos = torch.float32
        dtype_gpt = torch.float32
        dtype_decoder = torch.float32

        logger.info("Using full precision: torch.float32")

    if "all" in config.runtime_env_vars.use_cpu:
        device = cpu
    else:
        device = get_optimal_device()

    logger.info(f"Using device: {device}")


@lru_cache
def first_time_calculation():
    """
    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
    spends about 2.7 seconds doing that, at least wih NVidia.
    """

    x = torch.zeros((1, 1)).to(device, dtype)
    linear = torch.nn.Linear(1, 1).to(device, dtype)
    linear(x)

    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
    conv2d(x)