Spaces:
Build error
Build error
File size: 2,303 Bytes
1870c14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
class DeviceMap:
__top_layer: str
__device_map: dict
__total_layers: int
__layers: int
def __init__(self, model=None):
if model == "LLaMA":
self.__top_layer = "model"
self.__device_map = {
"model.embed_tokens": 0,
"model.norm": 0,
"lm_head": 0,
}
self.__total_layers = 34
self.__layers = 32
elif model == "ChatGLM":
self.__top_layer = "transformer"
self.__device_map = {
"transformer.word_embeddings": 0,
"transformer.final_layernorm": 0,
"lm_head": 0,
}
self.__total_layers = 30
self.__layers = 28
else:
self.__top_layer = ""
self.__device_map = {"": 0}
self.__total_layers = 0
self.__layers = 0
def get(self):
top_layer = self.__top_layer
total_layers = self.__total_layers
layers = self.__layers
device_map = self.__device_map
world_size = torch.cuda.device_count()
free_gpu_mem = []
for i in range(world_size):
torch.cuda.set_device(i)
free_gpu_mem.append(torch.cuda.mem_get_info()[0])
min_id = min(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
max_id = max(enumerate(free_gpu_mem), key=lambda x: x[1])[0]
totol_mem = sum(free_gpu_mem)
world_layers = {
id: int(round(total_layers * (mem / totol_mem)))
for id, mem in enumerate(free_gpu_mem)
}
diff = total_layers - sum(world_layers.values())
world_layers[max_id if diff > 0 else min_id] += diff
cnt = total_layers - layers
gpu_id = 0
for i in range(layers):
if cnt < world_layers[gpu_id]:
cnt += 1
else:
gpu_id += 1
cnt = 1
device_map[f"{top_layer}.layers.{i}"] = gpu_id
return device_map
def peft(self):
prefix = "base_model.model"
device_map = self.get()
perf_device_map = {"": 0}
for k, v in device_map.items():
perf_device_map[f"{prefix}.{k}"] = v
return perf_device_map |