DestaModel
custom_code
DeSTA2-8B-beta / modeling_desta.py
kehanlu's picture
Upload folder using huggingface_hub
8ab1e14 verified
raw
history blame
9.24 kB
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, WhisperForConditionalGeneration, PretrainedConfig, PreTrainedModel, BertConfig, AutoProcessor
from transformers.models.bert.modeling_bert import BertEncoder
from torch import nn
import torch
import os
class Desta2Config(PretrainedConfig):
model_type = "DestaModel"
def __init__(
self,
llama_model_id="meta-llama/Meta-Llama-3-8B-Instruct",
whisper_model_id="openai/whisper-small",
prompt_size=64,
**kwargs
):
super().__init__(**kwargs)
self.llama_model_id = llama_model_id
self.whisper_model_id = whisper_model_id
self.prompt_size = prompt_size
self.whisper_config = AutoConfig.from_pretrained(self.whisper_model_id)
self.llama_config = AutoConfig.from_pretrained(self.llama_model_id)
class QformerConnector(PreTrainedModel):
def __init__(self, cfg):
super().__init__(cfg)
self.cfg = cfg
if self.cfg.whisper_model_id == "openai/whisper-medium":
self.target_layer_ids = [5, 11, 17, 23]
elif self.cfg.whisper_model_id == "openai/whisper-small":
self.target_layer_ids = [2, 5, 8, 11]
elif self.cfg.whisper_model_id == "openai/whisper-tiny":
self.target_layer_ids = [0,1,2,3]
elif self.cfg.whisper_model_id == "openai/whisper-large-v3":
self.target_layer_ids = [3, 7, 11, 15, 19, 23, 27, 31]
else:
raise NotImplementedError(f"model_id {self.cfg.whisper_model_id} not implemented")
self.layer_prompts = nn.ParameterList([
nn.Parameter(torch.randn(1, self.cfg.prompt_size, self.cfg.whisper_config.d_model)) for _ in range(len(self.target_layer_ids))]
)
# (prompt_size, target_layers)
self.layer_weights = nn.Parameter(torch.zeros(self.cfg.prompt_size, len(self.target_layer_ids), dtype=torch.float))
qformer_config = BertConfig()
qformer_config.num_hidden_layers = 2
qformer_config.num_attention_heads = self.cfg.whisper_config.encoder_attention_heads
qformer_config.hidden_size = self.cfg.whisper_config.d_model
qformer_config.add_cross_attention = True
qformer_config.is_decoder = True
self.qformer = BertEncoder(qformer_config)
self.proj = nn.Sequential(
nn.LayerNorm(self.cfg.whisper_config.d_model),
nn.Linear(self.cfg.whisper_config.d_model, self.cfg.llama_config.hidden_size) # project to llama hidden size
)
def forward(self, encoder_hidden_states):
layer_prompt_outputs = []
for idx, encoder_hidden_state in enumerate(encoder_hidden_states):
if idx in self.target_layer_ids:
layer_prompt = self.layer_prompts[self.target_layer_ids.index(idx)].expand(encoder_hidden_state.size(0), -1, -1)
qformer_output = self.qformer(
hidden_states=layer_prompt,
encoder_hidden_states=encoder_hidden_state,
)
layer_prompt_output = qformer_output.last_hidden_state
layer_prompt_outputs.append(layer_prompt_output)
layer_prompt_outputs = torch.stack(layer_prompt_outputs, dim=0)
layer_prompt_outputs = layer_prompt_outputs.permute(1, 2, 0, 3)
self.norm_weights = torch.nn.functional.softmax(self.layer_weights, dim=-1).unsqueeze(-1)
output = (layer_prompt_outputs * self.norm_weights).sum(dim=2) # (b, prompt_size, d_model)
output = self.proj(output)
return output
class SpeechPerception(PreTrainedModel):
def __init__(self, cfg):
super().__init__(cfg)
self.cfg = cfg
self.whisper = WhisperForConditionalGeneration.from_pretrained(cfg.whisper_model_id)
self.processor = AutoProcessor.from_pretrained(cfg.whisper_model_id)
self.connector = QformerConnector(cfg)
def generate(self, input_features):
input_features = input_features.to(self.whisper.device)
outputs = self.whisper.generate(inputs=input_features, return_dict_in_generate=True, output_hidden_states=True) # here we use default generate config for whisper
transcriptions = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
speech_features = self.connector(outputs.encoder_hidden_states)
return transcriptions, speech_features
class DestaModel(PreTrainedModel):
config_class = Desta2Config
def __init__(self, config):
super().__init__(config)
self.speech_perception = SpeechPerception(config)
self.llama = AutoModelForCausalLM.from_pretrained(config.llama_model_id, torch_dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(config.llama_model_id)
def chat(self, messages, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.9):
"""
messages: list of dicts with keys "role" and "content"
```
[
{"role": "system", "content": "You are a helpful voice assistant."},
{"role": "audio", "content": "<path_to_audio_file>"},
{"role": "user", "content": "Describe the audio."}
]
```
"""
audio_path, input_features = self.load_audio(messages)
transcription, audio_features = self.speech_perception.generate(input_features)
inputs, audio_position = self.process_text(messages, audio_path, transcription)
inputs_embeds, attention_mask = self.prepare_llm_input(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
audio_position=audio_position,
audio_features=audio_features
)
outputs = self.llama.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
)
return outputs
def process_text(self, messages, audio_path, transcription):
context = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
left_text, right_text = context.split(audio_path)
right_text = transcription + right_text #
audio_position = len(self.tokenizer.tokenize(left_text))
context = left_text + right_text
inputs = self.tokenizer(context, return_tensors="pt")
return inputs, audio_position
def prepare_llm_input(self, input_ids, attention_mask, audio_position, audio_features):
input_ids = input_ids.to(self.llama.device)
attention_mask = attention_mask.to(self.llama.device)
audio_features = audio_features.to(self.llama.device)
audio_feature_length = audio_features.size(1)
inputs_embeds = self.llama.model.embed_tokens(input_ids) # [bs, seq_len, hidden_size]
inputs_embeds = torch.cat([inputs_embeds[0, :audio_position], audio_features[0, :], inputs_embeds[0, audio_position:]], dim=0)
attention_mask = torch.cat([attention_mask[0, :audio_position], torch.ones([ audio_feature_length], dtype=torch.long, device=self.llama.device), attention_mask[0, audio_position:]], dim=0)
inputs_embeds = inputs_embeds.to(self.llama.dtype)
attention_mask = attention_mask.to(self.llama.dtype)
return inputs_embeds.unsqueeze(0), attention_mask.unsqueeze(0)
def load_audio(self, messages):
audio_path = None
for message in messages:
if message["role"] == "audio" and audio_path is not None:
raise ValueError("Multiple audio file paths found in messages. We only support one audio file per message at this moment.")
if message["role"] == "audio":
audio_path = message["content"]
if audio_path is None:
raise ValueError("No audio file path found in messages")
audio, ori_sr = librosa.load(audio_path)
audio = librosa.resample(audio, orig_sr=ori_sr, target_sr=16000)
input_features = self.speech_perception.processor(audio, sampling_rate=16000, return_tensors="pt").input_features
return audio_path, input_features
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, config=None, cache_dir=None,**kwargs):
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
model.speech_perception.connector.load_state_dict(
torch.load(os.path.join(pretrained_model_name_or_path, "qformer_connector.pth"))
)
else:
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="qformer_connector.pth")
model.speech_perception.connector.load_state_dict(
torch.load(path)
)
return model