hysts's picture
hysts HF staff
Migrate from yapf to black
ecfdc8b
from __future__ import annotations
import datetime
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import slugify
import torch
from huggingface_hub import HfApi
from omegaconf import OmegaConf
from uploader import upload
from utils import save_model_card
sys.path.append("Tune-A-Video")
class Trainer:
def __init__(self):
self.checkpoint_dir = pathlib.Path("checkpoints")
self.checkpoint_dir.mkdir(exist_ok=True)
self.log_file = pathlib.Path("log.txt")
self.log_file.touch(exist_ok=True)
def download_base_model(self, base_model_id: str) -> str:
model_dir = self.checkpoint_dir / base_model_id
if not model_dir.exists():
org_name = base_model_id.split("/")[0]
org_dir = self.checkpoint_dir / org_name
org_dir.mkdir(exist_ok=True)
subprocess.run(shlex.split(f"git clone https://huggingface.co/{base_model_id}"), cwd=org_dir)
return model_dir.as_posix()
def run(
self,
training_video: str,
training_prompt: str,
output_model_name: str,
overwrite_existing_model: bool,
validation_prompt: str,
base_model: str,
resolution_s: str,
n_steps: int,
learning_rate: float,
gradient_accumulation: int,
seed: int,
fp16: bool,
use_8bit_adam: bool,
checkpointing_steps: int,
validation_epochs: int,
upload_to_hub: bool,
use_private_repo: bool,
delete_existing_repo: bool,
upload_to: str,
pause_space_after_training: bool,
hf_token: str,
) -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available.")
if training_video is None:
raise ValueError("You need to upload a video.")
if not training_prompt:
raise ValueError("The training prompt is missing.")
if not validation_prompt:
raise ValueError("The validation prompt is missing.")
resolution = int(resolution_s)
if not output_model_name:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
output_model_name = f"tune-a-video-{timestamp}"
output_model_name = slugify.slugify(output_model_name)
repo_dir = pathlib.Path(__file__).parent
output_dir = repo_dir / "experiments" / output_model_name
if overwrite_existing_model or upload_to_hub:
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(parents=True)
config = OmegaConf.load("Tune-A-Video/configs/man-surfing.yaml")
config.pretrained_model_path = self.download_base_model(base_model)
config.output_dir = output_dir.as_posix()
config.train_data.video_path = training_video.name # type: ignore
config.train_data.prompt = training_prompt
config.train_data.n_sample_frames = 8
config.train_data.width = resolution
config.train_data.height = resolution
config.train_data.sample_start_idx = 0
config.train_data.sample_frame_rate = 1
config.validation_data.prompts = [validation_prompt]
config.validation_data.video_length = 8
config.validation_data.width = resolution
config.validation_data.height = resolution
config.validation_data.num_inference_steps = 50
config.validation_data.guidance_scale = 7.5
config.learning_rate = learning_rate
config.gradient_accumulation_steps = gradient_accumulation
config.train_batch_size = 1
config.max_train_steps = n_steps
config.checkpointing_steps = checkpointing_steps
config.validation_steps = validation_epochs
config.seed = seed
config.mixed_precision = "fp16" if fp16 else ""
config.use_8bit_adam = use_8bit_adam
config_path = output_dir / "config.yaml"
with open(config_path, "w") as f:
OmegaConf.save(config, f)
command = f"accelerate launch Tune-A-Video/train_tuneavideo.py --config {config_path}"
with open(self.log_file, "w") as f:
subprocess.run(shlex.split(command), stdout=f, stderr=subprocess.STDOUT, text=True)
save_model_card(
save_dir=output_dir,
base_model=base_model,
training_prompt=training_prompt,
test_prompt=validation_prompt,
test_image_dir="samples",
)
with open(self.log_file, "a") as f:
f.write("Training completed!\n")
if upload_to_hub:
upload_message = upload(
local_folder_path=output_dir.as_posix(),
target_repo_name=output_model_name,
upload_to=upload_to,
private=use_private_repo,
delete_existing_repo=delete_existing_repo,
hf_token=hf_token,
)
with open(self.log_file, "a") as f:
f.write(upload_message)
if pause_space_after_training:
if space_id := os.getenv("SPACE_ID"):
api = HfApi(token=os.getenv("HF_TOKEN") or hf_token)
api.pause_space(repo_id=space_id)