|
import sys |
|
import os |
|
import pandas as pd |
|
import argparse |
|
|
|
|
|
default_cuda_devices = "0" |
|
if len(sys.argv) > 1: |
|
argument = sys.argv[1] |
|
if argument == '4': |
|
argument = default_cuda_devices |
|
else: |
|
argument = default_cuda_devices |
|
os.environ["CUDA_VISIBLE_DEVICES"] = argument |
|
import numpy as np |
|
import os |
|
import torchaudio |
|
import fire |
|
import json |
|
import torch |
|
from tqdm import tqdm |
|
import time |
|
import torchvision |
|
from peft import ( |
|
LoraConfig, |
|
get_peft_model, |
|
get_peft_model_state_dict, |
|
prepare_model_for_int8_training, |
|
set_peft_model_state_dict, |
|
) |
|
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer, LlamaConfig |
|
|
|
from utils.prompter import Prompter |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--input_csv', type=str, required=True, help='Path to the input file') |
|
parser.add_argument('--output_csv', type=str, required=True, help='Path to the output file') |
|
args = parser.parse_args() |
|
|
|
def int16_to_float32_torch(x): |
|
return (x / 32767.0).type(torch.float32) |
|
|
|
def float32_to_int16_torch(x): |
|
x = torch.clamp(x, min=-1., max=1.) |
|
return (x * 32767.).type(torch.int16) |
|
|
|
def get_mel(audio_data): |
|
|
|
mel_tf = torchaudio.transforms.MelSpectrogram( |
|
sample_rate=48000, |
|
n_fft=1024, |
|
win_length=1024, |
|
hop_length=480, |
|
center=True, |
|
pad_mode="reflect", |
|
power=2.0, |
|
norm=None, |
|
onesided=True, |
|
n_mels=64, |
|
f_min=50, |
|
f_max=14000 |
|
).to(audio_data.device) |
|
|
|
mel = mel_tf(audio_data) |
|
|
|
|
|
mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) |
|
|
|
return mel.T |
|
|
|
|
|
def get_audio_features(sample, audio_data, max_len, data_truncating, data_filling, require_grad=False): |
|
grad_fn = suppress if require_grad else torch.no_grad |
|
with grad_fn(): |
|
if len(audio_data) > max_len: |
|
if data_truncating == "rand_trunc": |
|
longer = torch.tensor([True]) |
|
elif data_truncating == "fusion": |
|
|
|
mel = get_mel(audio_data) |
|
|
|
chunk_frames = max_len // 480 + 1 |
|
total_frames = mel.shape[0] |
|
if chunk_frames == total_frames: |
|
|
|
|
|
|
|
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([False]) |
|
else: |
|
ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) |
|
|
|
|
|
|
|
|
|
if len(ranges[1]) == 0: |
|
|
|
ranges[1] = [0] |
|
if len(ranges[2]) == 0: |
|
|
|
ranges[2] = [0] |
|
|
|
idx_front = np.random.choice(ranges[0]) |
|
idx_middle = np.random.choice(ranges[1]) |
|
idx_back = np.random.choice(ranges[2]) |
|
|
|
mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] |
|
mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] |
|
mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] |
|
|
|
|
|
mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])(mel[None])[0] |
|
|
|
|
|
|
|
mel_fusion = torch.stack([mel_shrink, mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([True]) |
|
else: |
|
raise NotImplementedError( |
|
f"data_truncating {data_truncating} not implemented" |
|
) |
|
|
|
overflow = len(audio_data) - max_len |
|
idx = np.random.randint(0, overflow + 1) |
|
audio_data = audio_data[idx: idx + max_len] |
|
|
|
else: |
|
if len(audio_data) < max_len: |
|
if data_filling == "repeatpad": |
|
n_repeat = int(max_len / len(audio_data)) |
|
audio_data = audio_data.repeat(n_repeat) |
|
|
|
|
|
audio_data = F.pad( |
|
audio_data, |
|
(0, max_len - len(audio_data)), |
|
mode="constant", |
|
value=0, |
|
) |
|
elif data_filling == "pad": |
|
audio_data = F.pad( |
|
audio_data, |
|
(0, max_len - len(audio_data)), |
|
mode="constant", |
|
value=0, |
|
) |
|
elif data_filling == "repeat": |
|
n_repeat = int(max_len / len(audio_data)) |
|
audio_data = audio_data.repeat(n_repeat + 1)[:max_len] |
|
else: |
|
raise NotImplementedError( |
|
f"data_filling {data_filling} not implemented" |
|
) |
|
if data_truncating == 'fusion': |
|
mel = get_mel(audio_data) |
|
mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) |
|
sample["mel_fusion"] = mel_fusion |
|
longer = torch.tensor([False]) |
|
|
|
sample["longer"] = longer |
|
sample["waveform"] = audio_data |
|
sample["mel_fusion"] = sample["mel_fusion"].unsqueeze(0) |
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
def load_audio(filename): |
|
waveform, sr = torchaudio.load(filename) |
|
if sr != 16000: |
|
waveform = torchaudio.functional.resample(waveform=waveform, orig_freq=sr, new_freq=16000) |
|
sr = 16000 |
|
waveform = waveform - waveform.mean() |
|
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=sr, |
|
use_energy=False, window_type='hanning', |
|
num_mel_bins=128, dither=0.0, frame_shift=10) |
|
target_length = 1024 |
|
n_frames = fbank.shape[0] |
|
p = target_length - n_frames |
|
if p > 0: |
|
m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
|
fbank = m(fbank) |
|
elif p < 0: |
|
fbank = fbank[0:target_length, :] |
|
|
|
fbank = (fbank + 5.081) / 4.4849 |
|
return fbank |
|
|
|
|
|
root_dir = '/fs/nexus-projects' |
|
def main( |
|
base_model: str = "/fs/nexus-projects/brain_project/Llama-2-7b-chat-hf-qformer", |
|
prompt_template: str = "alpaca_short", |
|
): |
|
base_model = base_model or os.environ.get("BASE_MODEL", "") |
|
assert ( |
|
base_model |
|
), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'" |
|
|
|
prompter = Prompter(prompt_template) |
|
tokenizer = LlamaTokenizer.from_pretrained(base_model) |
|
|
|
|
|
model = LlamaForCausalLM.from_pretrained(base_model, device_map="auto") |
|
|
|
|
|
config = LoraConfig( |
|
r=8, |
|
lora_alpha=16, |
|
target_modules=["q_proj", "v_proj"], |
|
lora_dropout=0.0, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
model = get_peft_model(model, config) |
|
temp, top_p, top_k = 0.1, 0.95, 500 |
|
|
|
eval_root_path = "" |
|
|
|
eval_mdl_path = '/fs/gamma-projects/audio/gama/new_data_no_aggr/stage4_all_mix_new/checkpoint-46800//pytorch_model.bin' |
|
state_dict = torch.load(eval_mdl_path, map_location='cpu') |
|
msg = model.load_state_dict(state_dict, strict=False) |
|
|
|
model.is_parallelizable = True |
|
model.model_parallel = True |
|
|
|
|
|
model.config.pad_token_id = tokenizer.pad_token_id = 0 |
|
model.config.bos_token_id = 1 |
|
model.config.eos_token_id = 2 |
|
|
|
model.eval() |
|
file = pd.read_csv(args.input_csv) |
|
file = file.head() |
|
tmp_path = [] |
|
tmp_caption = [] |
|
tmp_dataset = [] |
|
tmp_split_name = [] |
|
for i in tqdm(range(len(file))): |
|
audio_path = file['path'][i] |
|
instruction = "Write a caption for the audio in AudioCaps style." |
|
prompt = prompter.generate_prompt(instruction, None) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
input_ids = inputs["input_ids"].to(device) |
|
if audio_path != 'empty': |
|
cur_audio_input = load_audio(audio_path).unsqueeze(0) |
|
if torch.cuda.is_available() == False: |
|
pass |
|
else: |
|
cur_audio_input = cur_audio_input.to(device) |
|
else: |
|
cur_audio_input = None |
|
|
|
generation_config = GenerationConfig( |
|
do_sample=True, |
|
temperature=temp, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=1.1, |
|
max_new_tokens=400, |
|
bos_token_id=model.config.bos_token_id, |
|
eos_token_id=model.config.eos_token_id, |
|
pad_token_id=model.config.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
generation_output = model.generate( |
|
input_ids=input_ids.to(device), |
|
audio_input=cur_audio_input, |
|
generation_config=generation_config, |
|
return_dict_in_generate=True, |
|
output_scores=True, |
|
max_new_tokens=400, |
|
) |
|
s = generation_output.sequences[0] |
|
output = tokenizer.decode(s)[6:-4] |
|
output = output[len(prompt):] |
|
|
|
print(output) |
|
tmp_path.append(audio_path) |
|
tmp_caption.append(output) |
|
tmp_dataset.append(file['dataset'][i]) |
|
tmp_split_name.append(file['split_name'][i]) |
|
df = pd.DataFrame() |
|
df['path'] = tmp_path |
|
df['caption'] = tmp_caption |
|
df.to_csv(args.output_csv,index=False) |
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main(args)) |
|
|