# Coarse Transformer

### Libraries:

In [1]:
import torch
from audiolm_pytorch import HubertWithKmeans
from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer
from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer
from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream
import gc
from musiclm_pytorch import MuLaNEmbedQuantizer
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

In [2]:
checkpoint_path = './models/hubert/hubert_base_ls960.pt'
kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'

audio_output_dir = './audio'
batch_size = 1
data_max_length = 320 * 32
num_train_steps = 1000

In [3]:
audio_transformer = AudioSpectrogramTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64,
 spec_n_fft = 128,
 spec_win_length = 24,
 spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
 dim = 512,
 depth = 6,
 heads = 8,
 dim_head = 64
)

mulan = MuLaN(
 audio_transformer = audio_transformer,
 text_transformer = text_transformer
)

quantizer = MuLaNEmbedQuantizer(
 mulan = mulan, 
 conditioning_dims = (1024, 1024, 1024), 
 namespaces = ('semantic', 'coarse', 'fine')
)
wavs = torch.randn(2, 1024)
conds = quantizer(wavs = wavs, namespace = 'semantic')

spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer


In [4]:
def train_coarse_transformer():
 wav2vec = HubertWithKmeans(
 checkpoint_path=checkpoint_path,
 kmeans_path=kmeans_path
 )
 soundstream = MusicLMSoundStream(
 codebook_size=1024, # Add this line to specify the codebook size
 strides=(3, 4, 5, 8),
 target_sample_hz=24000,
 rq_num_quantizers=8
 )

 if torch.cuda.is_available():
 coarse_transformer = CoarseTransformer(
 num_semantic_tokens=wav2vec.codebook_size,
 codebook_size=1024,
 num_coarse_quantizers=4,
 dim=1024,
 depth=6,
 audio_text_condition=True
 ).cuda()
 else:
 coarse_transformer = CoarseTransformer(
 num_semantic_tokens=wav2vec.codebook_size,
 codebook_size=1024,
 num_coarse_quantizers=4,
 dim=1024,
 depth=6,
 audio_text_condition=True
 )

 trainer = CoarseTransformerTrainer(
 transformer=coarse_transformer,
 codec=soundstream,
 wav2vec=wav2vec,
 audio_conditioner=quantizer,
 folder=audio_output_dir,
 batch_size=batch_size,
 data_max_length=data_max_length,
 num_train_steps=num_train_steps
 )

 trainer.train()
 torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')
 print("save coarse_transformer.pth")
 del coarse_transformer, trainer, wav2vec, soundstream
 gc.collect()

train_coarse_transformer()

ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
ANTLR runtime and generated code versions disagree: 4.9.3!=4.8
training with dataset of 4806 samples and validating with randomly splitted 253 samples
0: loss: 90.55248260498047
0: valid loss 28.765926361083984
0: saving model to results
1: loss: 39.71841812133789
2: loss: 89.22168731689453
3: loss: 64.72769927978516
4: loss: 46.61131286621094
5: loss: 71.61656951904297
6: loss: 51.03081130981445
7: loss: 41.790443420410156
8: loss: 53.92983627319336
9: loss: 34.468536376953125
10: loss: 33.230533599853516
11: loss: 39.82740020751953
12: loss: 25.284324645996094
13: loss: 28.97213363647461
14: loss: 30.330350875854492
15: loss: 29.048341751098633
16: loss: 22.92132568359375
17: loss: 19.784038543701172
18: loss: 24.917173385620117
19: loss: 21.861900329589844
20: loss: 21.64893913269043
21: loss: 19.426795959472656
22: loss: 16.47875213623047
23: loss: 14.150989532470703
24: loss: 16.4312686920166
25: loss: 10.7322006225585

KeyboardInterrupt: 