import gradio as gr from typing import Tuple from .custom_logging import setup_logging # Set up logging log = setup_logging() class BasicTraining: """ This class configures and initializes the basic training settings for a machine learning model, including options for SDXL, learning rate, learning rate scheduler, and training epochs. Attributes: sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training. learning_rate_value (str): Initial learning rate value. lr_scheduler_value (str): Initial learning rate scheduler value. lr_warmup_value (str): Initial learning rate warmup value. finetuning (bool): If True, enables fine-tuning of the model. dreambooth (bool): If True, enables Dreambooth training. """ def __init__( self, sdxl_checkbox: gr.Checkbox, learning_rate_value: float = "1e-6", lr_scheduler_value: str = "constant", lr_warmup_value: float = "0", finetuning: bool = False, dreambooth: bool = False, config: dict = {}, ) -> None: """ Initializes the BasicTraining object with the given parameters. Args: sdxl_checkbox (gr.Checkbox): Checkbox to enable SDXL training. learning_rate_value (str): Initial learning rate value. lr_scheduler_value (str): Initial learning rate scheduler value. lr_warmup_value (str): Initial learning rate warmup value. finetuning (bool): If True, enables fine-tuning of the model. dreambooth (bool): If True, enables Dreambooth training. """ self.sdxl_checkbox = sdxl_checkbox self.learning_rate_value = learning_rate_value self.lr_scheduler_value = lr_scheduler_value self.lr_warmup_value = lr_warmup_value self.finetuning = finetuning self.dreambooth = dreambooth self.config = config self.old_lr_warmup = 0 # Initialize the UI components self.initialize_ui_components() def initialize_ui_components(self) -> None: """ Initializes the UI components for the training settings. """ # Initialize the training controls self.init_training_controls() # Initialize the precision and resources controls self.init_precision_and_resources_controls() # Initialize the learning rate and optimizer controls self.init_lr_and_optimizer_controls() # Initialize the gradient and learning rate controls self.init_grad_and_lr_controls() # Initialize the learning rate controls self.init_learning_rate_controls() # Initialize the scheduler controls self.init_scheduler_controls() # Initialize the resolution and bucket controls self.init_resolution_and_bucket_controls() # Setup the behavior of the SDXL checkbox self.setup_sdxl_checkbox_behavior() def init_training_controls(self) -> None: """ Initializes the training controls for the model. """ # Create a row for the training controls with gr.Row(): # Initialize the train batch size slider self.train_batch_size = gr.Slider( minimum=1, maximum=64, label="Train batch size", value=1, step=self.config.get("basic.train_batch_size", 1), ) # Initialize the epoch number input self.epoch = gr.Number( label="Epoch", value=self.config.get("basic.epoch", 1), precision=0 ) # Initialize the maximum train epochs input self.max_train_epochs = gr.Number( label="Max train epoch", info="training epochs (overrides max_train_steps). 0 = no override", step=1, # precision=0, minimum=0, value=self.config.get("basic.max_train_epochs", 0), ) # Initialize the maximum train steps input self.max_train_steps = gr.Number( label="Max train steps", info="Overrides # training steps. 0 = no override", step=1, # precision=0, value=self.config.get("basic.max_train_steps", 1600), ) # Initialize the save every N epochs input self.save_every_n_epochs = gr.Number( label="Save every N epochs", value=self.config.get("basic.save_every_n_epochs", 1), precision=0, ) # Initialize the caption extension input self.caption_extension = gr.Dropdown( label="Caption file extension", choices=["", ".cap", ".caption", ".txt"], value=".txt", interactive=True, ) def init_precision_and_resources_controls(self) -> None: """ Initializes the precision and resources controls for the model. """ with gr.Row(): # Initialize the seed textbox self.seed = gr.Number( label="Seed", # precision=0, step=1, minimum=0, value=self.config.get("basic.seed", 0), info="Set to 0 to make random", ) # Initialize the cache latents checkbox self.cache_latents = gr.Checkbox( label="Cache latents", value=self.config.get("basic.cache_latents", True), ) # Initialize the cache latents to disk checkbox self.cache_latents_to_disk = gr.Checkbox( label="Cache latents to disk", value=self.config.get("basic.cache_latents_to_disk", False), ) def init_lr_and_optimizer_controls(self) -> None: """ Initializes the learning rate and optimizer controls for the model. """ with gr.Row(): # Initialize the learning rate scheduler dropdown self.lr_scheduler = gr.Dropdown( label="LR Scheduler", choices=[ "adafactor", "constant", "constant_with_warmup", "cosine", "cosine_with_restarts", "linear", "polynomial", ], value=self.config.get("basic.lr_scheduler", self.lr_scheduler_value), ) # Initialize the optimizer dropdown self.optimizer = gr.Dropdown( label="Optimizer", choices=[ "AdamW", "AdamW8bit", "Adafactor", "DAdaptation", "DAdaptAdaGrad", "DAdaptAdam", "DAdaptAdan", "DAdaptAdanIP", "DAdaptAdamPreprint", "DAdaptLion", "DAdaptSGD", "Lion", "Lion8bit", "PagedAdamW8bit", "PagedAdamW32bit", "PagedLion8bit", "Prodigy", "SGDNesterov", "SGDNesterov8bit", ], value=self.config.get("basic.optimizer", "AdamW8bit"), interactive=True, ) def init_grad_and_lr_controls(self) -> None: """ Initializes the gradient and learning rate controls for the model. """ with gr.Row(): # Initialize the maximum gradient norm slider self.max_grad_norm = gr.Slider( label="Max grad norm", value=self.config.get("basic.max_grad_norm", 1.0), minimum=0.0, maximum=1.0, interactive=True, ) # Initialize the learning rate scheduler extra arguments textbox self.lr_scheduler_args = gr.Textbox( label="LR scheduler extra arguments", lines=2, placeholder="(Optional) eg: milestones=[1,10,30,50] gamma=0.1", value=self.config.get("basic.lr_scheduler_args", ""), ) # Initialize the optimizer extra arguments textbox self.optimizer_args = gr.Textbox( label="Optimizer extra arguments", lines=2, placeholder="(Optional) eg: relative_step=True scale_parameter=True warmup_init=True", value=self.config.get("basic.optimizer_args", ""), ) def init_learning_rate_controls(self) -> None: """ Initializes the learning rate controls for the model. """ with gr.Row(): # Adjust visibility based on training modes lr_label = ( "Learning rate Unet" if self.finetuning or self.dreambooth else "Learning rate" ) # Initialize the learning rate number input self.learning_rate = gr.Number( label=lr_label, value=self.config.get("basic.learning_rate", self.learning_rate_value), minimum=0, maximum=1, info="Set to 0 to not train the Unet", ) # Initialize the learning rate TE number input self.learning_rate_te = gr.Number( label="Learning rate TE", value=self.config.get( "basic.learning_rate_te", self.learning_rate_value ), visible=self.finetuning or self.dreambooth, minimum=0, maximum=1, info="Set to 0 to not train the Text Encoder", ) # Initialize the learning rate TE1 number input self.learning_rate_te1 = gr.Number( label="Learning rate TE1", value=self.config.get( "basic.learning_rate_te1", self.learning_rate_value ), visible=False, minimum=0, maximum=1, info="Set to 0 to not train the Text Encoder 1", ) # Initialize the learning rate TE2 number input self.learning_rate_te2 = gr.Number( label="Learning rate TE2", value=self.config.get( "basic.learning_rate_te2", self.learning_rate_value ), visible=False, minimum=0, maximum=1, info="Set to 0 to not train the Text Encoder 2", ) # Initialize the learning rate warmup slider self.lr_warmup = gr.Slider( label="LR warmup (% of total steps)", value=self.config.get("basic.lr_warmup", self.lr_warmup_value), minimum=0, maximum=100, step=1, ) def lr_scheduler_changed(scheduler, value): if scheduler == "constant": self.old_lr_warmup = value value = 0 interactive=False info="Can't use LR warmup with LR Scheduler constant... setting to 0 and disabling field..." else: if self.old_lr_warmup != 0: value = self.old_lr_warmup self.old_lr_warmup = 0 interactive=True info="" return gr.Slider(value=value, interactive=interactive, info=info) self.lr_scheduler.change( lr_scheduler_changed, inputs=[self.lr_scheduler, self.lr_warmup], outputs=self.lr_warmup, ) def init_scheduler_controls(self) -> None: """ Initializes the scheduler controls for the model. """ with gr.Row(visible=not self.finetuning): # Initialize the learning rate scheduler number of cycles textbox self.lr_scheduler_num_cycles = gr.Number( label="LR # cycles", minimum=1, # precision=0, # round to nearest integer step=1, # Increment value by 1 info="Number of restarts for cosine scheduler with restarts", value=self.config.get("basic.lr_scheduler_num_cycles", 1), ) # Initialize the learning rate scheduler power textbox self.lr_scheduler_power = gr.Number( label="LR power", minimum=0.0, step=0.01, info="Polynomial power for polynomial scheduler", value=self.config.get("basic.lr_scheduler_power", 1.0), ) def init_resolution_and_bucket_controls(self) -> None: """ Initializes the resolution and bucket controls for the model. """ with gr.Row(visible=not self.finetuning): # Initialize the maximum resolution textbox self.max_resolution = gr.Textbox( label="Max resolution", value=self.config.get("basic.max_resolution", "512,512"), placeholder="512,512", ) # Initialize the stop text encoder training slider self.stop_text_encoder_training = gr.Slider( minimum=-1, maximum=100, value=self.config.get("basic.stop_text_encoder_training", 0), step=1, label="Stop TE (% of total steps)", ) # Initialize the enable buckets checkbox self.enable_bucket = gr.Checkbox( label="Enable buckets", value=self.config.get("basic.enable_bucket", True), ) # Initialize the minimum bucket resolution slider self.min_bucket_reso = gr.Slider( label="Minimum bucket resolution", value=self.config.get("basic.min_bucket_reso", 256), minimum=64, maximum=4096, step=64, info="Minimum size in pixel a bucket can be (>= 64)", ) # Initialize the maximum bucket resolution slider self.max_bucket_reso = gr.Slider( label="Maximum bucket resolution", value=self.config.get("basic.max_bucket_reso", 2048), minimum=64, maximum=4096, step=64, info="Maximum size in pixel a bucket can be (>= 64)", ) def setup_sdxl_checkbox_behavior(self) -> None: """ Sets up the behavior of the SDXL checkbox based on the finetuning and dreambooth flags. """ self.sdxl_checkbox.change( self.update_learning_rate_te, inputs=[ self.sdxl_checkbox, gr.Checkbox(value=self.finetuning, visible=False), gr.Checkbox(value=self.dreambooth, visible=False), ], outputs=[ self.learning_rate_te, self.learning_rate_te1, self.learning_rate_te2, ], ) def update_learning_rate_te( self, sdxl_checkbox: gr.Checkbox, finetuning: bool, dreambooth: bool, ) -> Tuple[gr.Number, gr.Number, gr.Number]: """ Updates the visibility of the learning rate TE, TE1, and TE2 based on the SDXL checkbox and finetuning/dreambooth flags. Args: sdxl_checkbox (gr.Checkbox): The SDXL checkbox. finetuning (bool): Whether finetuning is enabled. dreambooth (bool): Whether dreambooth is enabled. Returns: Tuple[gr.Number, gr.Number, gr.Number]: A tuple containing the updated visibility for learning rate TE, TE1, and TE2. """ # Determine the visibility condition based on finetuning and dreambooth flags visibility_condition = finetuning or dreambooth # Return a tuple of gr.Number instances with updated visibility return ( gr.Number(visible=(not sdxl_checkbox and visibility_condition)), gr.Number(visible=(sdxl_checkbox and visibility_condition)), gr.Number(visible=(sdxl_checkbox and visibility_condition)), )