|
""" |
|
Upload weights to huggingface. |
|
|
|
Usage: |
|
python3 -m fastchat.model.upload_hub --model-path ~/model_weights/vicuna-13b --hub-repo-id lmsys/vicuna-13b-v1.3 |
|
""" |
|
import argparse |
|
import tempfile |
|
|
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
def upload_hub(model_path, hub_repo_id, component, private): |
|
if component == "all": |
|
components = ["model", "tokenizer"] |
|
else: |
|
components = [component] |
|
|
|
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id, "private": args.private} |
|
|
|
if "model" in components: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True |
|
) |
|
with tempfile.TemporaryDirectory() as tmp_path: |
|
model.save_pretrained(tmp_path, **kwargs) |
|
|
|
if "tokenizer" in components: |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) |
|
with tempfile.TemporaryDirectory() as tmp_path: |
|
tokenizer.save_pretrained(tmp_path, **kwargs) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model-path", type=str, required=True) |
|
parser.add_argument("--hub-repo-id", type=str, required=True) |
|
parser.add_argument( |
|
"--component", type=str, choices=["all", "model", "tokenizer"], default="all" |
|
) |
|
parser.add_argument("--private", action="store_true") |
|
args = parser.parse_args() |
|
|
|
upload_hub(args.model_path, args.hub_repo_id, args.component, args.private) |
|
|