# Fine Transformer

### Libraries:

In [1]:
import torch
from audiolm_pytorch import HubertWithKmeans
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer
from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream
from musiclm_pytorch import MuLaNEmbedQuantizer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer
import gc
from nltk.tokenize import word_tokenize
import nltk
import librosa
import numpy as np
import pickle

In [2]:
nltk.download('punkt')
checkpoint_path = './models/hubert/hubert_base_ls960.pt'
kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'

audio_output_dir = './audio'
batch_size = 1
data_max_length = 320 * 32
num_train_steps = 1000

[nltk_data] Downloading package punkt to
[nltk_data] C:\Users\hp\AppData\Roaming\nltk_data...
[nltk_data] Package punkt is already up-to-date!


In [3]:
audio_transformer = AudioSpectrogramTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64,
 spec_n_fft = 128,
 spec_win_length = 24,
 spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64
)

mulan = MuLaN(
 audio_transformer = audio_transformer,
 text_transformer = text_transformer
)

quantizer = MuLaNEmbedQuantizer(
 mulan = mulan, 
 conditioning_dims = (1024, 1024, 1024), 
 namespaces = ('semantic', 'coarse', 'fine')
)


def train_fine_transformer():
 soundstream = MusicLMSoundStream(
 codebook_size=1024, 
 strides=(3, 4, 5, 8),
 target_sample_hz=24000,
 rq_num_quantizers=8
 )

 if torch.cuda.is_available():
 fine_transformer = FineTransformer(
 num_coarse_quantizers = 4,
 num_fine_quantizers = 4,
 codebook_size = 1024,
 dim = 1024,
 depth = 6,
 audio_text_condition = True
 ).cuda()
 else:
 fine_transformer = FineTransformer(
 num_coarse_quantizers = 4,
 num_fine_quantizers = 4,
 codebook_size = 1024,
 dim = 1024,
 depth = 6,
 audio_text_condition = True
 )

 trainer = FineTransformerTrainer(
 transformer=fine_transformer,
 codec=soundstream,
 folder=audio_output_dir,
 batch_size=batch_size,
 data_max_length=data_max_length,
 num_train_steps=num_train_steps,
 audio_conditioner = quantizer
 )

 trainer.train()
 torch.save(fine_transformer.state_dict(), 'fine_transformer.pth')
 print("save fine_transformer.pth")
 del fine_transformer, trainer, soundstream
 gc.collect()


train_fine_transformer()

training with dataset of 4806 samples and validating with randomly splitted 253 samples
spectrogram yielded shape of (65, 841), but had to be cropped to (64, 832) to be patchified for transformer
0: loss: 103.04938507080078
0: valid loss 11.681041717529297
0: saving model to results
training complete
save fine_transformer.pth
