#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import os from pathlib import Path import sys pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import huggingface_hub from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--repo_id", default="csukuangfj/wenet-chinese-model", # default="csukuangfj/wenet-english-model", type=str ) parser.add_argument("--model_filename", default="final.zip", type=str) parser.add_argument("--tokens_filename", default="units.txt", type=str) parser.add_argument( "--pretrained_model_dir", default=(project_path / "pretrained_models").as_posix(), type=str ) args = parser.parse_args() return args def main(): args = get_args() pretrained_model_dir = Path(args.pretrained_model_dir) pretrained_model_dir.mkdir(exist_ok=True) model_dir = pretrained_model_dir / "huggingface" / args.repo_id model_dir.mkdir(exist_ok=True) print("download model") model_filename = huggingface_hub.hf_hub_download( repo_id=args.repo_id, filename=args.model_filename, subfolder=".", local_dir=model_dir.as_posix(), ) print(model_filename) print("download tokens") tokens_filename = huggingface_hub.hf_hub_download( repo_id=args.repo_id, filename=args.tokens_filename, subfolder=".", local_dir=model_dir.as_posix(), ) print(tokens_filename) return if __name__ == "__main__": main()