from dataclasses import dataclass from typing import Any, Optional import lightning as L import loralib as lora import torch import torch.nn.functional as F from lightning.pytorch.utilities.types import OptimizerLRScheduler import fish_speech.utils as utils from fish_speech.models.text2semantic.llama import NaiveTransformer log = utils.RankedLogger(__name__, rank_zero_only=True) @dataclass class LoraConfig: r: int lora_alpha: float lora_dropout: float = 0.0 class TextToSemantic(L.LightningModule): def __init__( self, model: NaiveTransformer, optimizer: Any, lr_scheduler: Any, lora_config: Optional[LoraConfig] = None, save_lora_only: bool = False, use_dpo: bool = False, dpo_beta: float = 0.2, ): super().__init__() self.model = model self.optimizer_builder = optimizer self.lr_scheduler_builder = lr_scheduler self.lora_config = lora_config self.save_lora_only = save_lora_only self.use_dpo = use_dpo # We don't support reference model yet self.dpo_beta = dpo_beta if self.lora_config is not None: self.setup_lora() def setup_lora(self): # Replace the embedding layer with a LoRA layer self.model.embeddings = lora.Embedding( num_embeddings=self.model.embeddings.num_embeddings, embedding_dim=self.model.embeddings.embedding_dim, padding_idx=self.model.embeddings.padding_idx, r=self.lora_config.r, lora_alpha=self.lora_config.lora_alpha, ) # Replace output layer with a LoRA layer linears = [(self.model, "output")] # Replace all linear layers with LoRA layers for layer in self.model.layers: linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) linears.extend( [ (layer.feed_forward, "w1"), (layer.feed_forward, "w2"), (layer.feed_forward, "w3"), ] ) if hasattr(self.model, "fast_layers"): # Dual-AR model linears.extend([(self.model, "fast_output")]) for layer in self.model.fast_layers: linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")]) linears.extend( [ (layer.feed_forward, "w1"), (layer.feed_forward, "w2"), (layer.feed_forward, "w3"), ] ) for module, layer in linears: updated_linear = lora.Linear( in_features=getattr(module, layer).in_features, out_features=getattr(module, layer).out_features, bias=getattr(module, layer).bias, r=self.lora_config.r, lora_alpha=self.lora_config.lora_alpha, lora_dropout=self.lora_config.lora_dropout, ) setattr(module, layer, updated_linear) # Mark only the LoRA layers as trainable lora.mark_only_lora_as_trainable(self.model, bias="lora_only") def forward(self, x): return self.model(x) def on_save_checkpoint(self, checkpoint): if self.lora_config is None or self.save_lora_only is False: return # Save only LoRA parameters state_dict = checkpoint["state_dict"] for name in list(state_dict.keys()): if "lora" not in name: state_dict.pop(name) def configure_optimizers(self) -> OptimizerLRScheduler: # Get weight decay parameters weight_decay_parameters, other_parameters = [], [] for name, param in self.named_parameters(): if ".bias" in name or "norm.weight" in name or ".embeddings." in name: other_parameters.append(param) else: weight_decay_parameters.append(param) optimizer = self.optimizer_builder( [ {"params": weight_decay_parameters}, {"params": other_parameters, "weight_decay": 0.0}, ] ) # Print the parameters and their weight decay for i in optimizer.param_groups: log.info( f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters" ) lr_scheduler = self.lr_scheduler_builder(optimizer) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "interval": "step", }, } # Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 def get_batch_logps( self, logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False, ) -> torch.FloatTensor: """Compute the log probabilities of the given labels under the given logits. Args: logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size) labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size) average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. Returns: A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. """ assert logits.shape[:-1] == labels.shape labels = labels.clone() loss_mask = labels != -100 # dummy token; we'll ignore the losses on these tokens later labels[labels == -100] = 0 per_token_logps = torch.gather( logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1) ).squeeze(-1) if average_log_prob: return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) else: return (per_token_logps * loss_mask).sum(-1) def _step(self, batch, batch_idx, stage: str): is_train = stage == "train" # Do positive and negative samples in the same batch to speed up training labels = batch["labels"] outputs = self.model( inp=batch["inputs"], key_padding_mask=batch["attention_masks"], ) token_logits = outputs.token_logits codebook_logits = outputs.codebook_logits if self.use_dpo: # Firtst half is positive, second half is negative token_logits, negative_token_logits = token_logits.chunk(2) codebook_logits, negative_codebook_logits = codebook_logits.chunk(2) labels, negative_labels = labels.chunk(2) # Generate labels base_loss = F.cross_entropy( token_logits.reshape(-1, token_logits.size(-1)), labels[:, 0].reshape(-1), ignore_index=-100, ) codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT semantic_loss = F.cross_entropy( codebook_logits.reshape(-1, codebook_logits.size(-1)), codebook_labels.reshape(-1), ignore_index=-100, ) loss = base_loss + semantic_loss # If we use dpo if self.use_dpo: negative_codebook_labels = negative_labels[ :, 1 : 1 + self.model.config.num_codebooks ].mT positive_codebook_logps = self.get_batch_logps( codebook_logits, codebook_labels ) negative_codebook_logps = self.get_batch_logps( negative_codebook_logits, negative_codebook_labels ) # TODO: implement the reference model, avoid screwing up the gradients dpo_loss = -F.logsigmoid( (positive_codebook_logps - negative_codebook_logps) * self.dpo_beta ).mean() chosen_rewards = self.dpo_beta * positive_codebook_logps.detach() rejected_rewards = self.dpo_beta * negative_codebook_logps.detach() reward_accuracy = (chosen_rewards > rejected_rewards).float().mean() chosen_rewards, rejected_rewards = ( chosen_rewards.mean(), rejected_rewards.mean(), ) loss = loss + dpo_loss self.log( f"{stage}/dpo_loss", dpo_loss, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) self.log( f"{stage}/chosen_rewards", chosen_rewards, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) self.log( f"{stage}/rejected_rewards", rejected_rewards, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) self.log( f"{stage}/reward_accuracy", reward_accuracy, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) self.log( f"{stage}/loss", loss, on_step=is_train, on_epoch=not is_train, prog_bar=True, logger=True, ) self.log( f"{stage}/base_loss", base_loss, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) self.log( f"{stage}/semantic_loss", semantic_loss, on_step=is_train, on_epoch=not is_train, prog_bar=False, logger=True, ) # Top-5 accuracy accuracy = self.get_accuracy(codebook_logits, codebook_labels) self.log( f"{stage}/top_5_accuracy", accuracy, on_step=is_train, on_epoch=not is_train, prog_bar=True, logger=True, ) if self.model.config.num_codebooks != self.model.config.num_in_codebooks: accuracy = self.get_accuracy( codebook_logits[:, :, : self.model.config.num_in_codebooks], codebook_labels[:, :, : self.model.config.num_in_codebooks], ) self.log( f"{stage}/top_5_accuracy_in", accuracy, on_step=is_train, on_epoch=not is_train, prog_bar=True, logger=True, ) return loss def get_accuracy(self, logits, labels): _, indices = logits.topk(5, dim=-1) correct = indices.eq(labels.unsqueeze(-1)) correct[labels == -100] = 0 correct = correct.sum() accuracy = correct / (labels != -100).sum() return accuracy def training_step(self, batch, batch_idx): return self._step(batch, batch_idx, "train") def validation_step(self, batch, batch_idx): return self._step(batch, batch_idx, "val")