|
|
|
""" Checkpoint Averaging Script |
|
|
|
This script averages all model weights for checkpoints in specified path that match |
|
the specified filter wildcard. All checkpoints must be from the exact same model. |
|
|
|
For any hope of decent results, the checkpoints should be from the same or child |
|
(via resumes) training session. This can be viewed as similar to maintaining running |
|
EMA (exponential moving average) of the model weights or performing SWA (stochastic |
|
weight averaging), but post-training. |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) |
|
""" |
|
import torch |
|
import argparse |
|
import os |
|
import glob |
|
import hashlib |
|
from timm.models import load_state_dict |
|
try: |
|
import safetensors.torch |
|
_has_safetensors = True |
|
except ImportError: |
|
_has_safetensors = False |
|
|
|
DEFAULT_OUTPUT = "./averaged.pth" |
|
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors" |
|
|
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') |
|
parser.add_argument('--input', default='', type=str, metavar='PATH', |
|
help='path to base input folder containing checkpoints') |
|
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', |
|
help='checkpoint filter (path wildcard)') |
|
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH', |
|
help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.') |
|
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', |
|
help='Force not using ema version of weights (if present)') |
|
parser.add_argument('--no-sort', dest='no_sort', action='store_true', |
|
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') |
|
parser.add_argument('-n', type=int, default=10, metavar='N', |
|
help='Number of checkpoints to average') |
|
parser.add_argument('--safetensors', action='store_true', |
|
help='Save weights using safetensors instead of the default torch way (pickle).') |
|
|
|
|
|
def checkpoint_metric(checkpoint_path): |
|
if not checkpoint_path or not os.path.isfile(checkpoint_path): |
|
return {} |
|
print("=> Extracting metric from checkpoint '{}'".format(checkpoint_path)) |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
metric = None |
|
if 'metric' in checkpoint: |
|
metric = checkpoint['metric'] |
|
elif 'metrics' in checkpoint and 'metric_name' in checkpoint: |
|
metrics = checkpoint['metrics'] |
|
print(metrics) |
|
metric = metrics[checkpoint['metric_name']] |
|
return metric |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
|
|
args.use_ema = not args.no_use_ema |
|
|
|
args.sort = not args.no_sort |
|
|
|
if args.safetensors and args.output == DEFAULT_OUTPUT: |
|
|
|
args.output = DEFAULT_SAFE_OUTPUT |
|
|
|
output, output_ext = os.path.splitext(args.output) |
|
if not output_ext: |
|
output_ext = ('.safetensors' if args.safetensors else '.pth') |
|
output = output + output_ext |
|
|
|
if args.safetensors and not output_ext == ".safetensors": |
|
print( |
|
"Warning: saving weights as safetensors but output file extension is not " |
|
f"set to '.safetensors': {args.output}" |
|
) |
|
|
|
if os.path.exists(output): |
|
print("Error: Output filename ({}) already exists.".format(output)) |
|
exit(1) |
|
|
|
pattern = args.input |
|
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep): |
|
pattern += os.path.sep |
|
pattern += args.filter |
|
checkpoints = glob.glob(pattern, recursive=True) |
|
|
|
if args.sort: |
|
checkpoint_metrics = [] |
|
for c in checkpoints: |
|
metric = checkpoint_metric(c) |
|
if metric is not None: |
|
checkpoint_metrics.append((metric, c)) |
|
checkpoint_metrics = list(sorted(checkpoint_metrics)) |
|
checkpoint_metrics = checkpoint_metrics[-args.n:] |
|
if checkpoint_metrics: |
|
print("Selected checkpoints:") |
|
[print(m, c) for m, c in checkpoint_metrics] |
|
avg_checkpoints = [c for m, c in checkpoint_metrics] |
|
else: |
|
avg_checkpoints = checkpoints |
|
if avg_checkpoints: |
|
print("Selected checkpoints:") |
|
[print(c) for c in checkpoints] |
|
|
|
if not avg_checkpoints: |
|
print('Error: No checkpoints found to average.') |
|
exit(1) |
|
|
|
avg_state_dict = {} |
|
avg_counts = {} |
|
for c in avg_checkpoints: |
|
new_state_dict = load_state_dict(c, args.use_ema) |
|
if not new_state_dict: |
|
print(f"Error: Checkpoint ({c}) doesn't exist") |
|
continue |
|
for k, v in new_state_dict.items(): |
|
if k not in avg_state_dict: |
|
avg_state_dict[k] = v.clone().to(dtype=torch.float64) |
|
avg_counts[k] = 1 |
|
else: |
|
avg_state_dict[k] += v.to(dtype=torch.float64) |
|
avg_counts[k] += 1 |
|
|
|
for k, v in avg_state_dict.items(): |
|
v.div_(avg_counts[k]) |
|
|
|
|
|
float32_info = torch.finfo(torch.float32) |
|
final_state_dict = {} |
|
for k, v in avg_state_dict.items(): |
|
v = v.clamp(float32_info.min, float32_info.max) |
|
final_state_dict[k] = v.to(dtype=torch.float32) |
|
|
|
if args.safetensors: |
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors" |
|
safetensors.torch.save_file(final_state_dict, output) |
|
else: |
|
torch.save(final_state_dict, output) |
|
|
|
with open(output, 'rb') as f: |
|
sha_hash = hashlib.sha256(f.read()).hexdigest() |
|
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|