lIlBrother commited on
Commit
c251804
β€’
1 Parent(s): f8d0c20

Update: model

Browse files
config.json CHANGED
@@ -24,7 +24,7 @@
24
  "position_buckets": 256,
25
  "relative_attention": true,
26
  "share_att_key": true,
27
- "transformers_version": "4.37.2",
28
  "type_vocab_size": 0,
29
  "vocab_size": 64100
30
  }
 
24
  "position_buckets": 256,
25
  "relative_attention": true,
26
  "share_att_key": true,
27
+ "transformers_version": "4.38.2",
28
  "type_vocab_size": 0,
29
  "vocab_size": 64100
30
  }
{global_step23940 β†’ global_step165430}/mp_rank_00_model_states.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7ad98480a5aab0db71e19ce24c4d3be4333da5cb1955de706eb15ed4018d8113
3
- size 1077570732
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b04e95e0d0b4bb47fea41cd9de0af4ddfe4a9e3350ec64d4c1f8b77de9f9541
3
+ size 1077570796
{global_step23940 β†’ global_step165430}/zero_pp_rank_0_mp_rank_00_optim_states.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0139708571b55035fa4fe9a3f7be61ebaed589a1eb65bf985b3125b2e357d6ed
3
  size 808085192
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fe7887eab40a10ededf92b4403561338b9b7d9e0269623913f2dcbada17cc00
3
  size 808085192
{global_step23940 β†’ global_step165430}/zero_pp_rank_1_mp_rank_00_optim_states.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c7ce8e14c3ef21ed302708f1f5b59b4f9332a2b92de0e8a06387fe6d7a7ef4ca
3
  size 808095752
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9046b582e4329869f31dc4562371217f8ee62dd267816ac7ba57990bb808be1e
3
  size 808095752
{global_step23940 β†’ global_step165430}/zero_pp_rank_2_mp_rank_00_optim_states.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bb94c45e8587eb0fc9d0e120933c4f24d7efe1d058c6cee66969aadcd02266d8
3
  size 808085064
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:091d4c4fd84894549e6e66fadb7699e8c700d6e532ebca6feba502eda8d5d047
3
  size 808085064
{global_step23940 β†’ global_step165430}/zero_pp_rank_3_mp_rank_00_optim_states.pt RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4cf626224645155eb8a8e53724c0d07abeb725703096326e9b357ac1518d8080
3
  size 808095496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1470545a59094d545cbc35cc96811ce3765d7a602fc6f93739adf6295b514415
3
  size 808095496
latest CHANGED
@@ -1 +1 @@
1
- global_step24282
 
1
+ global_step165430
zero_to_fp32.py CHANGED
@@ -24,9 +24,18 @@ from dataclasses import dataclass
24
  # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
  # DeepSpeed data structures it has to be available in the current python environment.
26
  from deepspeed.utils import logger
27
- from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
28
- FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
29
- FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  @dataclass
@@ -42,7 +51,7 @@ class zero_model_state:
42
  debug = 0
43
 
44
  # load to cpu
45
- device = torch.device('cpu')
46
 
47
 
48
  def atoi(text):
@@ -50,12 +59,12 @@ def atoi(text):
50
 
51
 
52
  def natural_keys(text):
53
- '''
54
  alist.sort(key=natural_keys) sorts in human order
55
  http://nedbatchelder.com/blog/200712/human_sorting.html
56
  (See Toothy's implementation in the comments)
57
- '''
58
- return [atoi(c) for c in re.split(r'(\d+)', text)]
59
 
60
 
61
  def get_model_state_file(checkpoint_dir, zero_stage):
@@ -127,12 +136,14 @@ def parse_model_states(files):
127
 
128
  frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
129
 
130
- z_model_state = zero_model_state(buffers=buffers,
131
- param_shapes=param_shapes,
132
- shared_params=shared_params,
133
- ds_version=ds_version,
134
- frozen_param_shapes=frozen_param_shapes,
135
- frozen_param_fragments=frozen_param_fragments)
 
 
136
  zero_model_states.append(z_model_state)
137
 
138
  return zero_model_states
@@ -208,7 +219,7 @@ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
208
  model_files = get_model_state_files(ds_checkpoint_dir)
209
 
210
  zero_model_states = parse_model_states(model_files)
211
- print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
212
 
213
  if zero_stage <= 2:
214
  return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
@@ -225,13 +236,13 @@ def _zero2_merge_frozen_params(state_dict, zero_model_states):
225
 
226
  if debug:
227
  num_elem = sum(s.numel() for s in frozen_param_shapes.values())
228
- print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
229
 
230
  wanted_params = len(frozen_param_shapes)
231
  wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
232
  avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
233
- print(f'Frozen params: Have {avail_numel} numels to process.')
234
- print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
235
 
236
  total_params = 0
237
  total_numel = 0
@@ -273,7 +284,8 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
273
  full_single_fp32_vector = torch.cat(merged_partitions, 0)
274
  merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
275
  avail_numel = sum(
276
- [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
 
277
 
278
  if debug:
279
  wanted_params = sum([len(shapes) for shapes in param_shapes])
@@ -292,7 +304,7 @@ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
292
  avail_numel = full_single_fp32_vector.numel()
293
  for name, shape in shapes.items():
294
 
295
- unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
296
  total_numel += unpartitioned_numel
297
  total_params += 1
298
 
@@ -361,14 +373,14 @@ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
361
  if debug:
362
  for i in range(world_size):
363
  num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
364
- print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
365
 
366
  frozen_param_shapes = zero_model_states[0].frozen_param_shapes
367
  wanted_params = len(frozen_param_shapes)
368
  wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
369
  avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
370
- print(f'Frozen params: Have {avail_numel} numels to process.')
371
- print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
372
 
373
  total_params = 0
374
  total_numel = 0
@@ -430,9 +442,11 @@ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero
430
  )
431
 
432
  # XXX: memory usage doubles here
433
- state_dict[name] = torch.cat(
434
- tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
435
- 0).narrow(0, 0, unpartitioned_numel).view(shape)
 
 
436
  offset += partitioned_numel
437
 
438
  offset *= world_size
@@ -499,9 +513,9 @@ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
499
 
500
  """
501
  if tag is None:
502
- latest_path = os.path.join(checkpoint_dir, 'latest')
503
  if os.path.isfile(latest_path):
504
- with open(latest_path, 'r') as fd:
505
  tag = fd.read().strip()
506
  else:
507
  raise ValueError(f"Unable to find 'latest' file at {latest_path}")
@@ -572,19 +586,22 @@ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
572
  if __name__ == "__main__":
573
 
574
  parser = argparse.ArgumentParser()
575
- parser.add_argument("checkpoint_dir",
576
- type=str,
577
- help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
578
  parser.add_argument(
579
  "output_file",
580
  type=str,
581
- help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
582
- parser.add_argument("-t",
583
- "--tag",
584
- type=str,
585
- default=None,
586
- help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
587
- parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
 
 
 
588
  args = parser.parse_args()
589
 
590
  debug = args.debug
 
24
  # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
25
  # DeepSpeed data structures it has to be available in the current python environment.
26
  from deepspeed.utils import logger
27
+ from deepspeed.checkpoint.constants import (
28
+ DS_VERSION,
29
+ OPTIMIZER_STATE_DICT,
30
+ SINGLE_PARTITION_OF_FP32_GROUPS,
31
+ FP32_FLAT_GROUPS,
32
+ ZERO_STAGE,
33
+ PARTITION_COUNT,
34
+ PARAM_SHAPES,
35
+ BUFFER_NAMES,
36
+ FROZEN_PARAM_SHAPES,
37
+ FROZEN_PARAM_FRAGMENTS,
38
+ )
39
 
40
 
41
  @dataclass
 
51
  debug = 0
52
 
53
  # load to cpu
54
+ device = torch.device("cpu")
55
 
56
 
57
  def atoi(text):
 
59
 
60
 
61
  def natural_keys(text):
62
+ """
63
  alist.sort(key=natural_keys) sorts in human order
64
  http://nedbatchelder.com/blog/200712/human_sorting.html
65
  (See Toothy's implementation in the comments)
66
+ """
67
+ return [atoi(c) for c in re.split(r"(\d+)", text)]
68
 
69
 
70
  def get_model_state_file(checkpoint_dir, zero_stage):
 
136
 
137
  frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
138
 
139
+ z_model_state = zero_model_state(
140
+ buffers=buffers,
141
+ param_shapes=param_shapes,
142
+ shared_params=shared_params,
143
+ ds_version=ds_version,
144
+ frozen_param_shapes=frozen_param_shapes,
145
+ frozen_param_fragments=frozen_param_fragments,
146
+ )
147
  zero_model_states.append(z_model_state)
148
 
149
  return zero_model_states
 
219
  model_files = get_model_state_files(ds_checkpoint_dir)
220
 
221
  zero_model_states = parse_model_states(model_files)
222
+ print(f"Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}")
223
 
224
  if zero_stage <= 2:
225
  return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
 
236
 
237
  if debug:
238
  num_elem = sum(s.numel() for s in frozen_param_shapes.values())
239
+ print(f"rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}")
240
 
241
  wanted_params = len(frozen_param_shapes)
242
  wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
243
  avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
244
+ print(f"Frozen params: Have {avail_numel} numels to process.")
245
+ print(f"Frozen params: Need {wanted_numel} numels in {wanted_params} params")
246
 
247
  total_params = 0
248
  total_numel = 0
 
284
  full_single_fp32_vector = torch.cat(merged_partitions, 0)
285
  merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
286
  avail_numel = sum(
287
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]
288
+ )
289
 
290
  if debug:
291
  wanted_params = sum([len(shapes) for shapes in param_shapes])
 
304
  avail_numel = full_single_fp32_vector.numel()
305
  for name, shape in shapes.items():
306
 
307
+ unpartitioned_numel = shape.numel() if _has_callable(shape, "numel") else math.prod(shape)
308
  total_numel += unpartitioned_numel
309
  total_params += 1
310
 
 
373
  if debug:
374
  for i in range(world_size):
375
  num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
376
+ print(f"rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}")
377
 
378
  frozen_param_shapes = zero_model_states[0].frozen_param_shapes
379
  wanted_params = len(frozen_param_shapes)
380
  wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
381
  avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
382
+ print(f"Frozen params: Have {avail_numel} numels to process.")
383
+ print(f"Frozen params: Need {wanted_numel} numels in {wanted_params} params")
384
 
385
  total_params = 0
386
  total_numel = 0
 
442
  )
443
 
444
  # XXX: memory usage doubles here
445
+ state_dict[name] = (
446
+ torch.cat(tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), 0)
447
+ .narrow(0, 0, unpartitioned_numel)
448
+ .view(shape)
449
+ )
450
  offset += partitioned_numel
451
 
452
  offset *= world_size
 
513
 
514
  """
515
  if tag is None:
516
+ latest_path = os.path.join(checkpoint_dir, "latest")
517
  if os.path.isfile(latest_path):
518
+ with open(latest_path, "r") as fd:
519
  tag = fd.read().strip()
520
  else:
521
  raise ValueError(f"Unable to find 'latest' file at {latest_path}")
 
586
  if __name__ == "__main__":
587
 
588
  parser = argparse.ArgumentParser()
589
+ parser.add_argument(
590
+ "checkpoint_dir", type=str, help="path to the desired checkpoint folder, e.g., path/checkpoint-12"
591
+ )
592
  parser.add_argument(
593
  "output_file",
594
  type=str,
595
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)",
596
+ )
597
+ parser.add_argument(
598
+ "-t",
599
+ "--tag",
600
+ type=str,
601
+ default=None,
602
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1",
603
+ )
604
+ parser.add_argument("-d", "--debug", action="store_true", help="enable debug")
605
  args = parser.parse_args()
606
 
607
  debug = args.debug