Spaces:
Sleeping
Sleeping
''' | |
Ke Chen | [email protected] & Nikita Srivatsan | [email protected] | |
Load the mp3 format data from audiostock-full dataset | |
''' | |
import json | |
import numpy as np | |
import os | |
import pandas as pd | |
from pathlib import PurePosixPath | |
import random | |
import torch | |
import torchaudio | |
from torch.utils.data import Dataset | |
import sys | |
from lib import * | |
from utils import * | |
import torch.utils.data | |
def int16_to_float32(x): | |
return (x / 32767.0).type(torch.float) | |
def float32_to_int16(x): | |
x = torch.clip(x, min=-1., max=1.) | |
return (x * 32767.).type(torch.int16) | |
def my_collate(batch): | |
batch = [x for x in batch if x is not None] | |
if len(batch) == 0: | |
return batch | |
else: | |
return torch.utils.data.dataloader.default_collate(batch) | |
class AudiostockDataset(Dataset): | |
''' | |
Args: | |
dataset_path (str): the dataset folder path | |
train (bool): if True, we randomly return a 10-sec chunk from each audio file; if False, we return the middle 10-sec chunk (fixed) | |
split (str): a txt file to assign the idx in this dataset (for trainng, validation and testing) | |
factor (float): how many time we need to loop the whole dataset, this is to increase the number of training data batches in each epoch | |
whole_track (bool): if True, the dataset will return the full length of the audio file. However, this means the batch_size = 1, and it is usually in the test/validation case | |
''' | |
def __init__(self, dataset_path, tweet_prefix=True, prefix_length=10, normalize=False, dupefile='dupes.pkl', train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True, file_list=[]): | |
super().__init__() | |
# set up parameters | |
self.max_seq_len = 150 | |
self.tweet_prefix = tweet_prefix | |
if self.tweet_prefix: | |
self.max_seq_len *= 2 | |
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True) | |
self.prefix_length = prefix_length | |
self.normalize = normalize | |
self.id2neighbor = defaultdict(lambda: '') | |
if dedup: | |
if dupefile is not None and os.path.exists(dupefile): | |
with open(dupefile, 'rb') as dupefile: | |
self.is_rep = pickle.load(dupefile).is_rep | |
elif dupefile == 'both': | |
with open('dupes.pkl', 'rb') as dupefile: | |
dupes1 = pickle.load(dupefile) | |
with open('dupes_audio.pkl', 'rb') as dupefile: | |
dupes2 = pickle.load(dupefile) | |
self.is_rep = defaultdict(lambda: True) | |
for k,v in dupes1.is_rep.items(): | |
self.is_rep[k] = v | |
for k,v in dupes2.is_rep.items(): | |
self.is_rep[k] = v | |
else: | |
sys.exit('Could not find duplicate file') | |
subfolders = [f'audiostock-part-{i}' for i in range(1,9)] | |
self.label_path = os.path.join(dataset_path, 'audiostock-full-label') | |
self.whole_track = whole_track | |
self.file_list = file_list | |
# select out the elements for this split | |
if self.file_list == []: | |
temp_file_list = [] | |
for subfolder in subfolders: | |
temp_file_list += [os.path.join(dataset_path, subfolder, f) for f in os.listdir(os.path.join(dataset_path, subfolder)) if not dedup or self.is_rep[os.path.basename(f).split('.')[0]]] | |
if split is not None: | |
split = set(np.loadtxt(split, dtype = str)) | |
self.file_list = [f for f in temp_file_list if os.path.basename(f).split('.')[0] in split] | |
else: | |
self.file_list = temp_file_list | |
self.train = train | |
self.total_len = int(len(self.file_list) * factor) | |
if verbose: | |
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}') | |
def precompute_rand(self, candidate_set=None): | |
self.id2neighbor = defaultdict(lambda: '') | |
# if train | |
if candidate_set is None: | |
my_ids = [] | |
candidate_caps = [] | |
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
for batch in temp_loader: | |
my_ids += batch['id'] | |
candidate_caps += batch['short_text'] | |
for idx in my_ids: | |
self.id2neighbor[idx] = random.choice(candidate_caps) | |
# if test | |
else: | |
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
candidate_caps = [] | |
for batch in temp_loader: | |
candidate_caps += batch['short_text'] | |
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
my_ids = [] | |
for batch in temp_loader: | |
my_ids += batch['id'] | |
for idx in my_ids: | |
self.id2neighbor[idx] = random.choice(candidate_caps) | |
def precompute_gold(self): | |
self.id2neighbor = defaultdict(lambda: '') | |
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
for batch in temp_loader: | |
for idx,short_text in zip(batch['id'], batch['short_text']): | |
self.id2neighbor[idx] = short_text | |
def precompute_blank(self): | |
self.id2neighbor = defaultdict(lambda: '\n') | |
def precompute_neighbors(self, model, candidate_set=None): | |
print('Precomputing neighbors') | |
self.id2neighbor = defaultdict(lambda: '') | |
# if train and model given | |
if candidate_set is None: | |
# compute waveform embeddings for each song | |
cand_features = None | |
cand_ids = [] | |
cand_caps = [] | |
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
for batch in temp_loader: | |
with torch.no_grad(): | |
batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
if cand_features is not None: | |
cand_features = torch.cat([cand_features, batch_features]) | |
else: | |
cand_features = batch_features | |
cand_ids += batch['id'] | |
cand_caps += batch['short_text'] | |
progress.update() | |
progress.close() | |
my_features = cand_features | |
my_ids = cand_ids | |
# if test and model given | |
else: | |
# check if we already precomputed the embeddings | |
pickle_filename = 'nn_features.pkl' | |
if os.path.isfile(pickle_filename): | |
with open(pickle_filename, 'rb') as f: | |
(cand_features, cand_ids, cand_caps) = pickle.load(f) | |
else: | |
# build the features from the provided set instead of self | |
cand_features = None | |
cand_ids = [] | |
cand_caps = [] | |
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
for batch in temp_loader: | |
with torch.no_grad(): | |
batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
if cand_features is not None: | |
cand_features = torch.cat([cand_features, batch_features]) | |
else: | |
cand_features = batch_features | |
cand_ids += batch['id'] | |
#cand_caps += [' '.join(x.split()[:10]) for x in batch['short_text']] | |
cand_caps += batch['short_text'] | |
progress.update() | |
progress.close() | |
# dump to pickle so we don't have to redo this each time | |
with open(pickle_filename, 'wb') as f: | |
pickle.dump((cand_features, cand_ids, cand_caps), f) | |
# load up my own ids and features | |
my_features = None | |
my_ids = [] | |
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate) | |
progress = tqdm(total=len(temp_loader), dynamic_ncols=True) | |
for batch in temp_loader: | |
with torch.no_grad(): | |
batch_features = model.embed_waveform(batch['waveform'].cuda()) | |
if my_features is not None: | |
my_features = torch.cat([my_features, batch_features]) | |
else: | |
my_features = batch_features | |
my_ids += batch['id'] | |
progress.update() | |
progress.close() | |
is_self_sim = my_ids == cand_ids | |
for idx,audio_id in tqdm(enumerate(my_ids), total=len(my_ids), dynamic_ncols=True): | |
features = my_features[idx] | |
similarities = features @ cand_features.T | |
# remove identical matches | |
if is_self_sim: | |
similarities[idx] = float('-inf') | |
best_idx = torch.argmax(similarities) | |
most_similar_caption = cand_caps[best_idx] | |
self.id2neighbor[my_ids[idx]] = most_similar_caption | |
def pad_tokens(self, tokens, tokens_tweet): | |
tweet_text_len = 0 | |
if self.tweet_prefix: | |
tweet_text_len = tokens_tweet[:self.max_seq_len // 2].shape[0] | |
tokens = torch.cat((tokens_tweet[:tweet_text_len], tokens)) | |
padding = self.max_seq_len - tokens.shape[0] | |
if padding > 0: | |
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1)) | |
elif padding < 0: | |
tokens = tokens[:self.max_seq_len] | |
mask = tokens.ge(0) # mask is zero where we out of sequence | |
tokens[~mask] = 0 | |
mask = mask.float() | |
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask | |
return tokens, mask, tweet_text_len | |
def read_wav(self, filename): | |
# pickling functionality removed since it shouldn't be necessary | |
# chunk | |
try: | |
num_frames = torchaudio.info(filename).num_frames | |
except: | |
return None | |
# make sure it wasn't empty, if so die | |
if num_frames == 0: | |
return None | |
sta = 0 | |
if not self.whole_track: | |
if self.train: | |
sta = random.randint(0, num_frames - 441001) | |
else: | |
sta = (num_frames - 441001) // 2 | |
num_frames = 441000 | |
y, sr = torchaudio.load(filename, frame_offset=sta, num_frames=num_frames) | |
# resample | |
y = torchaudio.functional.resample(y, sr, 48000) | |
y = y[:, :441000] | |
# mono | |
y = y.mean(dim=0) | |
# normalize | |
y = int16_to_float32(float32_to_int16(y)) | |
return y | |
def __getitem__(self, index): | |
idx = index % len(self.file_list) | |
data_dict = {} | |
f = self.file_list[idx] | |
lf = os.path.join(self.label_path, os.path.basename(f).split('.')[0] + '.json') | |
data_dict['waveform'] = self.read_wav(f) | |
if os.path.isfile(lf): | |
with open(lf,'r') as label_file: | |
label_data = json.load(label_file) | |
data_dict['id'] = label_data['id'] | |
data_dict['short_text'] = label_data['short_text'] | |
if self.normalize: | |
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text'])) | |
if 'long_text' in label_data and label_data['long_text'] is not None: | |
data_dict['long_text'] = label_data['long_text'] | |
else: | |
data_dict['long_text'] = '' | |
''' | |
data_dict['tag'] = label_data['tag'] | |
data_dict['impression'] = label_data['impression'] | |
data_dict['purpose'] = label_data['purpose'] | |
''' | |
else: | |
data_dict['id'] = os.path.basename(f).split('.')[0] | |
data_dict['short_text'] = '' | |
data_dict['long_text'] = '' | |
# tokenize the caption | |
caption_proc = preproc(data_dict['short_text'], self.tokenizer) | |
tokens = torch.tensor(caption_proc, dtype=torch.int64) | |
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else '' | |
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False) | |
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64) | |
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet) | |
data_dict['tokens'] = tokens | |
data_dict['mask'] = mask | |
data_dict['tweet_text_len'] = tweet_text_len | |
data_dict['tweet_text'] = tweet_text | |
if (data_dict['id'] is None or | |
data_dict['short_text'] is None or | |
data_dict['long_text'] is None or | |
data_dict['tokens'] is None or | |
data_dict['mask'] is None or | |
data_dict['tweet_text_len'] is None or | |
data_dict['tweet_text'] is None or | |
data_dict['waveform'] is None | |
): | |
return None | |
else: | |
return data_dict | |
def __len__(self): | |
return self.total_len | |
class MusicCapsDataset(AudiostockDataset): | |
def __init__(self, dataset_path, args, train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True): | |
super(AudiostockDataset, self).__init__() | |
# set up parameters | |
self.max_seq_len = 150 | |
self.tweet_prefix = args.tweet_prefix | |
if self.tweet_prefix: | |
self.max_seq_len *= 2 | |
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True) | |
self.prefix_length = args.prefix_length | |
self.normalize = args.normalize | |
self.whole_track = whole_track | |
self.label_path = os.path.join(dataset_path, 'audio') | |
self.file_list = [] | |
self.label_data = [] | |
label_reader = pd.read_csv(f'{dataset_path}/musiccaps-resplit.csv') | |
for idx,row in label_reader.iterrows(): | |
if (row['is_audioset_eval'] == 1 and split == 'musiccaps_eval') \ | |
or (row['is_audioset_eval'] == 0 and split == 'musiccaps_train') \ | |
or (row['is_audioset_eval'] == 2 and split == 'musiccaps_dev'): | |
data_dict = {} | |
data_dict['id'] = row['ytid'] | |
self.file_list.append(f"{dataset_path}/audio/{data_dict['id']}.wav") | |
data_dict['short_text'] = row['caption'] | |
if self.normalize: | |
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text'])) | |
data_dict['long_text'] = '' | |
data_dict['tag'] = row['aspect_list'] | |
self.label_data.append(data_dict) | |
self.train = train | |
self.total_len = int(len(self.file_list) * factor) | |
if verbose: | |
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}') | |
def __getitem__(self, index): | |
idx = index % len(self.file_list) | |
data_dict = {} | |
f = self.file_list[idx] | |
data_dict['waveform'] = self.read_wav(f) | |
for k,v in self.label_data[idx].items(): | |
data_dict[k] = v | |
# tokenize the caption | |
caption_proc = preproc(data_dict['short_text'], self.tokenizer) | |
tokens = torch.tensor(caption_proc, dtype=torch.int64) | |
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else '' | |
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False) | |
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64) | |
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet) | |
data_dict['tokens'] = tokens | |
data_dict['mask'] = mask | |
data_dict['tweet_text_len'] = tweet_text_len | |
data_dict['tweet_text'] = tweet_text | |
if (data_dict['id'] is None or | |
data_dict['short_text'] is None or | |
data_dict['long_text'] is None or | |
data_dict['tokens'] is None or | |
data_dict['mask'] is None or | |
data_dict['tweet_text_len'] is None or | |
data_dict['tweet_text'] is None or | |
data_dict['waveform'] is None | |
): | |
return None | |
else: | |
return data_dict | |