# Copyright 2024 NVIDIA CORPORATION & AFFILIATES (authored by @Lyken17) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 import argparse import os import os.path as osp import re import time from hashlib import sha1, sha256 from huggingface_hub import HfApi from huggingface_hub.hf_api import CommitOperationAdd from termcolor import colored MAX_UPLOAD_FILES_PER_COMMIT = 64 MAX_UPLOAD_SIZE_PER_COMMIT = 64 * 1024 * 1024 * 1024 # 64 GiB MAX_UPLOAD_SIZE_PER_SINGLE_FILE = 45 * 1024 * 1024 * 1024 # 45 GiB def compute_git_hash(filename): with open(filename, "rb") as f: data = f.read() s = "blob %u\0" % len(data) h = sha1() h.update(s.encode("utf-8")) h.update(data) return h.hexdigest() def compute_lfs_hash(filename): with open(filename, "rb") as f: data = f.read() h = sha256() h.update(data) return h.hexdigest() def main(): import os os.environ["CURL_CA_BUNDLE"] = "" parser = argparse.ArgumentParser() parser.add_argument("local_folder", type=str) parser.add_argument("--model-name", type=str, default=None) parser.add_argument("--repo-type", type=str, choices=["model", "dataset"]) parser.add_argument("--repo-org", type=str, default="Efficient-Large-Model") parser.add_argument("--repo-id", type=str, default=None) parser.add_argument("--root-dir", type=str, default=None) parser.add_argument("--token", type=str, default=None) parser.add_argument("-e", "--exclude", action="append", default=[r"checkpoint-[\d]*/.*", ".git/.*", "wandb/.*"]) parser.add_argument("--fast-check", action="store_true") parser.add_argument("--sleep-on-error", action="store_true") args = parser.parse_args() if args.token is None: api = HfApi() else: print("initing using token from cmd args.") api = HfApi(token=args.token) repo_type = args.repo_type local_folder = args.local_folder if args.repo_id is not None: repo = args.repo_id else: # remove last / if local_folder[-1] == "/": local_folder = local_folder[:-1] if args.model_name is None: model_name = osp.basename(local_folder).replace("+", "-") else: model_name = args.model_name repo = osp.join(args.repo_org, model_name) local_folder = os.path.expanduser(local_folder) root_dir = local_folder if args.root_dir is None else args.root_dir print(f"uploading {local_folder} to {repo}") if not api.repo_exists(repo, repo_type=repo_type): api.create_repo( repo_id=repo, private=True, repo_type=repo_type, ) BASE_URL = "https://hf.co" if args.repo_type == "dataset": BASE_URL = "https://hf.co/datasets" print(colored(f"Uploading {osp.join(BASE_URL, repo)}", "green")) ops = [] commit_description = "" commit_size = 0 for root, dirs, files in os.walk(local_folder, topdown=True): dirs.sort() for name in files: fpath = osp.join(root, name) rpath = osp.relpath(fpath, osp.abspath(root_dir)) exclude_flag = False matched_pattern = None for pattern in args.exclude: if re.search(pattern, rpath): exclude_flag = True matched_pattern = pattern if exclude_flag: print(colored(f"""[regex filter-out][{matched_pattern}]: {rpath}, skipping""", "yellow")) continue if osp.getsize(fpath) > MAX_UPLOAD_SIZE_PER_SINGLE_FILE: print( colored( f"Huggingface only supports filesize less than {MAX_UPLOAD_SIZE_PER_SINGLE_FILE}, skipping", "red", ) ) continue if api.file_exists(repo_id=repo, filename=rpath, repo_type=repo_type): if args.fast_check: print( colored( f"Already uploaded {rpath}, fast check pass, skipping", "green", ) ) continue else: hf_meta = api.get_paths_info(repo_id=repo, paths=rpath, repo_type=repo_type)[0] if hf_meta.lfs is not None: hash_type = "lfs-sha256" hf_hash = hf_meta.lfs["sha256"] git_hash = compute_lfs_hash(fpath) else: hash_type = "git-sha1" hf_hash = hf_meta.blob_id git_hash = compute_git_hash(fpath) if hf_hash == git_hash: print( colored( f"Already uploaded {rpath}, hash check pass, skipping", "green", ) ) continue else: print( colored( f"{rpath} is not same as local version, re-uploading...", "red", ) ) operation = CommitOperationAdd( path_or_fileobj=fpath, path_in_repo=rpath, ) print(colored(f"Commiting {rpath}", "green")) ops.append(operation) commit_size += operation.upload_info.size commit_description += f"Upload {rpath}\n" if len(ops) <= MAX_UPLOAD_FILES_PER_COMMIT and commit_size <= MAX_UPLOAD_SIZE_PER_COMMIT: continue commit_message = "Upload files with vila-upload." result = None while result is None: try: commit_info = api.create_commit( repo_id=repo, repo_type=repo_type, operations=ops, commit_message=commit_message, commit_description=commit_description, ) except RuntimeError as e: print(e) if not args.sleep_on_error: raise e else: print("sleeping for one hour then re-try") time.sleep(1800) continue result = "success" commit_description = "" ops = [] commit_size = 0 print(colored(f"Finish {commit_info}", "yellow")) # upload residuals commit_message = "Upload files with `sana-upload`." commit_info = api.create_commit( repo_id=repo, repo_type=repo_type, operations=ops, commit_message=commit_message, commit_description=commit_description, ) if __name__ == "__main__": main()