""" Holds the interface between the gradio app and the medusa training script """ import os import multiprocessing as mp from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer import torch import torch.distributed.run as distributed_run OUTPUT_DIR = "medusa_heads" DATASET = "vicuna" # These can't be changed (e.g. they control the output path) FIXED_TRAINING_ARGS = \ """src/medusa_training_script.py --model_name_or_path {model_id} --output_dir {output_dir} --run_name {model_id}-medusa-{dataset} --dataset {dataset}""" # These can be freely changed DEFAULT_TRAINING_ARGS = \ """--medusa_num_heads 3 --medusa_num_layers 1 --model_max_length 2048 --bf16 True --num_train_epochs 1 --per_device_train_batch_size 64 --per_device_eval_batch_size 64 --gradient_accumulation_steps 8 --evaluation_strategy no --save_strategy no --weight_decay 0.0 --warmup_ratio 0.1 --lr_scheduler_type cosine --logging_steps 10 --tf32 True --auto_find_batch_size True --learning_rate 1e-3""" def train_medusa_heads(model_id: str, training_args: str, dataset: str): all_training_args = FIXED_TRAINING_ARGS.format( model_id=model_id, output_dir=OUTPUT_DIR, dataset=dataset, ) + "\n" + training_args all_training_arg_list = [] for arg in all_training_args.split("\n"): all_training_arg_list += arg.split(" ") print("Full argument list:", all_training_arg_list) parser = distributed_run.get_args_parser() args = parser.parse_args(all_training_arg_list) distributed_run.run(args) def run(model_id: str, training_args: str, dataset: str) -> str: print(f"\n\n\nNEW RUN: {model_id}") api = HfApi() model_name = model_id.split("/")[-1] repo_id = f"joaogante/{model_name}-medusa-{dataset}" # Input validation if model_id == "": return """ ### Invalid input 🐞 Please fill a model_id. """ if api.repo_exists(repo_id): return f""" ### Invalid input 🐞 {repo_id} already exists, which means that {model_id} has already been used to create medusa heads. """ print(f"Valid inputs ✅\nValidating model_id: {model_id}") # Attempt to load the base model try: config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) del config, tokenizer, model except Exception as e: return f""" ### {model_id} can't be loaded with AutoClasses 🐞 {e} """ print(f"{model_id} can be loaded ✅\nCreating medusa heads (will take a few hours)") # Run the medusa heads creation try: proc = mp.Process(target=train_medusa_heads, args=(model_id, training_args, dataset)) proc.start() proc.join() print("Medusa heads training process completed (it might have crashed!)") except Exception as e: print("Error ❌\n", e) return f""" ### Error 😢😢😢 {e} """ # Upload the medusa heads to the Hub try: # Folder path from https://github.com/FasterDecoding/Medusa/blob/main/medusa/train/train.py#L399 folder_path = ( f"{OUTPUT_DIR}_medusa_{model_name}" ) if not any([x for x in os.listdir(folder_path) if len(x) >= 3 and x[-3:] == ".pt"]): raise Exception( "No model data in the expected model folder, the traning run probably failed. Check the logs for more " "information." ) api.create_repo( repo_id=repo_id, exist_ok=True, ) api.upload_folder( folder_path=folder_path, repo_id=repo_id, ) print("Medusa heads upload success ✅\n Uploaded to: ", repo_id) return f""" ### Success 🔥 Yay! Medusa heads were successfully created and uploaded to the following repo: {repo_id} """ except Exception as e: print("Error ❌\n", e) try: api.delete_repo(repo_id) except RepositoryNotFoundError: pass return f""" ### Error 😢😢😢 {e} """