|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
|
|
import gradio as gr |
|
|
|
from constants import UploadTarget |
|
from inference import InferencePipeline |
|
from trainer import Trainer |
|
|
|
|
|
def create_training_demo( |
|
trainer: Trainer, pipe: InferencePipeline | None = None, disable_run_button: bool = False |
|
) -> gr.Blocks: |
|
def read_log() -> str: |
|
with open(trainer.log_file) as f: |
|
lines = f.readlines() |
|
return "".join(lines[-10:]) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Box(): |
|
gr.Markdown("Training Data") |
|
training_video = gr.File(label="Training video") |
|
training_prompt = gr.Textbox(label="Training prompt", max_lines=1, placeholder="A man is surfing") |
|
gr.Markdown( |
|
""" |
|
- Upload a video and write a `Training Prompt` that describes the video. |
|
""" |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Box(): |
|
gr.Markdown("Training Parameters") |
|
with gr.Row(): |
|
base_model = gr.Text(label="Base Model", value="CompVis/stable-diffusion-v1-4", max_lines=1) |
|
resolution = gr.Dropdown( |
|
choices=["512", "768"], value="512", label="Resolution", visible=False |
|
) |
|
|
|
hf_token = gr.Text( |
|
label="Hugging Face Write Token", type="password", visible=os.getenv("HF_TOKEN") is None |
|
) |
|
with gr.Accordion(label="Advanced options", open=False): |
|
num_training_steps = gr.Number(label="Number of Training Steps", value=300, precision=0) |
|
learning_rate = gr.Number(label="Learning Rate", value=0.000035) |
|
gradient_accumulation = gr.Number( |
|
label="Number of Gradient Accumulation", value=1, precision=0 |
|
) |
|
seed = gr.Slider(label="Seed", minimum=0, maximum=100000, step=1, randomize=True, value=0) |
|
fp16 = gr.Checkbox(label="FP16", value=True) |
|
use_8bit_adam = gr.Checkbox(label="Use 8bit Adam", value=False) |
|
checkpointing_steps = gr.Number(label="Checkpointing Steps", value=1000, precision=0) |
|
validation_epochs = gr.Number(label="Validation Epochs", value=100, precision=0) |
|
gr.Markdown( |
|
""" |
|
- The base model must be a Stable Diffusion model compatible with [diffusers](https://github.com/huggingface/diffusers) library. |
|
- Expected time to train a model for 300 steps: ~20 minutes with T4 |
|
- You can check the training status by pressing the "Open logs" button if you are running this on your Space. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("Output Model") |
|
output_model_name = gr.Text(label="Name of your model", placeholder="The surfer man", max_lines=1) |
|
validation_prompt = gr.Text( |
|
label="Validation Prompt", placeholder="prompt to test the model, e.g: a dog is surfing" |
|
) |
|
with gr.Column(): |
|
gr.Markdown("Upload Settings") |
|
with gr.Row(): |
|
upload_to_hub = gr.Checkbox(label="Upload model to Hub", value=True) |
|
use_private_repo = gr.Checkbox(label="Private", value=True) |
|
delete_existing_repo = gr.Checkbox(label="Delete existing repo of the same name", value=False) |
|
upload_to = gr.Radio( |
|
label="Upload to", |
|
choices=[_.value for _ in UploadTarget], |
|
value=UploadTarget.MODEL_LIBRARY.value, |
|
) |
|
|
|
pause_space_after_training = gr.Checkbox( |
|
label="Pause this Space after training", |
|
value=False, |
|
interactive=bool(os.getenv("SPACE_ID")), |
|
visible=False, |
|
) |
|
run_button = gr.Button("Start Training", interactive=not disable_run_button) |
|
|
|
with gr.Box(): |
|
gr.Text(label="Log", value=read_log, lines=10, max_lines=10, every=1) |
|
|
|
if pipe is not None: |
|
run_button.click(fn=pipe.clear) |
|
run_button.click( |
|
fn=trainer.run, |
|
inputs=[ |
|
training_video, |
|
training_prompt, |
|
output_model_name, |
|
delete_existing_repo, |
|
validation_prompt, |
|
base_model, |
|
resolution, |
|
num_training_steps, |
|
learning_rate, |
|
gradient_accumulation, |
|
seed, |
|
fp16, |
|
use_8bit_adam, |
|
checkpointing_steps, |
|
validation_epochs, |
|
upload_to_hub, |
|
use_private_repo, |
|
delete_existing_repo, |
|
upload_to, |
|
pause_space_after_training, |
|
hf_token, |
|
], |
|
) |
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
trainer = Trainer() |
|
demo = create_training_demo(trainer) |
|
demo.queue(api_open=False, max_size=1).launch() |
|
|