|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, PretrainedConfig, PreTrainedModel, BertConfig, AutoProcessor |
|
from transformers.models.bert.modeling_bert import BertEncoder |
|
from torch import nn |
|
import torch |
|
import os |
|
|
|
|
|
class Desta2Config(PretrainedConfig): |
|
model_type = "DestaModel" |
|
|
|
def __init__( |
|
self, |
|
llama_model_id="meta-llama/Meta-Llama-3-8B-Instruct", |
|
whisper_model_id="openai/whisper-small", |
|
prompt_size=64, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.llama_model_id = llama_model_id |
|
self.whisper_model_id = whisper_model_id |
|
self.prompt_size = prompt_size |
|
|
|
self.whisper_config = AutoConfig.from_pretrained(self.whisper_model_id) |
|
self.llama_config = AutoConfig.from_pretrained(self.llama_model_id) |
|
|
|
class QformerConnector(PreTrainedModel): |
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.cfg = cfg |
|
|
|
|
|
if self.cfg.whisper_model_id == "openai/whisper-medium": |
|
self.target_layer_ids = [5, 11, 17, 23] |
|
elif self.cfg.whisper_model_id == "openai/whisper-small": |
|
self.target_layer_ids = [2, 5, 8, 11] |
|
elif self.cfg.whisper_model_id == "openai/whisper-tiny": |
|
self.target_layer_ids = [0,1,2,3] |
|
elif self.cfg.whisper_model_id == "openai/whisper-large-v3": |
|
self.target_layer_ids = [3, 7, 11, 15, 19, 23, 27, 31] |
|
else: |
|
raise NotImplementedError(f"model_id {self.cfg.whisper_model_id} not implemented") |
|
|
|
|
|
self.layer_prompts = nn.ParameterList([ |
|
nn.Parameter(torch.randn(1, self.cfg.prompt_size, self.cfg.whisper_config.d_model)) for _ in range(len(self.target_layer_ids))] |
|
) |
|
|
|
|
|
|
|
self.layer_weights = nn.Parameter(torch.zeros(self.cfg.prompt_size, len(self.target_layer_ids), dtype=torch.float)) |
|
|
|
qformer_config = BertConfig() |
|
qformer_config.num_hidden_layers = 2 |
|
qformer_config.num_attention_heads = self.cfg.whisper_config.encoder_attention_heads |
|
qformer_config.hidden_size = self.cfg.whisper_config.d_model |
|
qformer_config.add_cross_attention = True |
|
qformer_config.is_decoder = True |
|
|
|
self.qformer = BertEncoder(qformer_config) |
|
self.proj = nn.Sequential( |
|
nn.LayerNorm(self.cfg.whisper_config.d_model), |
|
nn.Linear(self.cfg.whisper_config.d_model, self.cfg.llama_config.hidden_size) |
|
) |
|
|
|
def forward(self, encoder_hidden_states): |
|
layer_prompt_outputs = [] |
|
for idx, encoder_hidden_state in enumerate(encoder_hidden_states): |
|
if idx in self.target_layer_ids: |
|
layer_prompt = self.layer_prompts[self.target_layer_ids.index(idx)].expand(encoder_hidden_state.size(0), -1, -1) |
|
qformer_output = self.qformer( |
|
hidden_states=layer_prompt, |
|
encoder_hidden_states=encoder_hidden_state, |
|
) |
|
layer_prompt_output = qformer_output.last_hidden_state |
|
layer_prompt_outputs.append(layer_prompt_output) |
|
|
|
layer_prompt_outputs = torch.stack(layer_prompt_outputs, dim=0) |
|
layer_prompt_outputs = layer_prompt_outputs.permute(1, 2, 0, 3) |
|
|
|
self.norm_weights = torch.nn.functional.softmax(self.layer_weights, dim=-1).unsqueeze(-1) |
|
|
|
output = (layer_prompt_outputs * self.norm_weights).sum(dim=2) |
|
|
|
output = self.proj(output) |
|
|
|
return output |
|
|
|
class SpeechPerception(PreTrainedModel): |
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
self.cfg = cfg |
|
|
|
self.whisper = WhisperForConditionalGeneration.from_pretrained(cfg.whisper_model_id) |
|
self.processor = AutoProcessor.from_pretrained(cfg.whisper_model_id) |
|
|
|
self.connector = QformerConnector(cfg) |
|
|
|
def generate(self, input_features): |
|
input_features = input_features.to(self.whisper.device) |
|
|
|
outputs = self.whisper.generate(inputs=input_features, return_dict_in_generate=True, output_hidden_states=True) |
|
|
|
transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0] |
|
speech_features = self.connector(outputs.encoder_hidden_states) |
|
|
|
return transcriptions, speech_features |
|
|
|
|
|
class DestaModel(PreTrainedModel): |
|
config_class = Desta2Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.speech_perception = SpeechPerception(config) |
|
self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id) |
|
|
|
|
|
def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9): |
|
""" |
|
messages: list of dicts with keys "role" and "content" |
|
``` |
|
[ |
|
{"role": "system", "content": "You are a helpful voice assistant."}, |
|
{"role": "audio", "content": "<path_to_audio_file>"}, |
|
{"role": "user", "content": "Describe the audio."} |
|
] |
|
``` |
|
""" |
|
|
|
audio_path, input_features = self.load_audio(messages) |
|
transcription, audio_features = self.speech_perception.generate(input_features) |
|
inputs, audio_position = self.process_text(messages, audio_path, transcription) |
|
|
|
inputs_embeds, attention_mask = self.prepare_llm_input( |
|
input_ids=inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
audio_position=audio_position, |
|
audio_features=audio_features |
|
) |
|
|
|
outputs = self.llama.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=do_sample, |
|
temperature=temperature, |
|
top_p=top_p, |
|
) |
|
return outputs |
|
|
|
def process_text(self, messages, audio_path, transcription): |
|
context = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
left_text, right_text = context.split(audio_path) |
|
right_text = transcription + right_text |
|
|
|
audio_position = len(self.tokenizer.tokenize(left_text)) |
|
context = left_text + right_text |
|
|
|
inputs = self.tokenizer(context, return_tensors="pt") |
|
|
|
return inputs, audio_position |
|
|
|
|
|
def prepare_llm_input(self, input_ids, attention_mask, audio_position, audio_features): |
|
input_ids = input_ids.to(self.llama.device) |
|
attention_mask = attention_mask.to(self.llama.device) |
|
audio_features = audio_features.to(self.llama.device) |
|
audio_feature_length = audio_features.size(1) |
|
|
|
inputs_embeds = self.llama.model.embed_tokens(input_ids) |
|
|
|
|
|
inputs_embeds = torch.cat([inputs_embeds[0, :audio_position], audio_features[0, :], inputs_embeds[0, audio_position:]], dim=0) |
|
attention_mask = torch.cat([attention_mask[0, :audio_position], torch.ones([ audio_feature_length], dtype=torch.long, device=self.llama.device), attention_mask[0, audio_position:]], dim=0) |
|
|
|
inputs_embeds = inputs_embeds.to(self.llama.dtype) |
|
attention_mask = attention_mask.to(self.llama.dtype) |
|
return inputs_embeds.unsqueeze(0), attention_mask.unsqueeze(0) |
|
|
|
|
|
def load_audio(self, messages): |
|
audio_path = None |
|
for message in messages: |
|
if message["role"] == "audio" and audio_path is not None: |
|
raise ValueError("Multiple audio file paths found in messages. We only support one audio file per message at this moment.") |
|
if message["role"] == "audio": |
|
audio_path = message["content"] |
|
if audio_path is None: |
|
raise ValueError("No audio file path found in messages") |
|
audio, ori_sr = librosa.load(audio_path) |
|
audio = librosa.resample(audio, orig_sr=ori_sr, target_sr=16000) |
|
input_features = self.speech_perception.processor(audio, sampling_rate=16000, return_tensors="pt").input_features |
|
|
|
return audio_path, input_features |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, cache_dir=None,**kwargs): |
|
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
model = cls(config) |
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
model.speech_perception.connector.load_state_dict( |
|
torch.load(os.path.join(pretrained_model_name_or_path, "qformer_connector.pth")) |
|
) |
|
else: |
|
from huggingface_hub import hf_hub_download |
|
path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="qformer_connector.pth") |
|
model.speech_perception.connector.load_state_dict( |
|
torch.load(path) |
|
) |
|
|
|
return model |
|
|