Spaces:
Running
on
A10G
Running
on
A10G
File size: 11,518 Bytes
0a3525d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
from dataclasses import dataclass
from typing import Any, Optional
import lightning as L
import loralib as lora
import torch
import torch.nn.functional as F
from lightning.pytorch.utilities.types import OptimizerLRScheduler
import fish_speech.utils as utils
from fish_speech.models.text2semantic.llama import NaiveTransformer
log = utils.RankedLogger(__name__, rank_zero_only=True)
@dataclass
class LoraConfig:
r: int
lora_alpha: float
lora_dropout: float = 0.0
class TextToSemantic(L.LightningModule):
def __init__(
self,
model: NaiveTransformer,
optimizer: Any,
lr_scheduler: Any,
lora_config: Optional[LoraConfig] = None,
save_lora_only: bool = False,
use_dpo: bool = False,
dpo_beta: float = 0.2,
):
super().__init__()
self.model = model
self.optimizer_builder = optimizer
self.lr_scheduler_builder = lr_scheduler
self.lora_config = lora_config
self.save_lora_only = save_lora_only
self.use_dpo = use_dpo # We don't support reference model yet
self.dpo_beta = dpo_beta
if self.lora_config is not None:
self.setup_lora()
def setup_lora(self):
# Replace the embedding layer with a LoRA layer
self.model.embeddings = lora.Embedding(
num_embeddings=self.model.embeddings.num_embeddings,
embedding_dim=self.model.embeddings.embedding_dim,
padding_idx=self.model.embeddings.padding_idx,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
)
# Replace output layer with a LoRA layer
linears = [(self.model, "output")]
# Replace all linear layers with LoRA layers
for layer in self.model.layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
if hasattr(self.model, "fast_layers"):
# Dual-AR model
linears.extend([(self.model, "fast_output")])
for layer in self.model.fast_layers:
linears.extend([(layer.attention, "wqkv"), (layer.attention, "wo")])
linears.extend(
[
(layer.feed_forward, "w1"),
(layer.feed_forward, "w2"),
(layer.feed_forward, "w3"),
]
)
for module, layer in linears:
updated_linear = lora.Linear(
in_features=getattr(module, layer).in_features,
out_features=getattr(module, layer).out_features,
bias=getattr(module, layer).bias,
r=self.lora_config.r,
lora_alpha=self.lora_config.lora_alpha,
lora_dropout=self.lora_config.lora_dropout,
)
setattr(module, layer, updated_linear)
# Mark only the LoRA layers as trainable
lora.mark_only_lora_as_trainable(self.model, bias="lora_only")
def forward(self, x):
return self.model(x)
def on_save_checkpoint(self, checkpoint):
if self.lora_config is None or self.save_lora_only is False:
return
# Save only LoRA parameters
state_dict = checkpoint["state_dict"]
for name in list(state_dict.keys()):
if "lora" not in name:
state_dict.pop(name)
def configure_optimizers(self) -> OptimizerLRScheduler:
# Get weight decay parameters
weight_decay_parameters, other_parameters = [], []
for name, param in self.named_parameters():
if ".bias" in name or "norm.weight" in name or ".embeddings." in name:
other_parameters.append(param)
else:
weight_decay_parameters.append(param)
optimizer = self.optimizer_builder(
[
{"params": weight_decay_parameters},
{"params": other_parameters, "weight_decay": 0.0},
]
)
# Print the parameters and their weight decay
for i in optimizer.param_groups:
log.info(
f"Set weight decay: {i['weight_decay']} for {len(i['params'])} parameters"
)
lr_scheduler = self.lr_scheduler_builder(optimizer)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
},
}
# Copied from https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
def get_batch_logps(
self,
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""Compute the log probabilities of the given labels under the given logits.
Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, codebook_size, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length, codebook_size)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
"""
assert logits.shape[:-1] == labels.shape
labels = labels.clone()
loss_mask = labels != -100
# dummy token; we'll ignore the losses on these tokens later
labels[labels == -100] = 0
per_token_logps = torch.gather(
logits.log_softmax(-1), dim=-1, index=labels.unsqueeze(-1)
).squeeze(-1)
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
def _step(self, batch, batch_idx, stage: str):
is_train = stage == "train"
# Do positive and negative samples in the same batch to speed up training
labels = batch["labels"]
outputs = self.model(
inp=batch["inputs"],
key_padding_mask=batch["attention_masks"],
)
token_logits = outputs.token_logits
codebook_logits = outputs.codebook_logits
if self.use_dpo:
# Firtst half is positive, second half is negative
token_logits, negative_token_logits = token_logits.chunk(2)
codebook_logits, negative_codebook_logits = codebook_logits.chunk(2)
labels, negative_labels = labels.chunk(2)
# Generate labels
base_loss = F.cross_entropy(
token_logits.reshape(-1, token_logits.size(-1)),
labels[:, 0].reshape(-1),
ignore_index=-100,
)
codebook_labels = labels[:, 1 : 1 + self.model.config.num_codebooks].mT
semantic_loss = F.cross_entropy(
codebook_logits.reshape(-1, codebook_logits.size(-1)),
codebook_labels.reshape(-1),
ignore_index=-100,
)
loss = base_loss + semantic_loss
# If we use dpo
if self.use_dpo:
negative_codebook_labels = negative_labels[
:, 1 : 1 + self.model.config.num_codebooks
].mT
positive_codebook_logps = self.get_batch_logps(
codebook_logits, codebook_labels
)
negative_codebook_logps = self.get_batch_logps(
negative_codebook_logits, negative_codebook_labels
)
# TODO: implement the reference model, avoid screwing up the gradients
dpo_loss = -F.logsigmoid(
(positive_codebook_logps - negative_codebook_logps) * self.dpo_beta
).mean()
chosen_rewards = self.dpo_beta * positive_codebook_logps.detach()
rejected_rewards = self.dpo_beta * negative_codebook_logps.detach()
reward_accuracy = (chosen_rewards > rejected_rewards).float().mean()
chosen_rewards, rejected_rewards = (
chosen_rewards.mean(),
rejected_rewards.mean(),
)
loss = loss + dpo_loss
self.log(
f"{stage}/dpo_loss",
dpo_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/chosen_rewards",
chosen_rewards,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/rejected_rewards",
rejected_rewards,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/reward_accuracy",
reward_accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/loss",
loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
self.log(
f"{stage}/base_loss",
base_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
self.log(
f"{stage}/semantic_loss",
semantic_loss,
on_step=is_train,
on_epoch=not is_train,
prog_bar=False,
logger=True,
)
# Top-5 accuracy
accuracy = self.get_accuracy(codebook_logits, codebook_labels)
self.log(
f"{stage}/top_5_accuracy",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
if self.model.config.num_codebooks != self.model.config.num_in_codebooks:
accuracy = self.get_accuracy(
codebook_logits[:, :, : self.model.config.num_in_codebooks],
codebook_labels[:, :, : self.model.config.num_in_codebooks],
)
self.log(
f"{stage}/top_5_accuracy_in",
accuracy,
on_step=is_train,
on_epoch=not is_train,
prog_bar=True,
logger=True,
)
return loss
def get_accuracy(self, logits, labels):
_, indices = logits.topk(5, dim=-1)
correct = indices.eq(labels.unsqueeze(-1))
correct[labels == -100] = 0
correct = correct.sum()
accuracy = correct / (labels != -100).sum()
return accuracy
def training_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
|