lengyue233's picture
Init hf space integration
0a3525d verified
raw
history blame
No virus
11.5 kB
from dataclasses import dataclass
from typing import Any, Optional
import lightning as L
import loralib as lora
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler
import fish_speech.utils as utils
from fish_speech.models.text2semantic.llama import NaiveTransformer
log = utils.RankedLogger(__name__, rank_zero_only=True)
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
class TextToSemantic(L.LightningModule):
def __init__(
self,
model: NaiveTransformer,
optimizer: Any,
lr_scheduler: Any,
lora_config: Optional[LoraConfig] = None,
save_lora_only: bool = False,
use_dpo: bool = False,
dpo_beta: float = 0.2,
):
super().__init__()
self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
self.lora_config = lora_config
self.save_lora_only = save_lora_only
self.use_dpo = use_dpo # We don't support reference model yet
self.dpo_beta = dpo_beta
if self.lora_config is not None:
self.setup_lora()
def setup_lora(self):
# Replace the embedding layer with a LoRA layer
self.model.embeddings = lora.Embedding(
num_embeddings=self.model.embeddings.num_embeddings,
embedding_dim=self.model.embeddings.embedding_dim,
padding_idx=self.model.embeddings.padding_idx,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
)
# Replace output layer with a LoRA layer
linears = [(self.model, "output")]
# Replace all linear layers with LoRA layers
for layer in self.model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(self.model, "fast_layers"):
# Dual-AR model
linears.extend([(self.model, "fast_output")])
for layer in self.model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
lora_dropout=self.lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
def forward(self, x):
return self.model(x)
def on_save_checkpoint(self, checkpoint):
if self.lora_config is None or self.save_lora_only is False:
return
# Save only LoRA parameters
state_dict = checkpoint["state_dict"]
for name in list(state_dict.keys()):
if "lora" not in name:
state_dict.pop(name)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Get weight decay parameters
weight_decay_parameters, other_parameters = [], []
for name, param in self.named_parameters():
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
other_parameters.append(param)
else:
weight_decay_parameters.append(param)
optimizer = self.optimizer_builder(
[
{"params": weight_decay_parameters},
{"params": other_parameters, "weight_decay": 0.0},
]
)
# Print the parameters and their weight decay
for i in optimizer.param_groups:
log.info(
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler = self.lr_scheduler_builder(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
labels = labels.clone()
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def _step(self, batch, batch_idx, stage: str):
is_train = stage == "train"
# Do positive and negative samples in the same batch to speed up training
labels = batch["labels"]
outputs = self.model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
if self.use_dpo:
# Firtst half is positive, second half is negative
token_logits, negative_token_logits = token_logits.chunk(2)
codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
labels, negative_labels = labels.chunk(2)
# Generate labels
base_loss = F.cross_entropy(
token_logits.reshape(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
)
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.reshape(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
)
loss = base_loss + semantic_loss
# If we use dpo
if self.use_dpo:
negative_codebook_labels = negative_labels[
:, 1 : 1 + self.model.config.num_codebooks
].mT
positive_codebook_logps = self.get_batch_logps(
codebook_logits, codebook_labels
)
negative_codebook_logps = self.get_batch_logps(
negative_codebook_logits, negative_codebook_labels
)
# TODO: implement the reference model, avoid screwing up the gradients
dpo_loss = -F.logsigmoid(
(positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
).mean()
chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
chosen_rewards, rejected_rewards = (
chosen_rewards.mean(),
rejected_rewards.mean(),
)
loss = loss + dpo_loss
self.log(
f"{stage}/dpo_loss",
dpo_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/chosen_rewards",
chosen_rewards,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/rejected_rewards",
rejected_rewards,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/reward_accuracy",
reward_accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/loss",
loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
self.log(
f"{stage}/base_loss",
base_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/semantic_loss",
semantic_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
# Top-5 accuracy
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
self.log(
f"{stage}/top_5_accuracy",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
accuracy = self.get_accuracy(
codebook_logits[:, :, : self.model.config.num_in_codebooks],
codebook_labels[:, :, : self.model.config.num_in_codebooks],
)
self.log(
f"{stage}/top_5_accuracy_in",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
return loss
def get_accuracy(self, logits, labels):
_, indices = logits.topk(5, dim=-1)
correct = indices.eq(labels.unsqueeze(-1))
correct[labels == -100] = 0
correct = correct.sum()
accuracy = correct / (labels != -100).sum()
return accuracy
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")