Upload push_to_hub.py with huggingface_hub
Browse files- push_to_hub.py +76 -0
push_to_hub.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import huggingface_hub
|
4 |
+
import torch
|
5 |
+
from vqmodel.configuration_vqmodel import VQModelConfig
|
6 |
+
from vqmodel.image_processing_vqmodel import VQModelImageProcessor
|
7 |
+
from vqmodel.modeling_vqmodel import VQModel
|
8 |
+
|
9 |
+
VQModelConfig.register_for_auto_class()
|
10 |
+
VQModel.register_for_auto_class()
|
11 |
+
VQModelImageProcessor.register_for_auto_class()
|
12 |
+
|
13 |
+
|
14 |
+
def main():
|
15 |
+
args = parse_args()
|
16 |
+
config = VQModelConfig(yaml_path=args.yaml_path)
|
17 |
+
model = VQModel(config)
|
18 |
+
load_model_weights(model, args.ckpt_path)
|
19 |
+
|
20 |
+
# Define image processor
|
21 |
+
ddconfig = model.vq_cfg.model.params.ddconfig
|
22 |
+
image_processor = VQModelImageProcessor(
|
23 |
+
size=ddconfig.resolution,
|
24 |
+
convert_rgb=ddconfig.in_channels == 3,
|
25 |
+
)
|
26 |
+
|
27 |
+
# Edit config
|
28 |
+
model.config.repo_id = args.repo_id
|
29 |
+
model.config.yaml_path = "config.yaml"
|
30 |
+
|
31 |
+
# Push to hub
|
32 |
+
model.push_to_hub(args.repo_id, private=True)
|
33 |
+
image_processor.push_to_hub(args.repo_id, private=True)
|
34 |
+
api = huggingface_hub.HfApi()
|
35 |
+
api.upload_file(
|
36 |
+
path_or_fileobj=args.yaml_path,
|
37 |
+
path_in_repo="config.yaml",
|
38 |
+
repo_id=args.repo_id,
|
39 |
+
)
|
40 |
+
api.upload_file(
|
41 |
+
path_or_fileobj=__file__,
|
42 |
+
path_in_repo="push_to_hub.py",
|
43 |
+
repo_id=args.repo_id,
|
44 |
+
)
|
45 |
+
api.upload_file(
|
46 |
+
path_or_fileobj="requirements.txt",
|
47 |
+
path_in_repo="requirements.txt",
|
48 |
+
repo_id=args.repo_id,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def parse_args():
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument("--repo_id", type=str, required=True, help="Repository ID")
|
55 |
+
parser.add_argument(
|
56 |
+
"--yaml_path", type=str, required=True, help="Path to YAML file"
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--ckpt_path", type=str, required=True, help="Path to checkpoint file"
|
60 |
+
)
|
61 |
+
return parser.parse_args()
|
62 |
+
|
63 |
+
|
64 |
+
def load_model_weights(model, ckpt_path):
|
65 |
+
# Load checkpoint
|
66 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
67 |
+
|
68 |
+
# Remove loss related states
|
69 |
+
for key in list(ckpt.keys()):
|
70 |
+
if key.startswith("loss."):
|
71 |
+
del ckpt[key]
|
72 |
+
model.model.load_state_dict(ckpt)
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
main()
|