# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random from pathlib import Path from typing import List import numpy as np import torch from torch.cuda.amp import GradScaler from .utils import ( MODEL_NAME, OPTIMIZER_NAME, RNG_STATE_NAME, SCALER_NAME, SCHEDULER_NAME, get_pretty_name, is_tpu_available, is_xpu_available, save, ) if is_tpu_available(check_device=False): import torch_xla.core.xla_model as xm from .logging import get_logger from .state import PartialState logger = get_logger(__name__) def save_accelerator_state( output_dir: str, model_states: List[dict], optimizers: list, schedulers: list, process_index: int, scaler: GradScaler = None, ): """ Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. Args: output_dir (`str` or `os.PathLike`): The name of the folder to save all relevant weights and states. model_states (`List[torch.nn.Module]`): A list of model states optimizers (`List[torch.optim.Optimizer]`): A list of optimizer instances schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): A list of learning rate schedulers process_index (`int`): The current process index in the Accelerator state scaler (`torch.cuda.amp.GradScaler`, *optional*): An optional gradient scaler instance to save """ # Model states for i, state in enumerate(model_states): weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" output_model_file = os.path.join(output_dir, weights_name) save(state, output_model_file) logger.info(f"Model weights saved in {output_model_file}") # Optimizer states for i, opt in enumerate(optimizers): state = opt.state_dict() optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" output_optimizer_file = os.path.join(output_dir, optimizer_name) save(state, output_optimizer_file) logger.info(f"Optimizer state saved in {output_optimizer_file}") # Scheduler states for i, scheduler in enumerate(schedulers): state = scheduler.state_dict() scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" output_scheduler_file = os.path.join(output_dir, scheduler_name) save(state, output_scheduler_file) logger.info(f"Scheduler state saved in {output_scheduler_file}") # GradScaler state if scaler is not None: state = scaler.state_dict() output_scaler_file = os.path.join(output_dir, SCALER_NAME) torch.save(state, output_scaler_file) logger.info(f"Gradient scaler state saved in {output_scaler_file}") # Random number generator states states = {} states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" states["random_state"] = random.getstate() states["numpy_random_seed"] = np.random.get_state() states["torch_manual_seed"] = torch.get_rng_state() if is_xpu_available(): states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() else: states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() if is_tpu_available(): states["xm_seed"] = xm.get_rng_state() output_states_file = os.path.join(output_dir, states_name) torch.save(states, output_states_file) logger.info(f"Random states saved in {output_states_file}") return output_dir def load_accelerator_state( input_dir, models, optimizers, schedulers, process_index, scaler=None, map_location=None, **load_model_func_kwargs, ): """ Loads states of the models, optimizers, scaler, and RNG generators from a given directory. Args: input_dir (`str` or `os.PathLike`): The name of the folder to load all relevant weights and states. models (`List[torch.nn.Module]`): A list of model instances optimizers (`List[torch.optim.Optimizer]`): A list of optimizer instances schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): A list of learning rate schedulers process_index (`int`): The current process index in the Accelerator state scaler (`torch.cuda.amp.GradScaler`, *optional*): An optional *GradScaler* instance to load map_location (`str`, *optional*): What device to load the optimizer state onto. Should be one of either "cpu" or "on_device". load_model_func_kwargs (`dict`, *optional*): Additional arguments that can be passed to the model's `load_state_dict` method. """ if map_location not in [None, "cpu", "on_device"]: raise TypeError( "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`" ) if map_location is None: map_location = "cpu" elif map_location == "on_device": map_location = PartialState().device # Model states for i, model in enumerate(models): weights_name = f"{MODEL_NAME}.bin" if i == 0 else f"{MODEL_NAME}_{i}.bin" input_model_file = os.path.join(input_dir, weights_name) models[i].load_state_dict(torch.load(input_model_file, map_location=map_location), **load_model_func_kwargs) logger.info("All model weights loaded successfully") # Optimizer states for i, opt in enumerate(optimizers): optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" input_optimizer_file = os.path.join(input_dir, optimizer_name) optimizer_state = torch.load(input_optimizer_file, map_location=map_location) optimizers[i].load_state_dict(optimizer_state) logger.info("All optimizer states loaded successfully") # Scheduler states for i, scheduler in enumerate(schedulers): scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" input_scheduler_file = os.path.join(input_dir, scheduler_name) scheduler.load_state_dict(torch.load(input_scheduler_file)) logger.info("All scheduler states loaded successfully") # GradScaler state if scaler is not None: input_scaler_file = os.path.join(input_dir, SCALER_NAME) scaler.load_state_dict(torch.load(input_scaler_file)) logger.info("GradScaler state loaded successfully") # Random states try: states = torch.load(os.path.join(input_dir, f"{RNG_STATE_NAME}_{process_index}.pkl")) random.setstate(states["random_state"]) np.random.set_state(states["numpy_random_seed"]) torch.set_rng_state(states["torch_manual_seed"]) if is_xpu_available(): torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) else: torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) if is_tpu_available(): xm.set_rng_state(states["xm_seed"]) logger.info("All random states loaded successfully") except Exception: logger.info("Could not load random states") def save_custom_state(obj, path, index: int = 0): """ Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl` """ # Should this be the right way to get a qual_name type value from `obj`? save_location = Path(path) / f"custom_checkpoint_{index}.pkl" logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") torch.save(obj.state_dict(), save_location) def load_custom_state(obj, path, index: int = 0): """ Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl` """ load_location = f"{path}/custom_checkpoint_{index}.pkl" logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}") obj.load_state_dict(torch.load(load_location, map_location="cpu"))