|
import argparse |
|
from pathlib import Path |
|
|
|
import torch |
|
from safetensors.torch import save_file |
|
|
|
|
|
def convert(path: Path, half: bool = False, no_ema: bool = False): |
|
state_dict = torch.load(path, map_location="cpu") |
|
if "state_dict" in state_dict: |
|
state_dict = state_dict["state_dict"] |
|
|
|
to_remove = [] |
|
for k, v in state_dict.items(): |
|
if not isinstance(v, torch.Tensor): |
|
to_remove.append(k) |
|
elif no_ema and "ema" in k: |
|
to_remove.append(k) |
|
|
|
for k in to_remove: |
|
del state_dict[k] |
|
|
|
if half: |
|
state_dict = {k: v.half() for k, v in state_dict.items()} |
|
|
|
output_name = path.stem |
|
if no_ema: |
|
output_name += "-pruned" |
|
if half: |
|
output_name += "-fp16" |
|
output_path = path.parent / f"{output_name}.safetensors" |
|
save_file(state_dict, output_path.as_posix()) |
|
|
|
|
|
def main(path: str, half: bool = False, no_ema: bool = False): |
|
path_ = Path(path).resolve() |
|
|
|
if not path_.exists(): |
|
raise ValueError(f"Invalid path: {path}") |
|
|
|
if path_.is_file(): |
|
to_convert = [path_] |
|
else: |
|
to_convert = list(path_.glob("*.ckpt")) |
|
|
|
for file in to_convert: |
|
print(f"Converting... {file}") |
|
convert(file, half, no_ema) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("path", type=str, help="Path to checkpoint file or directory.") |
|
parser.add_argument( |
|
"--half", action="store_true", help="Convert to half precision." |
|
) |
|
parser.add_argument("--no-ema", action="store_true", help="Ignore EMA weights.") |
|
return parser.parse_args() |
|
|
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
main(args.path, args.half, args.no_ema) |
|
|