File size: 5,248 Bytes
1c6b840 c4ec89c 1c6b840 053eb5b 1c6b840 c4ec89c 1c6b840 053eb5b 1c6b840 053eb5b c4ec89c 2c3d5a3 3f65edd 053eb5b 2c3d5a3 c4ec89c 1c6b840 8e999ad 1c6b840 8e999ad c4ec89c 1c6b840 8e999ad 1c6b840 8e999ad 053eb5b 1c6b840 8e999ad 1c6b840 8e999ad 1c6b840 187d4d3 1c6b840 053eb5b c4ec89c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
from torch import nn
import torchaudio
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer, HubertModel, AutoProcessor, AutoConfig, AutoModel
from .config import SpeechLLMModelConfig
from peft import LoraConfig, get_peft_model
class HubertXCNNEnoder(nn.Module):
def __init__(self, audio_enc_dim, llm_dim, encoder_name):
super().__init__()
config = AutoConfig.from_pretrained(encoder_name)
self.encoder = AutoModel.from_config(config)
self.cnn = nn.Sequential(
nn.ReLU(),
nn.Conv1d(audio_enc_dim, llm_dim // 2, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.Conv1d(llm_dim // 2, llm_dim, kernel_size=5, stride=2, padding=0),
nn.ReLU(),
nn.Conv1d(llm_dim, llm_dim, kernel_size=3, stride=1, padding=0),
)
def forward(self, x):
x = self.encoder(x).last_hidden_state
x = self.cnn(x.transpose(1, 2)).transpose(1, 2)
return x
def return_device(self):
return next(self.parameters()).device
class SpeechLLMModel(PreTrainedModel):
config_class = SpeechLLMModelConfig
def __init__(self, config):
super().__init__(config)
self.audio_processor = AutoProcessor.from_pretrained(config.audio_processor_name)
self.audio_encoder = HubertXCNNEnoder(config.audio_enc_dim, config.llm_dim, config.audio_encoder_name)
llm_config = AutoConfig.from_pretrained(config.llm_model_name)
self.llm_model = AutoModelForCausalLM.from_config(llm_config)
self.llm_tokenizer = AutoTokenizer.from_pretrained(config.llm_model_name)
self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
peft_config = LoraConfig(
r=4,
lora_alpha=8,
target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'up_proj', 'down_proj', 'gate_proj'],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
self.llm_model = get_peft_model(self.llm_model, peft_config)
self.llm_model = self.llm_model.merge_and_unload()
def encode(self, speech, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids):
batch_size = speech.shape[0]
with torch.no_grad():
speech_embeds = self.audio_encoder(speech)
embedder = self.llm_model.model.embed_tokens
pre_prompt_embeds = embedder(pre_tokenized_ids)
post_prompt_embeds = embedder(post_tokenized_ids)
output_prompt_embeds = embedder(output_tokenized_ids)
combined_embeds = torch.cat([pre_prompt_embeds, speech_embeds, post_prompt_embeds, output_prompt_embeds], dim=1)
atts = torch.ones(combined_embeds.size()[:-1], dtype=torch.long).to(combined_embeds.device)
input_token_length = pre_tokenized_ids.shape[1] + speech_embeds.shape[1] + post_tokenized_ids.shape[1]
label_ids = torch.cat([
torch.ones([batch_size, input_token_length], device=combined_embeds.device) * -100,
output_tokenized_ids
], 1).to(combined_embeds.device).to(torch.int64)
return combined_embeds, atts, label_ids
def forward(self, audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids, attention_mask=None):
combined_embeds, atts, label_ids = self.encode(audio_tensor, pre_tokenized_ids, post_tokenized_ids, output_tokenized_ids)
outputs = self.llm_model(inputs_embeds=combined_embeds, attention_mask=attention_mask)
return outputs
def generate_meta(self, audio_path=None, audio_tensor=None, instruction="Give me the following information about the audio [Transcript]", max_new_tokens=2000):
device = self.audio_encoder.return_device()
pre_speech_prompt = f'''Instruction:
{instruction}
Input:
<speech>'''
post_speech_prompt = f'''</speech>
Output:'''
output_prompt = '\n<s>'
with torch.no_grad():
if audio_tensor == None and audio_path != None:
audio_tensor, sr = torchaudio.load(audio_path)
audio_tensor = self.audio_processor(audio_tensor.squeeze(), return_tensors="pt", sampling_rate=16000).input_values
pre_tokenized_ids = self.llm_tokenizer(pre_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
post_tokenized_ids = self.llm_tokenizer(post_speech_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
output_tokenized_ids = self.llm_tokenizer(output_prompt, padding="do_not_pad", return_tensors='pt', truncation=False, add_special_tokens=False)["input_ids"]
combined_embeds, atts, label_ids = self.encode(audio_tensor.to(device), pre_tokenized_ids.to(device), post_tokenized_ids.to(device), output_tokenized_ids.to(device))
out = self.llm_model.generate(
inputs_embeds=combined_embeds,
max_new_tokens=max_new_tokens,
pad_token_id=self.llm_tokenizer.pad_token_id
).cpu().tolist()[0]
output_text = self.llm_tokenizer.decode(out, skip_special_tokens=True)
return output_text |