File size: 2,303 Bytes
1870c14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch


class DeviceMap:
    __top_layer: str
    __device_map: dict
    __total_layers: int
    __layers: int

    def __init__(self, model=None):
        if model == "LLaMA":
            self.__top_layer = "model"
            self.__device_map = {
                "model.embed_tokens": 0,
                "model.norm": 0,
                "lm_head": 0,
            }
            self.__total_layers = 34
            self.__layers = 32

        elif model == "ChatGLM":
            self.__top_layer = "transformer"
            self.__device_map = {
                "transformer.word_embeddings": 0,
                "transformer.final_layernorm": 0,
                "lm_head": 0,
            }
            self.__total_layers = 30
            self.__layers = 28

        else:
            self.__top_layer = ""
            self.__device_map = {"": 0}
            self.__total_layers = 0
            self.__layers = 0

    def get(self):
        top_layer = self.__top_layer
        total_layers = self.__total_layers
        layers = self.__layers
        device_map = self.__device_map

        world_size = torch.cuda.device_count()

        free_gpu_mem = []
        for i in range(world_size):
            torch.cuda.set_device(i)
            free_gpu_mem.append(torch.cuda.mem_get_info()[0])

        min_id = min(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
        max_id = max(enumerate(free_gpu_mem), key=lambda x: x[1])[0]

        totol_mem = sum(free_gpu_mem)

        world_layers = {
            id: int(round(total_layers * (mem / totol_mem)))
            for id, mem in enumerate(free_gpu_mem)
        }

        diff = total_layers - sum(world_layers.values())
        world_layers[max_id if diff > 0 else min_id] += diff

        cnt = total_layers - layers
        gpu_id = 0

        for i in range(layers):
            if cnt < world_layers[gpu_id]:
                cnt += 1
            else:
                gpu_id += 1
                cnt = 1
            device_map[f"{top_layer}.layers.{i}"] = gpu_id

        return device_map

    def peft(self):
        prefix = "base_model.model"
        device_map = self.get()
        perf_device_map = {"": 0}
        for k, v in device_map.items():
            perf_device_map[f"{prefix}.{k}"] = v
        return perf_device_map