Feature Extraction
Transformers
Safetensors
diva
custom_code
DiVA-llama-3-v0-8b / modeling_diva.py
WillHeld's picture
Update modeling_diva.py
72ed899 verified
raw
history blame
14.4 kB
import copy
import json
import os
from typing import Optional, Union
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from datasets import Audio
from safetensors.torch import load, load_model
from torch import nn
from .configuring_diva import DiVAConfig
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedModel,
WhisperModel,
)
class WhisperConnector(nn.Module):
def __init__(
self,
):
super().__init__()
self.decoder = None
self.projection = nn.Linear(1280, 4096)
self.query_tokens = nn.Parameter(torch.randn(448, 1280))
def forward(self, x, output_device="cuda:1"):
bsz = x.shape[0]
query_tokens = self.query_tokens[None, :, :].expand(bsz, -1, -1)
virt_whisper_tokens = self.decoder(
inputs_embeds=query_tokens, encoder_hidden_states=x
)
if self.projection.weight.shape[-1] == 5120:
virtual_tokens = self.projection(virt_whisper_tokens[0].reshape(112, 5120))
else:
virtual_tokens = self.projection(virt_whisper_tokens[0])
return virtual_tokens.to(output_device)
class DiVAModel(PreTrainedModel):
config_class = DiVAConfig
def __init__(
self, via_path=None, config_dict={}, device_map=None, speech_encoder_device=None
):
super().__init__(DiVAConfig.from_dict(config_dict))
if speech_encoder_device is None:
speech_encoder_device = "cuda:0"
whisper = WhisperModel.from_pretrained(config_dict["reference_encoder"])
connector = WhisperConnector()
connector.decoder = copy.deepcopy(whisper.decoder)
if via_path is not None:
with open(via_path, "rb") as f:
sd = load(f.read())
with torch.no_grad():
connector.query_tokens = nn.Parameter(sd["query_tokens"])
connector.projection.weight = nn.Parameter(sd["projection.weight"].T)
connector.projection.bias = nn.Parameter(sd["projection.bias"])
wsd = {
key.replace("connector.", ""): sd[key]
for key in sd
if key.startswith("connector.")
}
connector.decoder.load_state_dict(wsd)
if device_map == None:
num_layers = 32
num_gpus = 2
device_map = dict(
**{"model.embed_tokens": 1, "model.norm": 1, "lm_head": 2},
**{
"model.layers." + str(i): 1 + (i // (num_layers // num_gpus))
for i in range(num_layers)
},
)
self.connector = connector.to(speech_encoder_device)
self.whisper_encoder = whisper.encoder.to(speech_encoder_device)
self.llm_decoder = AutoModelForCausalLM.from_pretrained(
config_dict["reference_decoder"],
device_map=device_map,
torch_dtype=torch.float16,
)
self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
self.tokenizer = AutoTokenizer.from_pretrained(config_dict["reference_decoder"])
if self.tokenizer.pad_token_id == None:
override_token = list(self.tokenizer.added_tokens_decoder.items())[-1]
self.tokenizer.pad_token_id = override_token[0]
self.tokenizer.pad_tokn = str(override_token[1])
prefix, suffix = self.tokenizer.apply_chat_template(
[{"role": "user", "content": "PLACEHOLDER"}],
tokenize=False,
add_generation_prompt=True,
).split("PLACEHOLDER")
non_null = [line for line in prefix.split("\n") if line.strip()]
prefix_tok = self.tokenizer.encode(prefix, add_special_tokens=False)
suffix_tok = self.tokenizer.encode(suffix, add_special_tokens=False)
self.prefix = torch.tensor(prefix_tok).to(
self.llm_decoder.model.embed_tokens.weight.device
)
self.pre_system = torch.tensor(
self.tokenizer.encode(non_null[0] + "\n", add_special_tokens=False)
).to(self.llm_decoder.model.embed_tokens.weight.device)
self.post_system = torch.tensor(
self.tokenizer.encode("\n" + non_null[-1] + "\n", add_special_tokens=False)
).to(self.llm_decoder.model.embed_tokens.weight.device)
self.final_header = torch.tensor(suffix_tok).to(
self.llm_decoder.model.embed_tokens.weight.device
)
self.speech_encoder_device = speech_encoder_device
def can_generate(cls):
return False
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config=None,
cache_dir=None,
**kwargs,
):
if os.path.isdir(pretrained_model_name_or_path):
via_path = pretrained_model_name_or_path + "/model.safetensors"
config_path = pretrained_model_name_or_path + "/config.json"
else:
# Loading from huggingface repo
from huggingface_hub import hf_hub_download
via_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="model.safetensors",
token=kwargs.get("token", None),
)
config_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename="config.json",
token=kwargs.get("token", None),
)
with open(config_path, "r") as f:
config_dict = json.loads(f.read())
return cls(
via_path,
config_dict,
kwargs["device_map"] if "device_map" in kwargs else "auto",
(
kwargs["speech_encoder_device"]
if "speech_encoder_device" in kwargs
else None
),
)
def forward(self, audio, prefix_text_tokens, suffix_text_tokens):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llm_decoder.model.embed_tokens.weight.device,
).squeeze()
prefix_embed = self.llm_decoder.model.embed_tokens(prefix_text_tokens)
suffix_embed = self.llm_decoder.model.embed_tokens(suffix_text_tokens)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outputs = self.llm_decoder(
inputs_embeds=inputs_embeds.to(
self.llm_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
return outputs
@torch.no_grad()
def generate(
self,
audio,
text_prompt=None,
do_sample=False,
logits_processor=None,
max_new_tokens=128,
):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.speech_encoder_device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llm_decoder.model.embed_tokens.weight.device,
)
bsz = virt_tokens.shape[0]
if text_prompt != None and text_prompt != "":
user_prompt_text = torch.tensor(
self.tokenizer(
text_prompt,
add_special_tokens=False,
padding=True,
padding_side="right",
)["input_ids"],
device=self.pre_system.device,
)
prefix = torch.cat(
[
self.pre_system.expand(
bsz,
-1,
),
user_prompt_text,
self.post_system.expand(
bsz,
-1,
),
],
axis=1,
)
else:
prefix = self.prefix
prefix_embed = self.llm_decoder.model.embed_tokens(prefix).expand(bsz, -1, -1)
suffix = self.final_header
suffix_embed = self.llm_decoder.model.embed_tokens(suffix).expand(bsz, -1, -1)
inputs_embeds = torch.cat([prefix_embed, virt_tokens, suffix_embed], axis=1)
outs = [[] for i in range(bsz)]
complete = [False] * bsz
outputs = None
greedy = 1
i = 0
while not all(complete) and len(outs[0]) < max_new_tokens:
past_key_values = outputs.past_key_values if outputs else None
outputs = self.llm_decoder(
inputs_embeds=inputs_embeds.to(
self.llm_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[:, -1, :]
if logits_processor:
local_outs = torch.tensor(outs) if outs != [] else suffix
local_outs = local_outs.reshape(1, -1)
next_token_logits = logits_processor(
local_outs,
next_token_logits.reshape(1, -1),
)
next_token_logits = next_token_logits.flatten()
if do_sample:
logits = next_token_logits / temperature
probs = F.softmax(logits, dim=-1)
greedy = torch.multinomial(probs, num_samples=1)[0]
else:
greedy = next_token_logits.argmax(dim=-1)
for token_index, out in enumerate(greedy.flatten().tolist()):
if not complete[token_index]:
outs[token_index].append(out)
if out == 128009:
complete[token_index] = True
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(-1, 1))
inputs_embeds = next_embed
return self.tokenizer.batch_decode(outs, skip_special_tokens=True)
def generate_stream(
self,
audio,
text_prompt,
do_sample=False,
logits_processor=None,
max_new_tokens=128,
return_outputs=False,
init_outputs=None,
):
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
input_features = inputs.input_features.to(self.whisper_encoder.device)
hidden_states = self.whisper_encoder(input_features=input_features)[
"last_hidden_state"
]
virt_tokens = self.connector(
hidden_states,
output_device=self.llm_decoder.model.embed_tokens.weight.device,
).squeeze()
if text_prompt != None and text_prompt != "":
user_prompt_text = torch.tensor(
self.tokenizer(
text_prompt,
add_special_tokens=False,
padding=True,
padding_side="right",
)["input_ids"],
device=self.pre_system.device,
)
prefix = torch.cat(
[self.pre_system, user_prompt_text, self.post_system],
axis=0,
)
else:
prefix = self.prefix
prefix_embed = self.llm_decoder.model.embed_tokens(prefix)
suffix = self.final_header
suffix_embed = self.llm_decoder.model.embed_tokens(suffix)
inputs_embeds = torch.cat(
[prefix_embed, virt_tokens, suffix_embed], axis=0
).unsqueeze(0)
outs = []
outputs = init_outputs
greedy = 1
i = 0
while greedy != 128009 and len(outs) < max_new_tokens:
past_key_values = outputs.past_key_values if outputs else None
outputs = self.llm_decoder(
inputs_embeds=inputs_embeds.to(
self.llm_decoder.model.embed_tokens.weight.device
).half(),
return_dict=True,
output_hidden_states=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[-1, -1, :]
if logits_processor:
local_outs = torch.tensor(outs) if outs != [] else suffix
local_outs = local_outs.reshape(1, -1)
next_token_logits = logits_processor(
local_outs,
next_token_logits.reshape(1, -1),
)
next_token_logits = next_token_logits.flatten()
if do_sample:
logits = next_token_logits / temperature
probs = F.softmax(logits, dim=-1)
greedy = torch.multinomial(probs, num_samples=1)[0]
else:
greedy = next_token_logits.argmax()
outs.append(greedy)
next_embed = self.llm_decoder.model.embed_tokens(greedy.reshape(1, 1))
inputs_embeds = next_embed
if not return_outputs:
yield self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
)
else:
yield (
self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
),
outputs,
)
if not return_outputs:
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
)
else:
return (
self.tokenizer.decode(outs, skip_special_tokens=True).replace(
"<|eot_id|>", ""
),
outputs,
)