Spaces:
Paused
Paused
""" | |
Holds the interface between the gradio app and the medusa training script | |
""" | |
import os | |
import multiprocessing as mp | |
from huggingface_hub import HfApi | |
from huggingface_hub.utils import RepositoryNotFoundError | |
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import torch.distributed.run as distributed_run | |
OUTPUT_DIR = "medusa_heads" | |
DATASET = "vicuna" | |
# These can't be changed (e.g. they control the output path) | |
FIXED_TRAINING_ARGS = \ | |
"""src/medusa_training_script.py | |
--model_name_or_path {model_id} | |
--output_dir {output_dir} | |
--run_name {model_id}-medusa-{dataset} | |
--dataset {dataset}""" | |
# These can be freely changed | |
DEFAULT_TRAINING_ARGS = \ | |
"""--medusa_num_heads 3 | |
--medusa_num_layers 1 | |
--model_max_length 2048 | |
--bf16 True | |
--num_train_epochs 1 | |
--per_device_train_batch_size 64 | |
--per_device_eval_batch_size 64 | |
--gradient_accumulation_steps 8 | |
--evaluation_strategy no | |
--save_strategy no | |
--weight_decay 0.0 | |
--warmup_ratio 0.1 | |
--lr_scheduler_type cosine | |
--logging_steps 10 | |
--tf32 True | |
--auto_find_batch_size True | |
--learning_rate 1e-3""" | |
def train_medusa_heads(model_id: str, training_args: str, dataset: str): | |
all_training_args = FIXED_TRAINING_ARGS.format( | |
model_id=model_id, output_dir=OUTPUT_DIR, dataset=dataset, | |
) + "\n" + training_args | |
all_training_arg_list = [] | |
for arg in all_training_args.split("\n"): | |
all_training_arg_list += arg.split(" ") | |
print("Full argument list:", all_training_arg_list) | |
parser = distributed_run.get_args_parser() | |
args = parser.parse_args(all_training_arg_list) | |
distributed_run.run(args) | |
def run(model_id: str, training_args: str, dataset: str) -> str: | |
print(f"\n\n\nNEW RUN: {model_id}") | |
api = HfApi() | |
model_name = model_id.split("/")[-1] | |
repo_id = f"joaogante/{model_name}-medusa-{dataset}" | |
# Input validation | |
if model_id == "": | |
return """ | |
### Invalid input π | |
Please fill a model_id. | |
""" | |
if api.repo_exists(repo_id): | |
return f""" | |
### Invalid input π | |
{repo_id} already exists, which means that {model_id} has already been used to create medusa heads. | |
""" | |
print(f"Valid inputs β \nValidating model_id: {model_id}") | |
# Attempt to load the base model | |
try: | |
config = AutoConfig.from_pretrained(model_id) | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) | |
del config, tokenizer, model | |
except Exception as e: | |
return f""" | |
### {model_id} can't be loaded with AutoClasses π | |
{e} | |
""" | |
print(f"{model_id} can be loaded β \nCreating medusa heads (will take a few hours)") | |
# Run the medusa heads creation | |
try: | |
proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, dataset)) | |
proc.start() | |
proc.join() | |
print("Medusa heads training process completed (it might have crashed!)") | |
except Exception as e: | |
print("Error β\n", e) | |
return f""" | |
### Error π’π’π’ | |
{e} | |
""" | |
# Upload the medusa heads to the Hub | |
try: | |
# Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399 | |
folder_path = ( | |
f"{OUTPUT_DIR}_medusa_{model_name}" | |
) | |
if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]): | |
raise Exception( | |
"No model data in the expected model folder, the traning run probably failed. Check the logs for more " | |
"information." | |
) | |
api.create_repo( | |
repo_id=repo_id, | |
exist_ok=True, | |
) | |
api.upload_folder( | |
folder_path=folder_path, | |
repo_id=repo_id, | |
) | |
print("Medusa heads upload success β \n Uploaded to: ", repo_id) | |
return f""" | |
### Success π₯ | |
Yay! Medusa heads were successfully created and uploaded to the following repo: {repo_id} | |
""" | |
except Exception as e: | |
print("Error β\n", e) | |
try: | |
api.delete_repo(repo_id) | |
except RepositoryNotFoundError: | |
pass | |
return f""" | |
### Error π’π’π’ | |
{e} | |
""" | |