import re from pathlib import Path import torch import math from einops import rearrange def load_checkpoint(path, device='cpu'): path = Path(path).expanduser() is_deepspeed = False if path.is_dir(): # DeepSpeed checkpoint is_deepspeed = True latest_path = path / 'latest' if latest_path.is_file(): with open(latest_path, 'r') as fd: tag = fd.read().strip() else: raise ValueError(f"Unable to find 'latest' file at {latest_path}") path /= f'{tag}/mp_rank_00_model_states.pt' state_dict = torch.load(path, map_location=device) if is_deepspeed: state_dict = state_dict['module'] # Replace the names of some of the submodules def key_mapping(key): return re.sub(r'^module.model.', '', key) state_dict = {key_mapping(k): v for k, v in state_dict.items()} return state_dict def blockdiag_to_dense_mlp_bert(state_dict): from src.ops.blockdiag_multiply import blockdiag_weight_to_dense_weight names = {name for name in state_dict if re.match('bert.encoder.layer.(\d+).(mlp.fc(1|2)|(intermediate|output).dense).weight', name)} for name in names: state_dict[name] = blockdiag_weight_to_dense_weight(state_dict[name]) return state_dict def interpolate_pos_embedding(state_dict, out_seqlen, pos_embedding_name='model.pos_encoder.pe', interleave=False): orig_emb = state_dict['state_dict'][pos_embedding_name] assert (out_seqlen % orig_emb.shape[1]) == 0, 'out_seqlen must be a multiple of the original sequence length' reps = [1 for i in orig_emb.shape] reps[1] = out_seqlen // orig_emb.shape[1] if interleave: assert math.isqrt(orig_emb.shape[1]) ** 2 == orig_emb.shape[1], 'interleave only works for square lengths' assert math.isqrt(out_seqlen) ** 2 == out_seqlen, 'interleave only works for square lengths' assert math.isqrt(reps[1]) ** 2 == reps[1], 'out_seqlen / seqlen must be a perfect square' emb_square = rearrange(orig_emb, 'b (h w) d -> b h w d', h = math.isqrt(orig_emb.shape[1])) emb_square_expanded = emb_square.repeat_interleave(math.isqrt(reps[1]), axis=1).repeat_interleave(math.isqrt(reps[1]), axis=2) new_emb = rearrange(emb_square_expanded, 'b h w d -> b (h w) d') state_dict['state_dict'][pos_embedding_name] = new_emb else: state_dict['state_dict'][pos_embedding_name] = orig_emb.repeat(*reps) ret = remove_model_prefix(state_dict) # # HACK: this is a hack for block-sparse flash attention ret = { k: v for k, v in ret.items() if not k.endswith('inner_attn.layout') } return ret def remove_model_prefix(state_dict): # HACK: this is a hack to get the model to load properly, get rid of 'model.' prefix for key in list(state_dict['state_dict'].keys()): if key.startswith('model.'): new_key = key[len('model.'):] state_dict['state_dict'][new_key] = state_dict['state_dict'].pop(key) # HACK: something is wrong with the state dict being loaded... return state_dict['state_dict']