RGMC / data_module.py
NikitaSrivatsan
Removed pickling of input files
7b39cbc
'''
Ke Chen | [email protected] & Nikita Srivatsan | [email protected]
Load the mp3 format data from audiostock-full dataset
'''
import json
import numpy as np
import os
import pandas as pd
from pathlib import PurePosixPath
import random
import torch
import torchaudio
from torch.utils.data import Dataset
import sys
from lib import *
from utils import *
import torch.utils.data
def int16_to_float32(x):
return (x / 32767.0).type(torch.float)
def float32_to_int16(x):
x = torch.clip(x, min=-1., max=1.)
return (x * 32767.).type(torch.int16)
def my_collate(batch):
batch = [x for x in batch if x is not None]
if len(batch) == 0:
return batch
else:
return torch.utils.data.dataloader.default_collate(batch)
class AudiostockDataset(Dataset):
'''
Args:
dataset_path (str): the dataset folder path
train (bool): if True, we randomly return a 10-sec chunk from each audio file; if False, we return the middle 10-sec chunk (fixed)
split (str): a txt file to assign the idx in this dataset (for trainng, validation and testing)
factor (float): how many time we need to loop the whole dataset, this is to increase the number of training data batches in each epoch
whole_track (bool): if True, the dataset will return the full length of the audio file. However, this means the batch_size = 1, and it is usually in the test/validation case
'''
def __init__(self, dataset_path, tweet_prefix=True, prefix_length=10, normalize=False, dupefile='dupes.pkl', train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True, file_list=[]):
super().__init__()
# set up parameters
self.max_seq_len = 150
self.tweet_prefix = tweet_prefix
if self.tweet_prefix:
self.max_seq_len *= 2
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
self.prefix_length = prefix_length
self.normalize = normalize
self.id2neighbor = defaultdict(lambda: '')
if dedup:
if dupefile is not None and os.path.exists(dupefile):
with open(dupefile, 'rb') as dupefile:
self.is_rep = pickle.load(dupefile).is_rep
elif dupefile == 'both':
with open('dupes.pkl', 'rb') as dupefile:
dupes1 = pickle.load(dupefile)
with open('dupes_audio.pkl', 'rb') as dupefile:
dupes2 = pickle.load(dupefile)
self.is_rep = defaultdict(lambda: True)
for k,v in dupes1.is_rep.items():
self.is_rep[k] = v
for k,v in dupes2.is_rep.items():
self.is_rep[k] = v
else:
sys.exit('Could not find duplicate file')
subfolders = [f'audiostock-part-{i}' for i in range(1,9)]
self.label_path = os.path.join(dataset_path, 'audiostock-full-label')
self.whole_track = whole_track
self.file_list = file_list
# select out the elements for this split
if self.file_list == []:
temp_file_list = []
for subfolder in subfolders:
temp_file_list += [os.path.join(dataset_path, subfolder, f) for f in os.listdir(os.path.join(dataset_path, subfolder)) if not dedup or self.is_rep[os.path.basename(f).split('.')[0]]]
if split is not None:
split = set(np.loadtxt(split, dtype = str))
self.file_list = [f for f in temp_file_list if os.path.basename(f).split('.')[0] in split]
else:
self.file_list = temp_file_list
self.train = train
self.total_len = int(len(self.file_list) * factor)
if verbose:
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
def precompute_rand(self, candidate_set=None):
self.id2neighbor = defaultdict(lambda: '')
# if train
if candidate_set is None:
my_ids = []
candidate_caps = []
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
for batch in temp_loader:
my_ids += batch['id']
candidate_caps += batch['short_text']
for idx in my_ids:
self.id2neighbor[idx] = random.choice(candidate_caps)
# if test
else:
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
candidate_caps = []
for batch in temp_loader:
candidate_caps += batch['short_text']
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
my_ids = []
for batch in temp_loader:
my_ids += batch['id']
for idx in my_ids:
self.id2neighbor[idx] = random.choice(candidate_caps)
def precompute_gold(self):
self.id2neighbor = defaultdict(lambda: '')
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
for batch in temp_loader:
for idx,short_text in zip(batch['id'], batch['short_text']):
self.id2neighbor[idx] = short_text
def precompute_blank(self):
self.id2neighbor = defaultdict(lambda: '\n')
def precompute_neighbors(self, model, candidate_set=None):
print('Precomputing neighbors')
self.id2neighbor = defaultdict(lambda: '')
# if train and model given
if candidate_set is None:
# compute waveform embeddings for each song
cand_features = None
cand_ids = []
cand_caps = []
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
for batch in temp_loader:
with torch.no_grad():
batch_features = model.embed_waveform(batch['waveform'].cuda())
if cand_features is not None:
cand_features = torch.cat([cand_features, batch_features])
else:
cand_features = batch_features
cand_ids += batch['id']
cand_caps += batch['short_text']
progress.update()
progress.close()
my_features = cand_features
my_ids = cand_ids
# if test and model given
else:
# check if we already precomputed the embeddings
pickle_filename = 'nn_features.pkl'
if os.path.isfile(pickle_filename):
with open(pickle_filename, 'rb') as f:
(cand_features, cand_ids, cand_caps) = pickle.load(f)
else:
# build the features from the provided set instead of self
cand_features = None
cand_ids = []
cand_caps = []
temp_loader = DataLoader(candidate_set, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
for batch in temp_loader:
with torch.no_grad():
batch_features = model.embed_waveform(batch['waveform'].cuda())
if cand_features is not None:
cand_features = torch.cat([cand_features, batch_features])
else:
cand_features = batch_features
cand_ids += batch['id']
#cand_caps += [' '.join(x.split()[:10]) for x in batch['short_text']]
cand_caps += batch['short_text']
progress.update()
progress.close()
# dump to pickle so we don't have to redo this each time
with open(pickle_filename, 'wb') as f:
pickle.dump((cand_features, cand_ids, cand_caps), f)
# load up my own ids and features
my_features = None
my_ids = []
temp_loader = DataLoader(self, batch_size=32, shuffle=False, num_workers=32, drop_last=False, collate_fn=my_collate)
progress = tqdm(total=len(temp_loader), dynamic_ncols=True)
for batch in temp_loader:
with torch.no_grad():
batch_features = model.embed_waveform(batch['waveform'].cuda())
if my_features is not None:
my_features = torch.cat([my_features, batch_features])
else:
my_features = batch_features
my_ids += batch['id']
progress.update()
progress.close()
is_self_sim = my_ids == cand_ids
for idx,audio_id in tqdm(enumerate(my_ids), total=len(my_ids), dynamic_ncols=True):
features = my_features[idx]
similarities = features @ cand_features.T
# remove identical matches
if is_self_sim:
similarities[idx] = float('-inf')
best_idx = torch.argmax(similarities)
most_similar_caption = cand_caps[best_idx]
self.id2neighbor[my_ids[idx]] = most_similar_caption
def pad_tokens(self, tokens, tokens_tweet):
tweet_text_len = 0
if self.tweet_prefix:
tweet_text_len = tokens_tweet[:self.max_seq_len // 2].shape[0]
tokens = torch.cat((tokens_tweet[:tweet_text_len], tokens))
padding = self.max_seq_len - tokens.shape[0]
if padding > 0:
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
elif padding < 0:
tokens = tokens[:self.max_seq_len]
mask = tokens.ge(0) # mask is zero where we out of sequence
tokens[~mask] = 0
mask = mask.float()
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
return tokens, mask, tweet_text_len
def read_wav(self, filename):
# pickling functionality removed since it shouldn't be necessary
# chunk
try:
num_frames = torchaudio.info(filename).num_frames
except:
return None
# make sure it wasn't empty, if so die
if num_frames == 0:
return None
sta = 0
if not self.whole_track:
if self.train:
sta = random.randint(0, num_frames - 441001)
else:
sta = (num_frames - 441001) // 2
num_frames = 441000
y, sr = torchaudio.load(filename, frame_offset=sta, num_frames=num_frames)
# resample
y = torchaudio.functional.resample(y, sr, 48000)
y = y[:, :441000]
# mono
y = y.mean(dim=0)
# normalize
y = int16_to_float32(float32_to_int16(y))
return y
def __getitem__(self, index):
idx = index % len(self.file_list)
data_dict = {}
f = self.file_list[idx]
lf = os.path.join(self.label_path, os.path.basename(f).split('.')[0] + '.json')
data_dict['waveform'] = self.read_wav(f)
if os.path.isfile(lf):
with open(lf,'r') as label_file:
label_data = json.load(label_file)
data_dict['id'] = label_data['id']
data_dict['short_text'] = label_data['short_text']
if self.normalize:
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
if 'long_text' in label_data and label_data['long_text'] is not None:
data_dict['long_text'] = label_data['long_text']
else:
data_dict['long_text'] = ''
'''
data_dict['tag'] = label_data['tag']
data_dict['impression'] = label_data['impression']
data_dict['purpose'] = label_data['purpose']
'''
else:
data_dict['id'] = os.path.basename(f).split('.')[0]
data_dict['short_text'] = ''
data_dict['long_text'] = ''
# tokenize the caption
caption_proc = preproc(data_dict['short_text'], self.tokenizer)
tokens = torch.tensor(caption_proc, dtype=torch.int64)
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
data_dict['tokens'] = tokens
data_dict['mask'] = mask
data_dict['tweet_text_len'] = tweet_text_len
data_dict['tweet_text'] = tweet_text
if (data_dict['id'] is None or
data_dict['short_text'] is None or
data_dict['long_text'] is None or
data_dict['tokens'] is None or
data_dict['mask'] is None or
data_dict['tweet_text_len'] is None or
data_dict['tweet_text'] is None or
data_dict['waveform'] is None
):
return None
else:
return data_dict
def __len__(self):
return self.total_len
class MusicCapsDataset(AudiostockDataset):
def __init__(self, dataset_path, args, train = True, split = None, factor = 1.0, whole_track = False, verbose=True, dedup=True):
super(AudiostockDataset, self).__init__()
# set up parameters
self.max_seq_len = 150
self.tweet_prefix = args.tweet_prefix
if self.tweet_prefix:
self.max_seq_len *= 2
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2', local_files_only=True)
self.prefix_length = args.prefix_length
self.normalize = args.normalize
self.whole_track = whole_track
self.label_path = os.path.join(dataset_path, 'audio')
self.file_list = []
self.label_data = []
label_reader = pd.read_csv(f'{dataset_path}/musiccaps-resplit.csv')
for idx,row in label_reader.iterrows():
if (row['is_audioset_eval'] == 1 and split == 'musiccaps_eval') \
or (row['is_audioset_eval'] == 0 and split == 'musiccaps_train') \
or (row['is_audioset_eval'] == 2 and split == 'musiccaps_dev'):
data_dict = {}
data_dict['id'] = row['ytid']
self.file_list.append(f"{dataset_path}/audio/{data_dict['id']}.wav")
data_dict['short_text'] = row['caption']
if self.normalize:
data_dict['short_text'] = ' '.join(muscaps_tokenize(data_dict['short_text']))
data_dict['long_text'] = ''
data_dict['tag'] = row['aspect_list']
self.label_data.append(data_dict)
self.train = train
self.total_len = int(len(self.file_list) * factor)
if verbose:
print(f'Dataset Loaded | File Num.: {len(self.file_list)} | Batches per epoch: {self.total_len}')
def __getitem__(self, index):
idx = index % len(self.file_list)
data_dict = {}
f = self.file_list[idx]
data_dict['waveform'] = self.read_wav(f)
for k,v in self.label_data[idx].items():
data_dict[k] = v
# tokenize the caption
caption_proc = preproc(data_dict['short_text'], self.tokenizer)
tokens = torch.tensor(caption_proc, dtype=torch.int64)
tweet_text = self.id2neighbor[data_dict['id']] if self.tweet_prefix else ''
tweet_proc = preproc(tweet_text, self.tokenizer, stop=False)
tokens_tweet = torch.tensor(tweet_proc, dtype=torch.int64)
tokens, mask, tweet_text_len = self.pad_tokens(tokens, tokens_tweet)
data_dict['tokens'] = tokens
data_dict['mask'] = mask
data_dict['tweet_text_len'] = tweet_text_len
data_dict['tweet_text'] = tweet_text
if (data_dict['id'] is None or
data_dict['short_text'] is None or
data_dict['long_text'] is None or
data_dict['tokens'] is None or
data_dict['mask'] is None or
data_dict['tweet_text_len'] is None or
data_dict['tweet_text'] is None or
data_dict['waveform'] is None
):
return None
else:
return data_dict