# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel import torch import torch.nn.functional as F import numpy as np import os import torch.nn as nn class ValleAR(nn.Module): def __init__( self, phone_vocab_size=256, target_vocab_size=1024, hidden_size=1024, intermediate_size=4096, num_hidden_layers=12, num_attention_heads=16, pad_token_id=1281, bos_target_id=1282, eos_target_id=1283, bos_phone_id=1284, eos_phone_id=1285, use_input_embeds=False, emb_dim=256, **kwargs, ): super(ValleAR, self).__init__() self.config = LlamaConfig( vocab_size=phone_vocab_size + target_vocab_size + 10, hidden_size=hidden_size, intermediate_size=intermediate_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, pad_token_id=pad_token_id, bos_token_id=bos_target_id, eos_token_id=eos_target_id, ) self.phone_vocab_size = phone_vocab_size self.target_vocab_size = target_vocab_size self.pad_token_id = pad_token_id self.bos_target_id = bos_target_id self.eos_target_id = eos_target_id self.bos_phone_id = bos_phone_id self.eos_phone_id = eos_phone_id self.model = LlamaForCausalLM(self.config) self.use_input_embeds = use_input_embeds # no input embedding is used to provide speaker information if self.use_input_embeds: self.emb_linear = nn.Linear(emb_dim, hidden_size) self.emb_linear.weight.data.normal_(mean=0.0, std=0.01) self.emb_linear.bias.data.zero_() def forward( self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None ): if input_embeds is not None: input_embeds = self.emb_linear(input_embeds) phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label( phone_ids, phone_mask, self.eos_phone_id, self.bos_phone_id, self.pad_token_id, ) target_ids, target_mask, target_label = self.add_target_eos_bos_label( target_ids, target_mask, self.eos_target_id, self.bos_target_id, self.pad_token_id, ) input_token_ids = torch.cat([phone_ids, target_ids], dim=-1) attention_mask = torch.cat([phone_mask, target_mask], dim=-1) # breakpoint() if input_embeds is not None: raise NotImplementedError attention_mask = torch.cat( [ torch.ones( (input_embeds.shape[0], input_embeds.shape[1]), dtype=attention_mask.dtype, device=attention_mask.device, ), attention_mask, ], dim=-1, ) labels = torch.cat([phone_label, target_label], dim=-1) if input_embeds is not None: raise NotImplementedError labels = torch.cat( [ -100 * torch.ones( (input_embeds.shape[0], input_embeds.shape[1]), dtype=labels.dtype, device=labels.device, ), labels, ], dim=-1, ) if input_embeds is not None: raise NotImplementedError inputs_embeds = torch.cat( [input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 ) out = self.model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True, ) return out out = self.model( input_token_ids, attention_mask=attention_mask, labels=labels, return_dict=True, ) # calcualte top1, top5, top10 accuracy logits = out.logits logits = logits[:, -target_ids.shape[1] :] top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:] top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum() top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1] top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1) top5_acc = ( top5_acc * target_mask[..., :-1].unsqueeze(-1) ).sum() / target_mask.sum() top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1] top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1) top10_acc = ( top10_acc * target_mask[..., :-1].unsqueeze(-1) ).sum() / target_mask.sum() out.top1_acc = top1_acc out.top5_acc = top5_acc out.top10_acc = top10_acc return out def add_phone_eos_bos_label( self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id ): # phone_ids: [B, T] # phone_mask: [B, T] phone_ids = phone_ids + self.target_vocab_size * phone_mask phone_ids = phone_ids * phone_mask phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad( 1 - phone_mask, (0, 1), value=1 ) # make pad token eos token, add eos token at the end phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask phone_ids = phone_ids * phone_mask + pad_token_id * ( 1 - phone_mask ) # restore pad token ids phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask phone_label = -100 * torch.ones_like( phone_ids ) # loss for entire phone is not computed (passed to llama) return phone_ids, phone_mask, phone_label def add_target_eos_bos_label( self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id ): # target_ids: [B, T] # target_mask: [B, T] target_ids = target_ids * target_mask target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad( 1 - target_mask, (0, 1), value=1 ) target_mask = F.pad(target_mask, (1, 0), value=1) target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask) target_ids = F.pad(target_ids, (1, 0), value=target_bos_id) target_mask = F.pad(target_mask, (1, 0), value=1) target_label = target_ids * target_mask + (-100) * ( 1 - target_mask ) # loss for target is computed on unmasked tokens return target_ids, target_mask, target_label def sample_hf( self, phone_ids, # the phones of prompt and target should be concatenated together prompt_ids, inputs_embeds=None, max_length=2000, temperature=1.0, top_k=100, top_p=0.9, repeat_penalty=1.0, num_beams=1, ): if inputs_embeds is not None: inputs_embeds = self.emb_linear(inputs_embeds) phone_mask = torch.ones_like(phone_ids) prompt_mask = torch.ones_like(prompt_ids) phone_ids, _, _ = self.add_phone_eos_bos_label( phone_ids, phone_mask, self.eos_phone_id, self.bos_phone_id, self.pad_token_id, ) prompt_ids, _, _ = self.add_target_eos_bos_label( prompt_ids, prompt_mask, self.eos_target_id, self.bos_target_id, self.pad_token_id, ) prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1) if inputs_embeds is not None: raise NotImplementedError inputs_embeds = torch.cat( [inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1 ) generated_ids = self.model.generate( inputs_embeds=inputs_embeds, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, ) gen_tokens = generated_ids[:, :-1] return gen_tokens input_length = input_token_ids.shape[1] generated_ids = self.model.generate( input_token_ids, do_sample=True, max_length=max_length, pad_token_id=self.pad_token_id, eos_token_id=self.eos_target_id, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repeat_penalty, num_beams=num_beams, ) gen_tokens = generated_ids[:, input_length:-1] return gen_tokens def test(): model = ValleAR() phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]]) phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]]) target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1) target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1) optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) for i in range(15): optimizer.zero_grad() out = model( phone_ids=phone_ids, phone_mask=phone_mask, target_ids=target_ids, target_mask=target_mask, ) loss = out.loss loss.backward() optimizer.step() print(f"iter={i}, {loss}.") phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1) target_ids = torch.LongTensor([765, 234]).reshape(1, -1) sampled = model.sample_hf(phone_ids, target_ids) breakpoint() if __name__ == "__main__": test()