Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import ast | |
import collections | |
import contextlib | |
import inspect | |
import logging | |
import os | |
import re | |
import time | |
import traceback | |
from collections import OrderedDict | |
from pathlib import Path | |
from typing import Any, Dict, Optional, Union | |
import numpy as np | |
import torch | |
from fairseq.data import data_utils | |
from fairseq.dataclass.configs import CheckpointConfig | |
from fairseq.dataclass.utils import ( | |
convert_namespace_to_omegaconf, | |
overwrite_args_by_name, | |
) | |
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP | |
from fairseq.file_io import PathManager | |
from fairseq.models import FairseqDecoder, FairseqEncoder | |
from omegaconf import DictConfig, OmegaConf, open_dict | |
logger = logging.getLogger(__name__) | |
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): | |
from fairseq import meters | |
# only one worker should attempt to create the required dir | |
if trainer.data_parallel_rank == 0: | |
os.makedirs(cfg.save_dir, exist_ok=True) | |
prev_best = getattr(save_checkpoint, "best", val_loss) | |
if val_loss is not None: | |
best_function = max if cfg.maximize_best_checkpoint_metric else min | |
save_checkpoint.best = best_function(val_loss, prev_best) | |
if cfg.no_save: | |
return | |
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state | |
if not trainer.should_save_checkpoint_on_current_rank: | |
if trainer.always_call_state_dict_during_save_checkpoint: | |
trainer.state_dict() | |
return | |
write_timer = meters.StopwatchMeter() | |
write_timer.start() | |
epoch = epoch_itr.epoch | |
end_of_epoch = epoch_itr.end_of_epoch() | |
updates = trainer.get_num_updates() | |
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") | |
def is_better(a, b): | |
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b | |
suffix = trainer.checkpoint_suffix | |
checkpoint_conds = collections.OrderedDict() | |
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( | |
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0 | |
) | |
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( | |
not end_of_epoch | |
and cfg.save_interval_updates > 0 | |
and updates % cfg.save_interval_updates == 0 | |
) | |
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( | |
not hasattr(save_checkpoint, "best") | |
or is_better(val_loss, save_checkpoint.best) | |
) | |
if val_loss is not None and cfg.keep_best_checkpoints > 0: | |
worst_best = getattr(save_checkpoint, "best", None) | |
chkpts = checkpoint_paths( | |
cfg.save_dir, | |
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( | |
cfg.best_checkpoint_metric, suffix | |
), | |
) | |
if len(chkpts) > 0: | |
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0] | |
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), "")) | |
# add random digits to resolve ties | |
with data_utils.numpy_seed(epoch, updates, val_loss): | |
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints) | |
checkpoint_conds[ | |
"checkpoint.best_{}_{:.3f}{}{}.pt".format( | |
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix | |
) | |
] = worst_best is None or is_better(val_loss, worst_best) | |
checkpoint_conds[ | |
"checkpoint_last{}.pt".format(suffix) | |
] = not cfg.no_last_checkpoints | |
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} | |
if hasattr(save_checkpoint, "best"): | |
extra_state.update({"best": save_checkpoint.best}) | |
checkpoints = [ | |
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond | |
] | |
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank: | |
trainer.save_checkpoint(checkpoints[0], extra_state) | |
for cp in checkpoints[1:]: | |
if cfg.write_checkpoints_asynchronously: | |
# TODO[ioPath]: Need to implement a delayed asynchronous | |
# file copying/moving feature. | |
logger.warning( | |
f"ioPath is not copying {checkpoints[0]} to {cp} " | |
"since async write mode is on." | |
) | |
else: | |
assert PathManager.copy( | |
checkpoints[0], cp, overwrite=True | |
), f"Failed to copy {checkpoints[0]} to {cp}" | |
write_timer.stop() | |
logger.info( | |
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( | |
checkpoints[0], epoch, updates, val_loss, write_timer.sum | |
) | |
) | |
if not end_of_epoch and cfg.keep_interval_updates > 0: | |
# remove old checkpoints; checkpoints are sorted in descending order | |
if cfg.keep_interval_updates_pattern == -1: | |
checkpoints = checkpoint_paths( | |
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix) | |
) | |
else: | |
checkpoints = checkpoint_paths( | |
cfg.save_dir, | |
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix), | |
keep_match=True, | |
) | |
checkpoints = [ | |
x[0] | |
for x in checkpoints | |
if x[1] % cfg.keep_interval_updates_pattern != 0 | |
] | |
for old_chk in checkpoints[cfg.keep_interval_updates :]: | |
if os.path.lexists(old_chk): | |
os.remove(old_chk) | |
elif PathManager.exists(old_chk): | |
PathManager.rm(old_chk) | |
if cfg.keep_last_epochs > 0: | |
# remove old epoch checkpoints; checkpoints are sorted in descending order | |
checkpoints = checkpoint_paths( | |
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix) | |
) | |
for old_chk in checkpoints[cfg.keep_last_epochs :]: | |
if os.path.lexists(old_chk): | |
os.remove(old_chk) | |
elif PathManager.exists(old_chk): | |
PathManager.rm(old_chk) | |
if cfg.keep_best_checkpoints > 0: | |
# only keep the best N checkpoints according to validation metric | |
checkpoints = checkpoint_paths( | |
cfg.save_dir, | |
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format( | |
cfg.best_checkpoint_metric, suffix | |
), | |
) | |
if not cfg.maximize_best_checkpoint_metric: | |
checkpoints = checkpoints[::-1] | |
for old_chk in checkpoints[cfg.keep_best_checkpoints :]: | |
if os.path.lexists(old_chk): | |
os.remove(old_chk) | |
elif PathManager.exists(old_chk): | |
PathManager.rm(old_chk) | |
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args): | |
""" | |
Load a checkpoint and restore the training iterator. | |
*passthrough_args* will be passed through to | |
``trainer.get_train_iterator``. | |
""" | |
reset_optimizer = cfg.reset_optimizer | |
reset_lr_scheduler = cfg.reset_lr_scheduler | |
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides) | |
reset_meters = cfg.reset_meters | |
reset_dataloader = cfg.reset_dataloader | |
if cfg.finetune_from_model is not None and ( | |
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader | |
): | |
raise ValueError( | |
"--finetune-from-model can not be set together with either --reset-optimizer" | |
" or reset_lr_scheduler or reset_meters or reset_dataloader" | |
) | |
suffix = trainer.checkpoint_suffix | |
if ( | |
cfg.restore_file == "checkpoint_last.pt" | |
): # default value of restore_file is 'checkpoint_last.pt' | |
checkpoint_path = os.path.join( | |
cfg.save_dir, "checkpoint_last{}.pt".format(suffix) | |
) | |
first_launch = not PathManager.exists(checkpoint_path) | |
if first_launch and getattr(cfg, "continue_once", None) is not None: | |
checkpoint_path = cfg.continue_once | |
elif cfg.finetune_from_model is not None and first_launch: | |
# if there is no last checkpoint to restore, start the finetune from pretrained model | |
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. | |
if PathManager.exists(cfg.finetune_from_model): | |
checkpoint_path = cfg.finetune_from_model | |
reset_optimizer = True | |
reset_lr_scheduler = True | |
reset_meters = True | |
reset_dataloader = True | |
logger.info( | |
f"loading pretrained model from {checkpoint_path}: " | |
"optimizer, lr scheduler, meters, dataloader will be reset" | |
) | |
else: | |
raise ValueError( | |
f"--finetune-from-model {cfg.finetune_from_model} does not exist" | |
) | |
elif suffix is not None: | |
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt") | |
else: | |
checkpoint_path = cfg.restore_file | |
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model: | |
raise ValueError( | |
"--finetune-from-model and --restore-file (non-default value) " | |
"can not be specified together: " + str(cfg) | |
) | |
extra_state = trainer.load_checkpoint( | |
checkpoint_path, | |
reset_optimizer, | |
reset_lr_scheduler, | |
optimizer_overrides, | |
reset_meters=reset_meters, | |
) | |
if ( | |
extra_state is not None | |
and "best" in extra_state | |
and not reset_optimizer | |
and not reset_meters | |
): | |
save_checkpoint.best = extra_state["best"] | |
if extra_state is not None and not reset_dataloader: | |
# restore iterator from checkpoint | |
itr_state = extra_state["train_iterator"] | |
epoch_itr = trainer.get_train_iterator( | |
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args | |
) | |
epoch_itr.load_state_dict(itr_state) | |
else: | |
epoch_itr = trainer.get_train_iterator( | |
epoch=1, load_dataset=True, **passthrough_args | |
) | |
trainer.lr_step(epoch_itr.epoch) | |
return extra_state, epoch_itr | |
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False): | |
"""Loads a checkpoint to CPU (with upgrading for backward compatibility). | |
If doing single-GPU training or if the checkpoint is only being loaded by at | |
most one process on each node (current default behavior is for only rank 0 | |
to read the checkpoint from disk), load_on_all_ranks should be False to | |
avoid errors from torch.distributed not having been initialized or | |
torch.distributed.barrier() hanging. | |
If all processes on each node may be loading the checkpoint | |
simultaneously, load_on_all_ranks should be set to True to avoid I/O | |
conflicts. | |
There's currently no support for > 1 but < all processes loading the | |
checkpoint on each node. | |
""" | |
local_path = PathManager.get_local_path(path) | |
# The locally cached file returned by get_local_path() may be stale for | |
# remote files that are periodically updated/overwritten (ex: | |
# checkpoint_last.pt) - so we remove the local copy, sync across processes | |
# (if needed), and then download a fresh copy. | |
if local_path != path and PathManager.path_requires_pathmanager(path): | |
try: | |
os.remove(local_path) | |
except FileNotFoundError: | |
# With potentially multiple processes removing the same file, the | |
# file being missing is benign (missing_ok isn't available until | |
# Python 3.8). | |
pass | |
if load_on_all_ranks: | |
torch.distributed.barrier() | |
local_path = PathManager.get_local_path(path) | |
with open(local_path, "rb") as f: | |
state = torch.load(f, map_location=torch.device("cpu")) | |
if "args" in state and state["args"] is not None and arg_overrides is not None: | |
args = state["args"] | |
for arg_name, arg_val in arg_overrides.items(): | |
setattr(args, arg_name, arg_val) | |
if "cfg" in state and state["cfg"] is not None: | |
# hack to be able to set Namespace in dict config. this should be removed when we update to newer | |
# omegaconf version that supports object flags, or when we migrate all existing models | |
from omegaconf import __version__ as oc_version | |
from omegaconf import _utils | |
if oc_version < "2.2": | |
old_primitive = _utils.is_primitive_type | |
_utils.is_primitive_type = lambda _: True | |
state["cfg"] = OmegaConf.create(state["cfg"]) | |
_utils.is_primitive_type = old_primitive | |
OmegaConf.set_struct(state["cfg"], True) | |
else: | |
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True}) | |
if arg_overrides is not None: | |
overwrite_args_by_name(state["cfg"], arg_overrides) | |
state = _upgrade_state_dict(state) | |
return state | |
def load_model_ensemble( | |
filenames, | |
arg_overrides: Optional[Dict[str, Any]] = None, | |
task=None, | |
strict=True, | |
suffix="", | |
num_shards=1, | |
state=None, | |
): | |
"""Loads an ensemble of models. | |
Args: | |
filenames (List[str]): checkpoint files to load | |
arg_overrides (Dict[str,Any], optional): override model args that | |
were used during model training | |
task (fairseq.tasks.FairseqTask, optional): task to use for loading | |
""" | |
assert not ( | |
strict and num_shards > 1 | |
), "Cannot load state dict with strict=True and checkpoint shards > 1" | |
ensemble, args, _task = load_model_ensemble_and_task( | |
filenames, | |
arg_overrides, | |
task, | |
strict, | |
suffix, | |
num_shards, | |
state, | |
) | |
return ensemble, args | |
def get_maybe_sharded_checkpoint_filename( | |
filename: str, suffix: str, shard_idx: int, num_shards: int | |
) -> str: | |
orig_filename = filename | |
filename = filename.replace(".pt", suffix + ".pt") | |
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt" | |
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt" | |
if PathManager.exists(fsdp_filename): | |
return fsdp_filename | |
elif num_shards > 1: | |
return model_parallel_filename | |
else: | |
return filename | |
def load_model_ensemble_and_task( | |
filenames, | |
arg_overrides: Optional[Dict[str, Any]] = None, | |
task=None, | |
strict=True, | |
suffix="", | |
num_shards=1, | |
state=None, | |
): | |
assert state is None or len(filenames) == 1 | |
from fairseq import tasks | |
assert not ( | |
strict and num_shards > 1 | |
), "Cannot load state dict with strict=True and checkpoint shards > 1" | |
ensemble = [] | |
cfg = None | |
for filename in filenames: | |
orig_filename = filename | |
model_shard_state = {"shard_weights": [], "shard_metadata": []} | |
assert num_shards > 0 | |
st = time.time() | |
for shard_idx in range(num_shards): | |
filename = get_maybe_sharded_checkpoint_filename( | |
orig_filename, suffix, shard_idx, num_shards | |
) | |
if not PathManager.exists(filename): | |
raise IOError("Model file not found: {}".format(filename)) | |
if state is None: | |
state = load_checkpoint_to_cpu(filename, arg_overrides) | |
if "args" in state and state["args"] is not None: | |
cfg = convert_namespace_to_omegaconf(state["args"]) | |
elif "cfg" in state and state["cfg"] is not None: | |
cfg = state["cfg"] | |
else: | |
raise RuntimeError( | |
f"Neither args nor cfg exist in state keys = {state.keys()}" | |
) | |
if task is None: | |
task = tasks.setup_task(cfg.task) | |
if "task_state" in state: | |
task.load_state_dict(state["task_state"]) | |
if "fsdp_metadata" in state and num_shards > 1: | |
model_shard_state["shard_weights"].append(state["model"]) | |
model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) | |
# check FSDP import before the code goes too far | |
if not has_FSDP: | |
raise ImportError( | |
"Cannot find FullyShardedDataParallel. " | |
"Please install fairscale with: pip install fairscale" | |
) | |
if shard_idx == num_shards - 1: | |
consolidated_model_state = FSDP.consolidate_shard_weights( | |
shard_weights=model_shard_state["shard_weights"], | |
shard_metadata=model_shard_state["shard_metadata"], | |
) | |
model = task.build_model(cfg.model) | |
if ( | |
"optimizer_history" in state | |
and len(state["optimizer_history"]) > 0 | |
and "num_updates" in state["optimizer_history"][-1] | |
): | |
model.set_num_updates( | |
state["optimizer_history"][-1]["num_updates"] | |
) | |
model.load_state_dict( | |
consolidated_model_state, strict=strict, model_cfg=cfg.model | |
) | |
else: | |
# model parallel checkpoint or unsharded checkpoint | |
# support old external tasks | |
argspec = inspect.getfullargspec(task.build_model) | |
if "from_checkpoint" in argspec.args: | |
model = task.build_model(cfg.model, from_checkpoint=True) | |
else: | |
model = task.build_model(cfg.model) | |
if ( | |
"optimizer_history" in state | |
and len(state["optimizer_history"]) > 0 | |
and "num_updates" in state["optimizer_history"][-1] | |
): | |
model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) | |
model.load_state_dict( | |
state["model"], strict=strict, model_cfg=cfg.model | |
) | |
# reset state so it gets loaded for the next model in ensemble | |
state = None | |
if shard_idx % 10 == 0 and shard_idx > 0: | |
elapsed = time.time() - st | |
logger.info( | |
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" | |
) | |
# build model for ensemble | |
ensemble.append(model) | |
return ensemble, cfg, task | |
def load_model_ensemble_and_task_from_hf_hub( | |
model_id, | |
cache_dir: Optional[str] = None, | |
arg_overrides: Optional[Dict[str, Any]] = None, | |
**kwargs: Any, | |
): | |
try: | |
from huggingface_hub import snapshot_download | |
except ImportError: | |
raise ImportError( | |
"You need to install huggingface_hub to use `load_from_hf_hub`. " | |
"See https://pypi.org/project/huggingface-hub/ for installation." | |
) | |
library_name = "fairseq" | |
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix() | |
cache_dir = snapshot_download( | |
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs | |
) | |
_arg_overrides = arg_overrides or {} | |
_arg_overrides["data"] = cache_dir | |
return load_model_ensemble_and_task( | |
[p.as_posix() for p in Path(cache_dir).glob("*.pt")], | |
arg_overrides=_arg_overrides, | |
) | |
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False): | |
"""Retrieves all checkpoints found in `path` directory. | |
Checkpoints are identified by matching filename to the specified pattern. If | |
the pattern contains groups, the result will be sorted by the first group in | |
descending order. | |
""" | |
pt_regexp = re.compile(pattern) | |
files = PathManager.ls(path) | |
entries = [] | |
for i, f in enumerate(files): | |
m = pt_regexp.fullmatch(f) | |
if m is not None: | |
idx = float(m.group(1)) if len(m.groups()) > 0 else i | |
entries.append((idx, m.group(0))) | |
if keep_match: | |
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)] | |
else: | |
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] | |
def torch_persistent_save(obj, filename, async_write: bool = False): | |
if async_write: | |
with PathManager.opena(filename, "wb") as f: | |
_torch_persistent_save(obj, f) | |
else: | |
if PathManager.supports_rename(filename): | |
# do atomic save | |
with PathManager.open(filename + ".tmp", "wb") as f: | |
_torch_persistent_save(obj, f) | |
PathManager.rename(filename + ".tmp", filename) | |
else: | |
# fallback to non-atomic save | |
with PathManager.open(filename, "wb") as f: | |
_torch_persistent_save(obj, f) | |
def _torch_persistent_save(obj, f): | |
if isinstance(f, str): | |
with PathManager.open(f, "wb") as h: | |
torch_persistent_save(obj, h) | |
return | |
for i in range(3): | |
try: | |
return torch.save(obj, f) | |
except Exception: | |
if i == 2: | |
logger.error(traceback.format_exc()) | |
raise | |
def _upgrade_state_dict(state): | |
"""Helper for upgrading old model checkpoints.""" | |
# add optimizer_history | |
if "optimizer_history" not in state: | |
state["optimizer_history"] = [ | |
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} | |
] | |
state["last_optimizer_state"] = state["optimizer"] | |
del state["optimizer"] | |
del state["best_loss"] | |
# move extra_state into sub-dictionary | |
if "epoch" in state and "extra_state" not in state: | |
state["extra_state"] = { | |
"epoch": state["epoch"], | |
"batch_offset": state["batch_offset"], | |
"val_loss": state["val_loss"], | |
} | |
del state["epoch"] | |
del state["batch_offset"] | |
del state["val_loss"] | |
# reduce optimizer history's memory usage (only keep the last state) | |
if "optimizer" in state["optimizer_history"][-1]: | |
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] | |
for optim_hist in state["optimizer_history"]: | |
del optim_hist["optimizer"] | |
# record the optimizer class name | |
if "optimizer_name" not in state["optimizer_history"][-1]: | |
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" | |
# move best_loss into lr_scheduler_state | |
if "lr_scheduler_state" not in state["optimizer_history"][-1]: | |
state["optimizer_history"][-1]["lr_scheduler_state"] = { | |
"best": state["optimizer_history"][-1]["best_loss"] | |
} | |
del state["optimizer_history"][-1]["best_loss"] | |
# keep track of number of updates | |
if "num_updates" not in state["optimizer_history"][-1]: | |
state["optimizer_history"][-1]["num_updates"] = 0 | |
# use stateful training data iterator | |
if "train_iterator" not in state["extra_state"]: | |
state["extra_state"]["train_iterator"] = { | |
"epoch": state["extra_state"].get("epoch", 0), | |
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0), | |
} | |
# backward compatibility, cfg updates | |
if "args" in state and state["args"] is not None: | |
# old model checkpoints may not have separate source/target positions | |
if hasattr(state["args"], "max_positions") and not hasattr( | |
state["args"], "max_source_positions" | |
): | |
state["args"].max_source_positions = state["args"].max_positions | |
state["args"].max_target_positions = state["args"].max_positions | |
# default to translation task | |
if not hasattr(state["args"], "task"): | |
state["args"].task = "translation" | |
# --raw-text and --lazy-load are deprecated | |
if getattr(state["args"], "raw_text", False): | |
state["args"].dataset_impl = "raw" | |
elif getattr(state["args"], "lazy_load", False): | |
state["args"].dataset_impl = "lazy" | |
# epochs start at 1 | |
if state["extra_state"]["train_iterator"] is not None: | |
state["extra_state"]["train_iterator"]["epoch"] = max( | |
state["extra_state"]["train_iterator"].get("epoch", 1), 1 | |
) | |
# --remove-bpe ==> --postprocess | |
if hasattr(state["args"], "remove_bpe"): | |
state["args"].post_process = state["args"].remove_bpe | |
# --min-lr ==> --stop-min-lr | |
if hasattr(state["args"], "min_lr"): | |
state["args"].stop_min_lr = state["args"].min_lr | |
del state["args"].min_lr | |
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion | |
if hasattr(state["args"], "criterion") and state["args"].criterion in [ | |
"binary_cross_entropy", | |
"kd_binary_cross_entropy", | |
]: | |
state["args"].criterion = "wav2vec" | |
# remove log_keys if it's None (criteria will supply a default value of []) | |
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: | |
delattr(state["args"], "log_keys") | |
# speech_pretraining => audio pretraining | |
if ( | |
hasattr(state["args"], "task") | |
and state["args"].task == "speech_pretraining" | |
): | |
state["args"].task = "audio_pretraining" | |
# audio_cpc => wav2vec | |
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": | |
state["args"].arch = "wav2vec" | |
# convert legacy float learning rate to List[float] | |
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): | |
state["args"].lr = [state["args"].lr] | |
# convert task data arg to a string instead of List[string] | |
if ( | |
hasattr(state["args"], "data") | |
and isinstance(state["args"].data, list) | |
and len(state["args"].data) > 0 | |
): | |
state["args"].data = state["args"].data[0] | |
state["cfg"] = convert_namespace_to_omegaconf(state["args"]) | |
if "cfg" in state and state["cfg"] is not None: | |
cfg = state["cfg"] | |
with open_dict(cfg): | |
# any upgrades for Hydra-based configs | |
if ( | |
"task" in cfg | |
and "eval_wer_config" in cfg.task | |
and isinstance(cfg.task.eval_wer_config.print_alignment, bool) | |
): | |
cfg.task.eval_wer_config.print_alignment = "hard" | |
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): | |
cfg.generation.print_alignment = ( | |
"hard" if cfg.generation.print_alignment else None | |
) | |
if ( | |
"model" in cfg | |
and "w2v_args" in cfg.model | |
and cfg.model.w2v_args is not None | |
and ( | |
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args | |
) | |
and hasattr(cfg.model.w2v_args.task, "eval_wer_config") | |
and cfg.model.w2v_args.task.eval_wer_config is not None | |
and isinstance( | |
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool | |
) | |
): | |
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" | |
return state | |
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]): | |
"""Prune the given state_dict if desired for LayerDrop | |
(https://arxiv.org/abs/1909.11556). | |
Training with LayerDrop allows models to be robust to pruning at inference | |
time. This function prunes state_dict to allow smaller models to be loaded | |
from a larger model and re-maps the existing state_dict for this to occur. | |
It's called by functions that load models from checkpoints and does not | |
need to be called directly. | |
""" | |
arch = None | |
if model_cfg is not None: | |
arch = ( | |
model_cfg._name | |
if isinstance(model_cfg, DictConfig) | |
else getattr(model_cfg, "arch", None) | |
) | |
if not model_cfg or arch is None or arch == "ptt_transformer": | |
# args should not be none, but don't crash if it is. | |
return state_dict | |
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None) | |
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None) | |
if not encoder_layers_to_keep and not decoder_layers_to_keep: | |
return state_dict | |
# apply pruning | |
logger.info( | |
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop" | |
) | |
def create_pruning_pass(layers_to_keep, layer_name): | |
keep_layers = sorted( | |
int(layer_string) for layer_string in layers_to_keep.split(",") | |
) | |
mapping_dict = {} | |
for i in range(len(keep_layers)): | |
mapping_dict[str(keep_layers[i])] = str(i) | |
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name)) | |
return {"substitution_regex": regex, "mapping_dict": mapping_dict} | |
pruning_passes = [] | |
if encoder_layers_to_keep: | |
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder")) | |
if decoder_layers_to_keep: | |
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder")) | |
new_state_dict = {} | |
for layer_name in state_dict.keys(): | |
match = re.search(r"\.layers\.(\d+)\.", layer_name) | |
# if layer has no number in it, it is a supporting layer, such as an | |
# embedding | |
if not match: | |
new_state_dict[layer_name] = state_dict[layer_name] | |
continue | |
# otherwise, layer should be pruned. | |
original_layer_number = match.group(1) | |
# figure out which mapping dict to replace from | |
for pruning_pass in pruning_passes: | |
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[ | |
"substitution_regex" | |
].search(layer_name): | |
new_layer_number = pruning_pass["mapping_dict"][original_layer_number] | |
substitution_match = pruning_pass["substitution_regex"].search( | |
layer_name | |
) | |
new_state_key = ( | |
layer_name[: substitution_match.start(1)] | |
+ new_layer_number | |
+ layer_name[substitution_match.end(1) :] | |
) | |
new_state_dict[new_state_key] = state_dict[layer_name] | |
# Since layers are now pruned, *_layers_to_keep are no longer needed. | |
# This is more of "It would make it work fix" rather than a proper fix. | |
if isinstance(model_cfg, DictConfig): | |
context = open_dict(model_cfg) | |
else: | |
context = contextlib.ExitStack() | |
with context: | |
if hasattr(model_cfg, "encoder_layers_to_keep"): | |
model_cfg.encoder_layers_to_keep = None | |
if hasattr(model_cfg, "decoder_layers_to_keep"): | |
model_cfg.decoder_layers_to_keep = None | |
return new_state_dict | |
def load_pretrained_component_from_model( | |
component: Union[FairseqEncoder, FairseqDecoder], | |
checkpoint: str, | |
strict: bool = True, | |
): | |
""" | |
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the | |
provided `component` object. If state_dict fails to load, there may be a | |
mismatch in the architecture of the corresponding `component` found in the | |
`checkpoint` file. | |
""" | |
if not PathManager.exists(checkpoint): | |
raise IOError("Model file not found: {}".format(checkpoint)) | |
state = load_checkpoint_to_cpu(checkpoint) | |
if isinstance(component, FairseqEncoder): | |
component_type = "encoder" | |
elif isinstance(component, FairseqDecoder): | |
component_type = "decoder" | |
else: | |
raise ValueError( | |
"component to load must be either a FairseqEncoder or " | |
"FairseqDecoder. Loading other component types are not supported." | |
) | |
component_state_dict = OrderedDict() | |
for key in state["model"].keys(): | |
if key.startswith(component_type): | |
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight | |
component_subkey = key[len(component_type) + 1 :] | |
component_state_dict[component_subkey] = state["model"][key] | |
component.load_state_dict(component_state_dict, strict=strict) | |
return component | |
def verify_checkpoint_directory(save_dir: str) -> None: | |
if not os.path.exists(save_dir): | |
os.makedirs(save_dir, exist_ok=True) | |
temp_file_path = os.path.join(save_dir, "dummy") | |
try: | |
with open(temp_file_path, "w"): | |
pass | |
except OSError as e: | |
logger.warning( | |
"Unable to access checkpoint save directory: {}".format(save_dir) | |
) | |
raise e | |
else: | |
os.remove(temp_file_path) | |
def save_ema_as_checkpoint(src_path, dst_path): | |
state = load_ema_from_checkpoint(src_path) | |
torch_persistent_save(state, dst_path) | |
def load_ema_from_checkpoint(fpath): | |
"""Loads exponential moving averaged (EMA) checkpoint from input and | |
returns a model with ema weights. | |
Args: | |
fpath: A string path of checkpoint to load from. | |
Returns: | |
A dict of string keys mapping to various values. The 'model' key | |
from the returned dict should correspond to an OrderedDict mapping | |
string parameter names to torch Tensors. | |
""" | |
params_dict = collections.OrderedDict() | |
new_state = None | |
with PathManager.open(fpath, "rb") as f: | |
new_state = torch.load( | |
f, | |
map_location=( | |
lambda s, _: torch.serialization.default_restore_location(s, "cpu") | |
), | |
) | |
# EMA model is stored in a separate "extra state" | |
model_params = new_state["extra_state"]["ema"] | |
for key in list(model_params.keys()): | |
p = model_params[key] | |
if isinstance(p, torch.HalfTensor): | |
p = p.float() | |
if key not in params_dict: | |
params_dict[key] = p.clone() | |
# NOTE: clone() is needed in case of p is a shared parameter | |
else: | |
raise ValueError("Key {} is repeated in EMA model params.".format(key)) | |
if len(params_dict) == 0: | |
raise ValueError( | |
f"Input checkpoint path '{fpath}' does not contain " | |
"ema model weights, is this model trained with EMA?" | |
) | |
new_state["model"] = params_dict | |
return new_state | |