Spaces:
Running
Running
SunilGopal
commited on
Commit
•
1e37f65
1
Parent(s):
9657edf
Upload 4 files
Browse files- coarse_transformer.ipynb +632 -0
- fine_transformer.ipynb +183 -0
- musiclm.ipynb +304 -0
- semantic_transformer.ipynb +851 -0
coarse_transformer.ipynb
ADDED
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Coarse Transformer"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"### Libraries:"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 1,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"import torch\n",
|
24 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
25 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
26 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
27 |
+
"from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
|
28 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
|
29 |
+
"import gc\n",
|
30 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
31 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
41 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
42 |
+
"\n",
|
43 |
+
"audio_output_dir = './audio'\n",
|
44 |
+
"batch_size = 1\n",
|
45 |
+
"data_max_length = 320 * 32\n",
|
46 |
+
"num_train_steps = 1000"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 3,
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [
|
54 |
+
{
|
55 |
+
"name": "stdout",
|
56 |
+
"output_type": "stream",
|
57 |
+
"text": [
|
58 |
+
"spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n"
|
59 |
+
]
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
64 |
+
" dim = 512,\n",
|
65 |
+
" depth = 6,\n",
|
66 |
+
" heads = 8,\n",
|
67 |
+
" dim_head = 64,\n",
|
68 |
+
" spec_n_fft = 128,\n",
|
69 |
+
" spec_win_length = 24,\n",
|
70 |
+
" spec_aug_stretch_factor = 0.8\n",
|
71 |
+
")\n",
|
72 |
+
"\n",
|
73 |
+
"text_transformer = TextTransformer(\n",
|
74 |
+
" dim = 512,\n",
|
75 |
+
" depth = 6,\n",
|
76 |
+
" heads = 8,\n",
|
77 |
+
" dim_head = 64\n",
|
78 |
+
")\n",
|
79 |
+
"\n",
|
80 |
+
"mulan = MuLaN(\n",
|
81 |
+
" audio_transformer = audio_transformer,\n",
|
82 |
+
" text_transformer = text_transformer\n",
|
83 |
+
")\n",
|
84 |
+
"\n",
|
85 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
86 |
+
" mulan = mulan, \n",
|
87 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
88 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
89 |
+
")\n",
|
90 |
+
"wavs = torch.randn(2, 1024)\n",
|
91 |
+
"conds = quantizer(wavs = wavs, namespace = 'semantic')"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 4,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [
|
99 |
+
{
|
100 |
+
"name": "stdout",
|
101 |
+
"output_type": "stream",
|
102 |
+
"text": [
|
103 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
104 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
105 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
106 |
+
"0: loss: 90.55248260498047\n",
|
107 |
+
"0: valid loss 28.765926361083984\n",
|
108 |
+
"0: saving model to results\n",
|
109 |
+
"1: loss: 39.71841812133789\n",
|
110 |
+
"2: loss: 89.22168731689453\n",
|
111 |
+
"3: loss: 64.72769927978516\n",
|
112 |
+
"4: loss: 46.61131286621094\n",
|
113 |
+
"5: loss: 71.61656951904297\n",
|
114 |
+
"6: loss: 51.03081130981445\n",
|
115 |
+
"7: loss: 41.790443420410156\n",
|
116 |
+
"8: loss: 53.92983627319336\n",
|
117 |
+
"9: loss: 34.468536376953125\n",
|
118 |
+
"10: loss: 33.230533599853516\n",
|
119 |
+
"11: loss: 39.82740020751953\n",
|
120 |
+
"12: loss: 25.284324645996094\n",
|
121 |
+
"13: loss: 28.97213363647461\n",
|
122 |
+
"14: loss: 30.330350875854492\n",
|
123 |
+
"15: loss: 29.048341751098633\n",
|
124 |
+
"16: loss: 22.92132568359375\n",
|
125 |
+
"17: loss: 19.784038543701172\n",
|
126 |
+
"18: loss: 24.917173385620117\n",
|
127 |
+
"19: loss: 21.861900329589844\n",
|
128 |
+
"20: loss: 21.64893913269043\n",
|
129 |
+
"21: loss: 19.426795959472656\n",
|
130 |
+
"22: loss: 16.47875213623047\n",
|
131 |
+
"23: loss: 14.150989532470703\n",
|
132 |
+
"24: loss: 16.4312686920166\n",
|
133 |
+
"25: loss: 10.732200622558594\n",
|
134 |
+
"26: loss: 9.64625358581543\n",
|
135 |
+
"27: loss: 13.40906047821045\n",
|
136 |
+
"28: loss: 8.942117691040039\n",
|
137 |
+
"29: loss: 14.944022178649902\n",
|
138 |
+
"30: loss: 17.149667739868164\n",
|
139 |
+
"31: loss: 8.965814590454102\n",
|
140 |
+
"32: loss: 10.492903709411621\n",
|
141 |
+
"33: loss: 11.236382484436035\n",
|
142 |
+
"34: loss: 10.356119155883789\n",
|
143 |
+
"35: loss: 9.816141128540039\n",
|
144 |
+
"36: loss: 11.789191246032715\n",
|
145 |
+
"37: loss: 10.450325012207031\n",
|
146 |
+
"38: loss: 18.911396026611328\n",
|
147 |
+
"39: loss: 8.278931617736816\n",
|
148 |
+
"40: loss: 10.884782791137695\n",
|
149 |
+
"41: loss: 8.885784149169922\n",
|
150 |
+
"42: loss: 9.226049423217773\n",
|
151 |
+
"43: loss: 10.362125396728516\n",
|
152 |
+
"44: loss: 4.0845770835876465\n",
|
153 |
+
"45: loss: 9.664544105529785\n",
|
154 |
+
"46: loss: 9.46312427520752\n",
|
155 |
+
"47: loss: 9.138323783874512\n",
|
156 |
+
"48: loss: 7.396448135375977\n",
|
157 |
+
"49: loss: 7.293612480163574\n",
|
158 |
+
"50: loss: 10.331693649291992\n",
|
159 |
+
"51: loss: 7.775559425354004\n",
|
160 |
+
"52: loss: 7.011277198791504\n",
|
161 |
+
"53: loss: 6.324047565460205\n",
|
162 |
+
"54: loss: 5.501199245452881\n",
|
163 |
+
"55: loss: 4.69442081451416\n",
|
164 |
+
"56: loss: 4.073971748352051\n",
|
165 |
+
"57: loss: 4.142904758453369\n",
|
166 |
+
"58: loss: 4.585968017578125\n",
|
167 |
+
"59: loss: 4.700481414794922\n",
|
168 |
+
"60: loss: 5.152374267578125\n",
|
169 |
+
"61: loss: 8.181085586547852\n",
|
170 |
+
"62: loss: 6.7371416091918945\n",
|
171 |
+
"63: loss: 10.67423152923584\n",
|
172 |
+
"64: loss: 5.926950454711914\n",
|
173 |
+
"65: loss: 5.470860004425049\n",
|
174 |
+
"66: loss: 4.630016803741455\n",
|
175 |
+
"67: loss: 5.366561412811279\n",
|
176 |
+
"68: loss: 11.271105766296387\n",
|
177 |
+
"69: loss: 6.516841411590576\n",
|
178 |
+
"70: loss: 7.9438066482543945\n",
|
179 |
+
"71: loss: 5.358776092529297\n",
|
180 |
+
"72: loss: 5.713461875915527\n",
|
181 |
+
"73: loss: 7.075550556182861\n",
|
182 |
+
"74: loss: 5.229584217071533\n",
|
183 |
+
"75: loss: 5.103419303894043\n",
|
184 |
+
"76: loss: 4.516308307647705\n",
|
185 |
+
"77: loss: 7.4682488441467285\n",
|
186 |
+
"78: loss: 7.275866508483887\n",
|
187 |
+
"79: loss: 5.846785545349121\n",
|
188 |
+
"80: loss: 5.688624382019043\n",
|
189 |
+
"81: loss: 5.150119781494141\n",
|
190 |
+
"82: loss: 4.671944618225098\n",
|
191 |
+
"83: loss: 8.293455123901367\n",
|
192 |
+
"84: loss: 7.202897071838379\n",
|
193 |
+
"85: loss: 4.38778018951416\n",
|
194 |
+
"86: loss: 4.410329818725586\n",
|
195 |
+
"87: loss: 4.341781139373779\n",
|
196 |
+
"88: loss: 4.000961780548096\n",
|
197 |
+
"89: loss: 4.009156703948975\n",
|
198 |
+
"90: loss: 3.562082052230835\n",
|
199 |
+
"91: loss: 3.641108989715576\n",
|
200 |
+
"92: loss: 5.916473388671875\n",
|
201 |
+
"93: loss: 4.046755790710449\n",
|
202 |
+
"94: loss: 6.699942111968994\n",
|
203 |
+
"95: loss: 6.139719009399414\n",
|
204 |
+
"96: loss: 10.71791934967041\n",
|
205 |
+
"97: loss: 4.094853401184082\n",
|
206 |
+
"98: loss: 6.08973503112793\n",
|
207 |
+
"99: loss: 9.11803150177002\n",
|
208 |
+
"100: loss: 8.486052513122559\n",
|
209 |
+
"100: valid loss 4.0021281242370605\n",
|
210 |
+
"101: loss: 4.0021281242370605\n",
|
211 |
+
"102: loss: 3.346961736679077\n",
|
212 |
+
"103: loss: 3.15854549407959\n",
|
213 |
+
"104: loss: 2.5357956886291504\n",
|
214 |
+
"105: loss: 5.492861270904541\n",
|
215 |
+
"106: loss: 2.7623958587646484\n",
|
216 |
+
"107: loss: 2.9482226371765137\n",
|
217 |
+
"108: loss: 6.3801493644714355\n",
|
218 |
+
"109: loss: 4.1293463706970215\n",
|
219 |
+
"110: loss: 3.566096067428589\n",
|
220 |
+
"111: loss: 3.569946527481079\n",
|
221 |
+
"112: loss: 3.762925624847412\n",
|
222 |
+
"113: loss: 6.147146701812744\n",
|
223 |
+
"114: loss: 5.933719635009766\n",
|
224 |
+
"115: loss: 6.800720691680908\n",
|
225 |
+
"116: loss: 2.86614990234375\n",
|
226 |
+
"117: loss: 3.0812878608703613\n",
|
227 |
+
"118: loss: 3.110222101211548\n",
|
228 |
+
"119: loss: 4.000320911407471\n",
|
229 |
+
"120: loss: 3.2422871589660645\n",
|
230 |
+
"121: loss: 3.7775020599365234\n",
|
231 |
+
"122: loss: 3.595900774002075\n",
|
232 |
+
"123: loss: 2.73819637298584\n",
|
233 |
+
"124: loss: 3.4981672763824463\n",
|
234 |
+
"125: loss: 5.3726325035095215\n",
|
235 |
+
"126: loss: 3.0014798641204834\n",
|
236 |
+
"127: loss: 3.5963802337646484\n",
|
237 |
+
"128: loss: 2.8306686878204346\n",
|
238 |
+
"129: loss: 2.5162878036499023\n",
|
239 |
+
"130: loss: 2.685560941696167\n",
|
240 |
+
"131: loss: 6.374442100524902\n",
|
241 |
+
"132: loss: 7.788975715637207\n",
|
242 |
+
"133: loss: 2.897576332092285\n",
|
243 |
+
"134: loss: 3.333127737045288\n",
|
244 |
+
"135: loss: 3.436774253845215\n",
|
245 |
+
"136: loss: 4.979071617126465\n",
|
246 |
+
"137: loss: 4.120012283325195\n",
|
247 |
+
"138: loss: 3.7855355739593506\n",
|
248 |
+
"139: loss: 4.324587345123291\n",
|
249 |
+
"140: loss: 3.4336843490600586\n",
|
250 |
+
"141: loss: 2.6801435947418213\n",
|
251 |
+
"142: loss: 3.359581470489502\n",
|
252 |
+
"143: loss: 5.4692182540893555\n",
|
253 |
+
"144: loss: 5.773078918457031\n",
|
254 |
+
"145: loss: 4.27987813949585\n",
|
255 |
+
"146: loss: 7.247451305389404\n",
|
256 |
+
"147: loss: 6.170166492462158\n",
|
257 |
+
"148: loss: 4.961609840393066\n",
|
258 |
+
"149: loss: 4.028770923614502\n",
|
259 |
+
"150: loss: 2.90120005607605\n",
|
260 |
+
"151: loss: 1.9893661737442017\n",
|
261 |
+
"152: loss: 1.652574062347412\n",
|
262 |
+
"153: loss: 2.374600887298584\n",
|
263 |
+
"154: loss: 2.1045265197753906\n",
|
264 |
+
"155: loss: 6.417508125305176\n",
|
265 |
+
"156: loss: 5.273669719696045\n",
|
266 |
+
"157: loss: 6.238985538482666\n",
|
267 |
+
"158: loss: 3.8025736808776855\n",
|
268 |
+
"159: loss: 6.6854705810546875\n",
|
269 |
+
"160: loss: 2.5476467609405518\n",
|
270 |
+
"161: loss: 6.810393810272217\n",
|
271 |
+
"162: loss: 2.2033159732818604\n",
|
272 |
+
"163: loss: 1.9863100051879883\n",
|
273 |
+
"164: loss: 4.976431369781494\n",
|
274 |
+
"165: loss: 3.899188756942749\n",
|
275 |
+
"166: loss: 4.68454647064209\n",
|
276 |
+
"167: loss: 2.4539690017700195\n",
|
277 |
+
"168: loss: 6.830282688140869\n",
|
278 |
+
"169: loss: 1.7942843437194824\n",
|
279 |
+
"170: loss: 1.242318868637085\n",
|
280 |
+
"171: loss: 5.012855052947998\n",
|
281 |
+
"172: loss: 1.6154134273529053\n",
|
282 |
+
"173: loss: 1.5895756483078003\n",
|
283 |
+
"174: loss: 5.240614891052246\n",
|
284 |
+
"175: loss: 1.8958660364151\n",
|
285 |
+
"176: loss: 2.1411402225494385\n",
|
286 |
+
"177: loss: 5.932228088378906\n",
|
287 |
+
"178: loss: 2.7539122104644775\n",
|
288 |
+
"179: loss: 6.218499660491943\n",
|
289 |
+
"180: loss: 2.991704225540161\n",
|
290 |
+
"181: loss: 3.378645896911621\n",
|
291 |
+
"182: loss: 2.719741106033325\n",
|
292 |
+
"183: loss: 2.5844321250915527\n",
|
293 |
+
"184: loss: 5.851257801055908\n",
|
294 |
+
"185: loss: 2.239989995956421\n",
|
295 |
+
"186: loss: 5.5589141845703125\n",
|
296 |
+
"187: loss: 3.11521053314209\n",
|
297 |
+
"188: loss: 2.5269265174865723\n",
|
298 |
+
"189: loss: 2.181260824203491\n",
|
299 |
+
"190: loss: 1.8941911458969116\n",
|
300 |
+
"191: loss: 5.106175422668457\n",
|
301 |
+
"192: loss: 3.5514838695526123\n",
|
302 |
+
"193: loss: 3.233003854751587\n",
|
303 |
+
"194: loss: 2.55694317817688\n",
|
304 |
+
"195: loss: 6.5134053230285645\n",
|
305 |
+
"196: loss: 6.311967372894287\n",
|
306 |
+
"197: loss: 2.3541362285614014\n",
|
307 |
+
"198: loss: 6.195401668548584\n",
|
308 |
+
"199: loss: 3.013007879257202\n",
|
309 |
+
"200: loss: 2.53104567527771\n",
|
310 |
+
"200: valid loss 1.895339846611023\n",
|
311 |
+
"201: loss: 7.572109699249268\n",
|
312 |
+
"202: loss: 1.946860909461975\n",
|
313 |
+
"203: loss: 1.6077873706817627\n",
|
314 |
+
"204: loss: 1.5050052404403687\n",
|
315 |
+
"205: loss: 1.1216596364974976\n",
|
316 |
+
"206: loss: 1.017206072807312\n",
|
317 |
+
"207: loss: 7.081823825836182\n",
|
318 |
+
"208: loss: 1.1608872413635254\n",
|
319 |
+
"209: loss: 0.728882908821106\n",
|
320 |
+
"210: loss: 0.514722466468811\n",
|
321 |
+
"211: loss: 0.6075964570045471\n",
|
322 |
+
"212: loss: 0.7593868970870972\n",
|
323 |
+
"213: loss: 0.6465023159980774\n",
|
324 |
+
"214: loss: 8.1160888671875\n",
|
325 |
+
"215: loss: 0.8256340622901917\n",
|
326 |
+
"216: loss: 0.5982277393341064\n",
|
327 |
+
"217: loss: 7.202335834503174\n",
|
328 |
+
"218: loss: 4.8967790603637695\n",
|
329 |
+
"219: loss: 2.037604331970215\n",
|
330 |
+
"220: loss: 1.7443571090698242\n",
|
331 |
+
"221: loss: 0.8838777542114258\n",
|
332 |
+
"222: loss: 0.7871264219284058\n",
|
333 |
+
"223: loss: 5.985363483428955\n",
|
334 |
+
"224: loss: 3.6808922290802\n",
|
335 |
+
"225: loss: 4.453125476837158\n",
|
336 |
+
"226: loss: 4.137350559234619\n",
|
337 |
+
"227: loss: 1.5606231689453125\n",
|
338 |
+
"228: loss: 5.764791488647461\n",
|
339 |
+
"229: loss: 1.2394036054611206\n",
|
340 |
+
"230: loss: 1.1438194513320923\n",
|
341 |
+
"231: loss: 0.5560073852539062\n",
|
342 |
+
"232: loss: 5.746810436248779\n",
|
343 |
+
"233: loss: 4.34252405166626\n",
|
344 |
+
"234: loss: 6.079676628112793\n",
|
345 |
+
"235: loss: 4.213600158691406\n",
|
346 |
+
"236: loss: 1.1661522388458252\n",
|
347 |
+
"237: loss: 7.770791053771973\n",
|
348 |
+
"238: loss: 3.6331183910369873\n",
|
349 |
+
"239: loss: 6.657710552215576\n",
|
350 |
+
"240: loss: 4.314018249511719\n",
|
351 |
+
"241: loss: 3.964081048965454\n",
|
352 |
+
"242: loss: 3.4643802642822266\n",
|
353 |
+
"243: loss: 3.2389814853668213\n",
|
354 |
+
"244: loss: 5.009263515472412\n",
|
355 |
+
"245: loss: 5.4173903465271\n",
|
356 |
+
"246: loss: 3.464853048324585\n",
|
357 |
+
"247: loss: 2.690930128097534\n",
|
358 |
+
"248: loss: 5.482550621032715\n",
|
359 |
+
"249: loss: 1.500435709953308\n",
|
360 |
+
"250: loss: 1.207865834236145\n",
|
361 |
+
"251: loss: 6.162202835083008\n",
|
362 |
+
"252: loss: 0.5159206986427307\n",
|
363 |
+
"253: loss: 0.352285772562027\n",
|
364 |
+
"254: loss: 0.28347644209861755\n",
|
365 |
+
"255: loss: 0.2998739182949066\n",
|
366 |
+
"256: loss: 7.412589073181152\n",
|
367 |
+
"257: loss: 1.0271281003952026\n",
|
368 |
+
"258: loss: 0.5622831583023071\n",
|
369 |
+
"259: loss: 6.975170135498047\n",
|
370 |
+
"260: loss: 0.050237879157066345\n",
|
371 |
+
"261: loss: 9.500787734985352\n",
|
372 |
+
"262: loss: 1.1100494861602783\n",
|
373 |
+
"263: loss: 10.5401029586792\n",
|
374 |
+
"264: loss: 7.637964725494385\n",
|
375 |
+
"265: loss: 1.5384433269500732\n",
|
376 |
+
"266: loss: 0.6748937368392944\n",
|
377 |
+
"267: loss: 0.38336750864982605\n",
|
378 |
+
"268: loss: 0.1832476705312729\n",
|
379 |
+
"269: loss: 7.080984115600586\n",
|
380 |
+
"270: loss: 6.806582927703857\n",
|
381 |
+
"271: loss: 6.216980457305908\n",
|
382 |
+
"272: loss: 8.122699737548828\n",
|
383 |
+
"273: loss: 2.344430685043335\n",
|
384 |
+
"274: loss: 5.185897350311279\n",
|
385 |
+
"275: loss: 5.136538982391357\n",
|
386 |
+
"276: loss: 4.847122669219971\n",
|
387 |
+
"277: loss: 3.447641372680664\n",
|
388 |
+
"278: loss: 1.9696052074432373\n",
|
389 |
+
"279: loss: 6.129249095916748\n",
|
390 |
+
"280: loss: 1.4744977951049805\n",
|
391 |
+
"281: loss: 4.836997032165527\n",
|
392 |
+
"282: loss: 4.361396789550781\n",
|
393 |
+
"283: loss: 4.975046157836914\n",
|
394 |
+
"284: loss: 5.6431074142456055\n",
|
395 |
+
"285: loss: 8.127538681030273\n",
|
396 |
+
"286: loss: 7.203218460083008\n",
|
397 |
+
"287: loss: 2.408040761947632\n",
|
398 |
+
"288: loss: 1.7607803344726562\n",
|
399 |
+
"289: loss: 1.1752283573150635\n",
|
400 |
+
"290: loss: 5.39897346496582\n",
|
401 |
+
"291: loss: 0.8753417134284973\n",
|
402 |
+
"292: loss: 6.104700088500977\n",
|
403 |
+
"293: loss: 0.8714774250984192\n",
|
404 |
+
"294: loss: 5.633414268493652\n",
|
405 |
+
"295: loss: 1.0734435319900513\n",
|
406 |
+
"296: loss: 0.5978174209594727\n",
|
407 |
+
"297: loss: 0.6240620613098145\n",
|
408 |
+
"298: loss: 0.3799970746040344\n",
|
409 |
+
"299: loss: 5.793654441833496\n",
|
410 |
+
"300: loss: 4.920631408691406\n",
|
411 |
+
"300: valid loss 0.5733768343925476\n",
|
412 |
+
"301: loss: 0.5733768343925476\n",
|
413 |
+
"302: loss: 0.35356906056404114\n",
|
414 |
+
"303: loss: 6.0288190841674805\n",
|
415 |
+
"304: loss: 0.17994554340839386\n",
|
416 |
+
"305: loss: 6.07096004486084\n",
|
417 |
+
"306: loss: 0.798763632774353\n",
|
418 |
+
"307: loss: 0.30721110105514526\n",
|
419 |
+
"308: loss: 0.35866862535476685\n",
|
420 |
+
"309: loss: 6.664376258850098\n",
|
421 |
+
"310: loss: 10.371112823486328\n",
|
422 |
+
"311: loss: 1.5442111492156982\n",
|
423 |
+
"312: loss: 0.5046924948692322\n",
|
424 |
+
"313: loss: 0.02138896845281124\n",
|
425 |
+
"314: loss: 11.088417053222656\n",
|
426 |
+
"315: loss: 0.2801823616027832\n",
|
427 |
+
"316: loss: 1.6325680017471313\n",
|
428 |
+
"317: loss: 1.042490005493164\n",
|
429 |
+
"318: loss: 0.19980621337890625\n",
|
430 |
+
"319: loss: 6.208798408508301\n",
|
431 |
+
"320: loss: 2.2923152446746826\n",
|
432 |
+
"321: loss: 1.5293265581130981\n",
|
433 |
+
"322: loss: 5.384918212890625\n",
|
434 |
+
"323: loss: 0.5806372165679932\n",
|
435 |
+
"324: loss: 0.11083264648914337\n",
|
436 |
+
"325: loss: 6.474861145019531\n",
|
437 |
+
"326: loss: 6.7361063957214355\n",
|
438 |
+
"327: loss: 6.07684850692749\n",
|
439 |
+
"328: loss: 0.1449495404958725\n",
|
440 |
+
"329: loss: 0.24492450058460236\n",
|
441 |
+
"330: loss: 0.0179277490824461\n",
|
442 |
+
"331: loss: 5.866001605987549\n",
|
443 |
+
"332: loss: 0.14012691378593445\n",
|
444 |
+
"333: loss: 0.14467062056064606\n",
|
445 |
+
"334: loss: 0.01395170483738184\n",
|
446 |
+
"335: loss: 0.04150881618261337\n",
|
447 |
+
"336: loss: 0.07648518681526184\n",
|
448 |
+
"337: loss: 9.367613792419434\n",
|
449 |
+
"338: loss: 8.372873306274414\n",
|
450 |
+
"339: loss: 0.6273093223571777\n",
|
451 |
+
"340: loss: 0.11360179632902145\n",
|
452 |
+
"341: loss: 0.02351052314043045\n",
|
453 |
+
"342: loss: 0.06904540210962296\n",
|
454 |
+
"343: loss: 0.02174321562051773\n",
|
455 |
+
"344: loss: 0.11702124029397964\n",
|
456 |
+
"345: loss: 0.061455100774765015\n",
|
457 |
+
"346: loss: 0.03193430230021477\n",
|
458 |
+
"347: loss: 0.33268794417381287\n",
|
459 |
+
"348: loss: 0.053275030106306076\n",
|
460 |
+
"349: loss: 0.009291582740843296\n",
|
461 |
+
"350: loss: 0.18401774764060974\n",
|
462 |
+
"351: loss: 0.30571281909942627\n",
|
463 |
+
"352: loss: 17.913070678710938\n",
|
464 |
+
"353: loss: 0.2126859426498413\n",
|
465 |
+
"354: loss: 0.6229326128959656\n",
|
466 |
+
"355: loss: 11.214807510375977\n",
|
467 |
+
"356: loss: 0.15888328850269318\n",
|
468 |
+
"357: loss: 0.662460446357727\n",
|
469 |
+
"358: loss: 7.345875263214111\n",
|
470 |
+
"359: loss: 7.803595066070557\n",
|
471 |
+
"360: loss: 1.2322083711624146\n",
|
472 |
+
"361: loss: 0.7014895081520081\n",
|
473 |
+
"362: loss: 0.10298460721969604\n",
|
474 |
+
"363: loss: 8.574231147766113\n",
|
475 |
+
"364: loss: 0.03108447603881359\n",
|
476 |
+
"365: loss: 0.6616091728210449\n",
|
477 |
+
"366: loss: 4.938299655914307\n",
|
478 |
+
"367: loss: 5.479018688201904\n",
|
479 |
+
"368: loss: 6.740688800811768\n",
|
480 |
+
"369: loss: 3.110865831375122\n",
|
481 |
+
"370: loss: 4.795236587524414\n",
|
482 |
+
"371: loss: 1.8502461910247803\n",
|
483 |
+
"372: loss: 3.737464427947998\n",
|
484 |
+
"373: loss: 1.9333598613739014\n",
|
485 |
+
"374: loss: 7.145735740661621\n",
|
486 |
+
"375: loss: 1.3372946977615356\n",
|
487 |
+
"376: loss: 5.683573246002197\n",
|
488 |
+
"377: loss: 1.204305648803711\n",
|
489 |
+
"378: loss: 0.9289284348487854\n",
|
490 |
+
"379: loss: 5.174688339233398\n",
|
491 |
+
"380: loss: 1.458616852760315\n",
|
492 |
+
"381: loss: 0.9457168579101562\n",
|
493 |
+
"382: loss: 0.4627819359302521\n",
|
494 |
+
"383: loss: 0.2658665180206299\n",
|
495 |
+
"384: loss: 4.429558753967285\n",
|
496 |
+
"385: loss: 1.2449607849121094\n",
|
497 |
+
"386: loss: 1.3288488388061523\n",
|
498 |
+
"387: loss: 6.628821849822998\n",
|
499 |
+
"388: loss: 0.4825551211833954\n",
|
500 |
+
"389: loss: 0.6510865688323975\n",
|
501 |
+
"390: loss: 0.36395493149757385\n",
|
502 |
+
"391: loss: 0.18036174774169922\n",
|
503 |
+
"392: loss: 0.3237663209438324\n",
|
504 |
+
"393: loss: 6.840792655944824\n",
|
505 |
+
"394: loss: 1.6587960720062256\n",
|
506 |
+
"395: loss: 7.458000659942627\n",
|
507 |
+
"396: loss: 0.8729283809661865\n",
|
508 |
+
"397: loss: 0.6731876134872437\n",
|
509 |
+
"398: loss: 0.1747300773859024\n",
|
510 |
+
"399: loss: 0.5882076621055603\n",
|
511 |
+
"400: loss: 0.6982569098472595\n",
|
512 |
+
"400: valid loss 0.4763210713863373\n",
|
513 |
+
"401: loss: 0.4763210713863373\n",
|
514 |
+
"402: loss: 0.46096739172935486\n",
|
515 |
+
"403: loss: 4.166454792022705\n",
|
516 |
+
"404: loss: 0.44991931319236755\n",
|
517 |
+
"405: loss: 4.830379009246826\n",
|
518 |
+
"406: loss: 0.5408239364624023\n",
|
519 |
+
"407: loss: 0.2607786953449249\n",
|
520 |
+
"408: loss: 0.13067474961280823\n",
|
521 |
+
"409: loss: 4.062631130218506\n",
|
522 |
+
"410: loss: 5.5028300285339355\n",
|
523 |
+
"411: loss: 1.2942296266555786\n",
|
524 |
+
"412: loss: 1.4390389919281006\n",
|
525 |
+
"413: loss: 5.374651908874512\n",
|
526 |
+
"414: loss: 1.2929461002349854\n",
|
527 |
+
"415: loss: 0.643798291683197\n",
|
528 |
+
"416: loss: 0.6353816986083984\n",
|
529 |
+
"417: loss: 5.8032636642456055\n",
|
530 |
+
"418: loss: 3.3737053871154785\n",
|
531 |
+
"419: loss: 1.8712362051010132\n",
|
532 |
+
"420: loss: 1.0622261762619019\n",
|
533 |
+
"421: loss: 0.8681365847587585\n",
|
534 |
+
"422: loss: 0.6761938333511353\n",
|
535 |
+
"423: loss: 4.074782371520996\n",
|
536 |
+
"424: loss: 0.4106965661048889\n"
|
537 |
+
]
|
538 |
+
},
|
539 |
+
{
|
540 |
+
"ename": "KeyboardInterrupt",
|
541 |
+
"evalue": "",
|
542 |
+
"output_type": "error",
|
543 |
+
"traceback": [
|
544 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
545 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
546 |
+
"Cell \u001b[1;32mIn[4], line 49\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m coarse_transformer, trainer, wav2vec, soundstream\n\u001b[0;32m 47\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 49\u001b[0m \u001b[43mtrain_coarse_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
547 |
+
"Cell \u001b[1;32mIn[4], line 43\u001b[0m, in \u001b[0;36mtrain_coarse_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 23\u001b[0m coarse_transformer \u001b[38;5;241m=\u001b[39m CoarseTransformer(\n\u001b[0;32m 24\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 25\u001b[0m codebook_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 29\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 30\u001b[0m )\n\u001b[0;32m 32\u001b[0m trainer \u001b[38;5;241m=\u001b[39m CoarseTransformerTrainer(\n\u001b[0;32m 33\u001b[0m transformer\u001b[38;5;241m=\u001b[39mcoarse_transformer,\n\u001b[0;32m 34\u001b[0m codec\u001b[38;5;241m=\u001b[39msoundstream,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 40\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 41\u001b[0m )\n\u001b[1;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 44\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(coarse_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcoarse_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 45\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave coarse_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
548 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1302\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 1299\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 1301\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1302\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1303\u001b[0m log_fn(logs)\n\u001b[0;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
549 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1244\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1238\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[0;32m 1239\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_wrapper(\n\u001b[0;32m 1240\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_kwargs,\n\u001b[0;32m 1241\u001b[0m return_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 1242\u001b[0m )\n\u001b[1;32m-> 1244\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad_accum_every\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1246\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n\u001b[0;32m 1248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_grad_norm):\n",
|
550 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\accelerate\\accelerator.py:2151\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[1;34m(self, loss, **kwargs)\u001b[0m\n\u001b[0;32m 2149\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[0;32m 2150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2151\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
551 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m 524\u001b[0m )\n\u001b[1;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
552 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
553 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[0;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
|
554 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
555 |
+
]
|
556 |
+
}
|
557 |
+
],
|
558 |
+
"source": [
|
559 |
+
"def train_coarse_transformer():\n",
|
560 |
+
" wav2vec = HubertWithKmeans(\n",
|
561 |
+
" checkpoint_path=checkpoint_path,\n",
|
562 |
+
" kmeans_path=kmeans_path\n",
|
563 |
+
" )\n",
|
564 |
+
" soundstream = MusicLMSoundStream(\n",
|
565 |
+
" codebook_size=1024, # Add this line to specify the codebook size\n",
|
566 |
+
" strides=(3, 4, 5, 8),\n",
|
567 |
+
" target_sample_hz=24000,\n",
|
568 |
+
" rq_num_quantizers=8\n",
|
569 |
+
" )\n",
|
570 |
+
"\n",
|
571 |
+
" if torch.cuda.is_available():\n",
|
572 |
+
" coarse_transformer = CoarseTransformer(\n",
|
573 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
574 |
+
" codebook_size=1024,\n",
|
575 |
+
" num_coarse_quantizers=4,\n",
|
576 |
+
" dim=1024,\n",
|
577 |
+
" depth=6,\n",
|
578 |
+
" audio_text_condition=True\n",
|
579 |
+
" ).cuda()\n",
|
580 |
+
" else:\n",
|
581 |
+
" coarse_transformer = CoarseTransformer(\n",
|
582 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
583 |
+
" codebook_size=1024,\n",
|
584 |
+
" num_coarse_quantizers=4,\n",
|
585 |
+
" dim=1024,\n",
|
586 |
+
" depth=6,\n",
|
587 |
+
" audio_text_condition=True\n",
|
588 |
+
" )\n",
|
589 |
+
"\n",
|
590 |
+
" trainer = CoarseTransformerTrainer(\n",
|
591 |
+
" transformer=coarse_transformer,\n",
|
592 |
+
" codec=soundstream,\n",
|
593 |
+
" wav2vec=wav2vec,\n",
|
594 |
+
" audio_conditioner=quantizer,\n",
|
595 |
+
" folder=audio_output_dir,\n",
|
596 |
+
" batch_size=batch_size,\n",
|
597 |
+
" data_max_length=data_max_length,\n",
|
598 |
+
" num_train_steps=num_train_steps\n",
|
599 |
+
" )\n",
|
600 |
+
"\n",
|
601 |
+
" trainer.train()\n",
|
602 |
+
" torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')\n",
|
603 |
+
" print(\"save coarse_transformer.pth\")\n",
|
604 |
+
" del coarse_transformer, trainer, wav2vec, soundstream\n",
|
605 |
+
" gc.collect()\n",
|
606 |
+
"\n",
|
607 |
+
"train_coarse_transformer()"
|
608 |
+
]
|
609 |
+
}
|
610 |
+
],
|
611 |
+
"metadata": {
|
612 |
+
"kernelspec": {
|
613 |
+
"display_name": "myenv",
|
614 |
+
"language": "python",
|
615 |
+
"name": "python3"
|
616 |
+
},
|
617 |
+
"language_info": {
|
618 |
+
"codemirror_mode": {
|
619 |
+
"name": "ipython",
|
620 |
+
"version": 3
|
621 |
+
},
|
622 |
+
"file_extension": ".py",
|
623 |
+
"mimetype": "text/x-python",
|
624 |
+
"name": "python",
|
625 |
+
"nbconvert_exporter": "python",
|
626 |
+
"pygments_lexer": "ipython3",
|
627 |
+
"version": "3.11.2"
|
628 |
+
}
|
629 |
+
},
|
630 |
+
"nbformat": 4,
|
631 |
+
"nbformat_minor": 2
|
632 |
+
}
|
fine_transformer.ipynb
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Fine Transformer"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"### Libraries:"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 1,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"import torch\n",
|
24 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
25 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
26 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
27 |
+
"from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n",
|
28 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n",
|
29 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
30 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
|
31 |
+
"import gc\n",
|
32 |
+
"from nltk.tokenize import word_tokenize\n",
|
33 |
+
"import nltk\n",
|
34 |
+
"import librosa\n",
|
35 |
+
"import numpy as np\n",
|
36 |
+
"import pickle"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 2,
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"name": "stderr",
|
46 |
+
"output_type": "stream",
|
47 |
+
"text": [
|
48 |
+
"[nltk_data] Downloading package punkt to\n",
|
49 |
+
"[nltk_data] C:\\Users\\hp\\AppData\\Roaming\\nltk_data...\n",
|
50 |
+
"[nltk_data] Package punkt is already up-to-date!\n"
|
51 |
+
]
|
52 |
+
}
|
53 |
+
],
|
54 |
+
"source": [
|
55 |
+
"nltk.download('punkt')\n",
|
56 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
57 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
58 |
+
"\n",
|
59 |
+
"audio_output_dir = './audio'\n",
|
60 |
+
"batch_size = 1\n",
|
61 |
+
"data_max_length = 320 * 32\n",
|
62 |
+
"num_train_steps = 1000"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 3,
|
68 |
+
"metadata": {},
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"name": "stdout",
|
72 |
+
"output_type": "stream",
|
73 |
+
"text": [
|
74 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
75 |
+
"spectrogram yielded shape of (65, 841), but had to be cropped to (64, 832) to be patchified for transformer\n",
|
76 |
+
"0: loss: 103.04938507080078\n",
|
77 |
+
"0: valid loss 11.681041717529297\n",
|
78 |
+
"0: saving model to results\n",
|
79 |
+
"training complete\n",
|
80 |
+
"save fine_transformer.pth\n"
|
81 |
+
]
|
82 |
+
}
|
83 |
+
],
|
84 |
+
"source": [
|
85 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
86 |
+
" dim = 512,\n",
|
87 |
+
" depth = 6,\n",
|
88 |
+
" heads = 8,\n",
|
89 |
+
" dim_head = 64,\n",
|
90 |
+
" spec_n_fft = 128,\n",
|
91 |
+
" spec_win_length = 24,\n",
|
92 |
+
" spec_aug_stretch_factor = 0.8\n",
|
93 |
+
")\n",
|
94 |
+
"\n",
|
95 |
+
"text_transformer = TextTransformer(\n",
|
96 |
+
" dim = 512,\n",
|
97 |
+
" depth = 6,\n",
|
98 |
+
" heads = 8,\n",
|
99 |
+
" dim_head = 64\n",
|
100 |
+
")\n",
|
101 |
+
"\n",
|
102 |
+
"mulan = MuLaN(\n",
|
103 |
+
" audio_transformer = audio_transformer,\n",
|
104 |
+
" text_transformer = text_transformer\n",
|
105 |
+
")\n",
|
106 |
+
"\n",
|
107 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
108 |
+
" mulan = mulan, \n",
|
109 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
110 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
111 |
+
")\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"def train_fine_transformer():\n",
|
115 |
+
" soundstream = MusicLMSoundStream(\n",
|
116 |
+
" codebook_size=1024, \n",
|
117 |
+
" strides=(3, 4, 5, 8),\n",
|
118 |
+
" target_sample_hz=24000,\n",
|
119 |
+
" rq_num_quantizers=8\n",
|
120 |
+
" )\n",
|
121 |
+
"\n",
|
122 |
+
" if torch.cuda.is_available():\n",
|
123 |
+
" fine_transformer = FineTransformer(\n",
|
124 |
+
" num_coarse_quantizers = 4,\n",
|
125 |
+
" num_fine_quantizers = 4,\n",
|
126 |
+
" codebook_size = 1024,\n",
|
127 |
+
" dim = 1024,\n",
|
128 |
+
" depth = 6,\n",
|
129 |
+
" audio_text_condition = True\n",
|
130 |
+
" ).cuda()\n",
|
131 |
+
" else:\n",
|
132 |
+
" fine_transformer = FineTransformer(\n",
|
133 |
+
" num_coarse_quantizers = 4,\n",
|
134 |
+
" num_fine_quantizers = 4,\n",
|
135 |
+
" codebook_size = 1024,\n",
|
136 |
+
" dim = 1024,\n",
|
137 |
+
" depth = 6,\n",
|
138 |
+
" audio_text_condition = True\n",
|
139 |
+
" )\n",
|
140 |
+
"\n",
|
141 |
+
" trainer = FineTransformerTrainer(\n",
|
142 |
+
" transformer=fine_transformer,\n",
|
143 |
+
" codec=soundstream,\n",
|
144 |
+
" folder=audio_output_dir,\n",
|
145 |
+
" batch_size=batch_size,\n",
|
146 |
+
" data_max_length=data_max_length,\n",
|
147 |
+
" num_train_steps=num_train_steps,\n",
|
148 |
+
" audio_conditioner = quantizer\n",
|
149 |
+
" )\n",
|
150 |
+
"\n",
|
151 |
+
" trainer.train()\n",
|
152 |
+
" torch.save(fine_transformer.state_dict(), 'fine_transformer.pth')\n",
|
153 |
+
" print(\"save fine_transformer.pth\")\n",
|
154 |
+
" del fine_transformer, trainer, soundstream\n",
|
155 |
+
" gc.collect()\n",
|
156 |
+
"\n",
|
157 |
+
"\n",
|
158 |
+
"train_fine_transformer()"
|
159 |
+
]
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"metadata": {
|
163 |
+
"kernelspec": {
|
164 |
+
"display_name": "myenv",
|
165 |
+
"language": "python",
|
166 |
+
"name": "python3"
|
167 |
+
},
|
168 |
+
"language_info": {
|
169 |
+
"codemirror_mode": {
|
170 |
+
"name": "ipython",
|
171 |
+
"version": 3
|
172 |
+
},
|
173 |
+
"file_extension": ".py",
|
174 |
+
"mimetype": "text/x-python",
|
175 |
+
"name": "python",
|
176 |
+
"nbconvert_exporter": "python",
|
177 |
+
"pygments_lexer": "ipython3",
|
178 |
+
"version": "3.11.2"
|
179 |
+
}
|
180 |
+
},
|
181 |
+
"nbformat": 4,
|
182 |
+
"nbformat_minor": 2
|
183 |
+
}
|
musiclm.ipynb
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# AudioLM"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"### Libraries:"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 2,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [
|
22 |
+
{
|
23 |
+
"name": "stderr",
|
24 |
+
"output_type": "stream",
|
25 |
+
"text": [
|
26 |
+
"2024-07-26 16:06:09 | WARNING | xformers | WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:\n",
|
27 |
+
" PyTorch 2.1.0+cu121 with CUDA 1201 (you have 2.1.0+cpu)\n",
|
28 |
+
" Python 3.11.6 (you have 3.11.2)\n",
|
29 |
+
" Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)\n",
|
30 |
+
" Memory-efficient attention, SwiGLU, sparse and more won't be available.\n",
|
31 |
+
" Set XFORMERS_MORE_DETAILS=1 for more details\n",
|
32 |
+
"2024-07-26 16:06:09 | WARNING | xformers | Triton is not available, some optimizations will not be enabled.\n",
|
33 |
+
"This is just a warning: triton is not available\n"
|
34 |
+
]
|
35 |
+
}
|
36 |
+
],
|
37 |
+
"source": [
|
38 |
+
"import torch\n",
|
39 |
+
"from audiolm_pytorch import HubertWithKmeans\n",
|
40 |
+
"from audiolm_pytorch import SemanticTransformer\n",
|
41 |
+
"from audiolm_pytorch import CoarseTransformer\n",
|
42 |
+
"from audiolm_pytorch import FineTransformer\n",
|
43 |
+
"from audiolm_pytorch import AudioLMSoundStream, AudioLM\n",
|
44 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
45 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 3,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
55 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 5,
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"wav2vec = HubertWithKmeans(checkpoint_path=checkpoint_path, kmeans_path=kmeans_path)\n",
|
65 |
+
"\n",
|
66 |
+
"soundstream = AudioLMSoundStream(\n",
|
67 |
+
" codebook_size=1024, # Add this line to specify the codebook size\n",
|
68 |
+
" strides=(2, 4, 5, 8),\n",
|
69 |
+
" target_sample_hz=16000,\n",
|
70 |
+
" rq_num_quantizers=8\n",
|
71 |
+
")\n",
|
72 |
+
"\n",
|
73 |
+
"\n",
|
74 |
+
"if torch.cuda.is_available():\n",
|
75 |
+
" semantic_transformer = SemanticTransformer(\n",
|
76 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
77 |
+
" dim=1024,\n",
|
78 |
+
" depth=6,\n",
|
79 |
+
" audio_text_condition=True\n",
|
80 |
+
" ).cuda()\n",
|
81 |
+
"\n",
|
82 |
+
" coarse_transformer = CoarseTransformer(\n",
|
83 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
84 |
+
" codebook_size=1024,\n",
|
85 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
86 |
+
" dim=1024,\n",
|
87 |
+
" depth=6,\n",
|
88 |
+
" audio_text_condition=True\n",
|
89 |
+
" ).cuda()\n",
|
90 |
+
"\n",
|
91 |
+
" fine_transformer = FineTransformer(\n",
|
92 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
93 |
+
" num_fine_quantizers=4,\n",
|
94 |
+
" codebook_size=1024,\n",
|
95 |
+
" dim=1024,\n",
|
96 |
+
" depth=6,\n",
|
97 |
+
" audio_text_condition=True\n",
|
98 |
+
" ).cuda()\n",
|
99 |
+
"else:\n",
|
100 |
+
" semantic_transformer = SemanticTransformer(\n",
|
101 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
102 |
+
" dim=1024,\n",
|
103 |
+
" depth=6,\n",
|
104 |
+
" audio_text_condition=True\n",
|
105 |
+
" )\n",
|
106 |
+
"\n",
|
107 |
+
" coarse_transformer = CoarseTransformer(\n",
|
108 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
109 |
+
" codebook_size=1024,\n",
|
110 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
111 |
+
" dim=1024,\n",
|
112 |
+
" depth=6,\n",
|
113 |
+
" audio_text_condition=True\n",
|
114 |
+
" )\n",
|
115 |
+
"\n",
|
116 |
+
" fine_transformer = FineTransformer(\n",
|
117 |
+
" num_coarse_quantizers=4, # Consistent with training\n",
|
118 |
+
" num_fine_quantizers=4,\n",
|
119 |
+
" codebook_size=1024,\n",
|
120 |
+
" dim=1024,\n",
|
121 |
+
" depth=6,\n",
|
122 |
+
" audio_text_condition=True\n",
|
123 |
+
" )\n",
|
124 |
+
"\n",
|
125 |
+
"semantic_transformer.load_state_dict(torch.load('semantic_transformer.pth'))\n",
|
126 |
+
"coarse_transformer.load_state_dict(torch.load('coarse_transformer.pth'))\n",
|
127 |
+
"fine_transformer.load_state_dict(torch.load('fine_transformer.pth'))\n",
|
128 |
+
"\n",
|
129 |
+
"audiolm = AudioLM(\n",
|
130 |
+
" wav2vec=wav2vec,\n",
|
131 |
+
" codec=soundstream,\n",
|
132 |
+
" semantic_transformer=semantic_transformer,\n",
|
133 |
+
" coarse_transformer=coarse_transformer,\n",
|
134 |
+
" fine_transformer=fine_transformer\n",
|
135 |
+
")\n"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "markdown",
|
140 |
+
"metadata": {},
|
141 |
+
"source": [
|
142 |
+
"# MuLaN"
|
143 |
+
]
|
144 |
+
},
|
145 |
+
{
|
146 |
+
"cell_type": "code",
|
147 |
+
"execution_count": 6,
|
148 |
+
"metadata": {},
|
149 |
+
"outputs": [],
|
150 |
+
"source": [
|
151 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
152 |
+
" dim = 512,\n",
|
153 |
+
" depth = 6,\n",
|
154 |
+
" heads = 8,\n",
|
155 |
+
" dim_head = 64,\n",
|
156 |
+
" spec_n_fft = 128,\n",
|
157 |
+
" spec_win_length = 24,\n",
|
158 |
+
" spec_aug_stretch_factor = 0.8\n",
|
159 |
+
")\n",
|
160 |
+
"\n",
|
161 |
+
"text_transformer = TextTransformer(\n",
|
162 |
+
" dim = 512,\n",
|
163 |
+
" depth = 6,\n",
|
164 |
+
" heads = 8,\n",
|
165 |
+
" dim_head = 64\n",
|
166 |
+
")\n",
|
167 |
+
"\n",
|
168 |
+
"mulan = MuLaN(\n",
|
169 |
+
" audio_transformer = audio_transformer,\n",
|
170 |
+
" text_transformer = text_transformer\n",
|
171 |
+
")\n",
|
172 |
+
"\n",
|
173 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
174 |
+
" mulan = mulan, \n",
|
175 |
+
" conditioning_dims = (1024, 1024, 1024), \n",
|
176 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
177 |
+
")\n"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "markdown",
|
182 |
+
"metadata": {},
|
183 |
+
"source": [
|
184 |
+
"# MusicLM"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": 7,
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [],
|
192 |
+
"source": [
|
193 |
+
"from musiclm_pytorch import MusicLM\n",
|
194 |
+
"\n",
|
195 |
+
"if torch.cuda.is_available():\n",
|
196 |
+
" musiclm = MusicLM(\n",
|
197 |
+
" audio_lm = audiolm,\n",
|
198 |
+
" mulan_embed_quantizer = quantizer\n",
|
199 |
+
" ).cuda()\n",
|
200 |
+
"else:\n",
|
201 |
+
" musiclm = MusicLM(\n",
|
202 |
+
" audio_lm = audiolm,\n",
|
203 |
+
" mulan_embed_quantizer = quantizer\n",
|
204 |
+
" )"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "markdown",
|
209 |
+
"metadata": {},
|
210 |
+
"source": [
|
211 |
+
"# Inference:"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 10,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [
|
219 |
+
{
|
220 |
+
"name": "stdout",
|
221 |
+
"output_type": "stream",
|
222 |
+
"text": [
|
223 |
+
" 31 / 403\r"
|
224 |
+
]
|
225 |
+
},
|
226 |
+
{
|
227 |
+
"ename": "KeyboardInterrupt",
|
228 |
+
"evalue": "",
|
229 |
+
"output_type": "error",
|
230 |
+
"traceback": [
|
231 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
232 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
233 |
+
"Cell \u001b[1;32mIn[10], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mMusiclm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\n\u001b[0;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcrazy EDM, heavy bang\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[0;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 5\u001b[0m display_audio(res, \u001b[38;5;241m32000\u001b[39m)\n",
|
234 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\genmodel.py:161\u001b[0m, in \u001b[0;36mBaseGenModel.generate\u001b[1;34m(self, descriptions, progress, return_tokens)\u001b[0m\n\u001b[0;32m 159\u001b[0m attributes, prompt_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_prepare_tokens_and_attributes(descriptions, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m 160\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m prompt_tokens \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m--> 161\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_generate_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprogress\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m return_tokens:\n\u001b[0;32m 163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerate_audio(tokens), tokens\n",
|
235 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\musicgen.py:256\u001b[0m, in \u001b[0;36mMusicGen._generate_tokens\u001b[1;34m(self, attributes, prompt_tokens, progress)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mduration \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_duration:\n\u001b[0;32m 254\u001b[0m \u001b[38;5;66;03m# generate by sampling from LM, simple case.\u001b[39;00m\n\u001b[0;32m 255\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mautocast:\n\u001b[1;32m--> 256\u001b[0m gen_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlm\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 257\u001b[0m \u001b[43m \u001b[49m\u001b[43mprompt_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattributes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_gen_len\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_gen_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgeneration_params\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 260\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 261\u001b[0m \u001b[38;5;66;03m# now this gets a bit messier, we need to handle prompts,\u001b[39;00m\n\u001b[0;32m 262\u001b[0m \u001b[38;5;66;03m# melody conditioning etc.\u001b[39;00m\n\u001b[0;32m 263\u001b[0m ref_wavs \u001b[38;5;241m=\u001b[39m [attr\u001b[38;5;241m.\u001b[39mwav[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mself_wav\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m attr \u001b[38;5;129;01min\u001b[39;00m attributes]\n",
|
236 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\utils\\_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[1;34m(*args, **kwargs)\u001b[0m\n\u001b[0;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[0;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[1;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
237 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:510\u001b[0m, in \u001b[0;36mLMModel.generate\u001b[1;34m(self, prompt, conditions, num_samples, max_gen_len, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg, remove_prompts, check, callback, **kwargs)\u001b[0m\n\u001b[0;32m 508\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (curr_sequence \u001b[38;5;241m==\u001b[39m unknown_token)\u001b[38;5;241m.\u001b[39many()\n\u001b[0;32m 509\u001b[0m \u001b[38;5;66;03m# sample next token from the model, next token shape is [B, K, 1]\u001b[39;00m\n\u001b[1;32m--> 510\u001b[0m next_token \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sample_next_token\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 511\u001b[0m \u001b[43m \u001b[49m\u001b[43mcurr_sequence\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcfg_conditions\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43munconditional_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_sampling\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtemp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtop_p\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 512\u001b[0m \u001b[43m \u001b[49m\u001b[43mcfg_coef\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcfg_coef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtwo_step_cfg\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 513\u001b[0m \u001b[38;5;66;03m# ensure the tokens that should be masked are properly set to special_token_id\u001b[39;00m\n\u001b[0;32m 514\u001b[0m \u001b[38;5;66;03m# as the model never output special_token_id\u001b[39;00m\n\u001b[0;32m 515\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m mask[\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m, offset:offset\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m]\u001b[38;5;241m.\u001b[39mexpand(B, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
|
238 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:369\u001b[0m, in \u001b[0;36mLMModel._sample_next_token\u001b[1;34m(self, sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, cfg_coef, two_step_cfg)\u001b[0m\n\u001b[0;32m 366\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 367\u001b[0m \u001b[38;5;66;03m# Preparing for CFG, predicting both conditional and unconditional logits.\u001b[39;00m\n\u001b[0;32m 368\u001b[0m sequence \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat([sequence, sequence], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m--> 369\u001b[0m all_logits \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 370\u001b[0m \u001b[43m \u001b[49m\u001b[43msequence\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 371\u001b[0m \u001b[43m \u001b[49m\u001b[43mconditions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcondition_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcondition_tensors\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 372\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m condition_tensors:\n\u001b[0;32m 373\u001b[0m cond_logits, uncond_logits \u001b[38;5;241m=\u001b[39m all_logits\u001b[38;5;241m.\u001b[39msplit(B, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m) \u001b[38;5;66;03m# [B, K, T, card]\u001b[39;00m\n",
|
239 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
240 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
241 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\models\\lm.py:257\u001b[0m, in \u001b[0;36mLMModel.forward\u001b[1;34m(self, sequence, conditions, condition_tensors, stage)\u001b[0m\n\u001b[0;32m 253\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m conditions, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mShouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt pass both conditions and condition_tensors.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m 255\u001b[0m input_, cross_attention_input \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfuser(input_, condition_tensors)\n\u001b[1;32m--> 257\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcross_attention_input\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 258\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattn_mask_per_stage\u001b[49m\u001b[43m[\u001b[49m\u001b[43mstage\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mstage\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 259\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm:\n\u001b[0;32m 260\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mout_norm(out)\n",
|
242 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
243 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
244 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:708\u001b[0m, in \u001b[0;36mStreamingTransformer.forward\u001b[1;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[0;32m 705\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpositional_scale \u001b[38;5;241m*\u001b[39m pos_emb\n\u001b[0;32m 707\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m--> 708\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_apply_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlayer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 710\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_is_streaming:\n\u001b[0;32m 711\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_streaming_state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moffsets\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m offsets \u001b[38;5;241m+\u001b[39m T\n",
|
245 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:665\u001b[0m, in \u001b[0;36mStreamingTransformer._apply_layer\u001b[1;34m(self, layer, *args, **kwargs)\u001b[0m\n\u001b[0;32m 663\u001b[0m method \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheckpointing\n\u001b[0;32m 664\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[1;32m--> 665\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlayer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 666\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m method \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtorch\u001b[39m\u001b[38;5;124m'\u001b[39m:\n\u001b[0;32m 667\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m torch_checkpoint(layer, \u001b[38;5;241m*\u001b[39margs, use_reentrant\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
|
246 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
247 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
248 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:563\u001b[0m, in \u001b[0;36mStreamingTransformerLayer.forward\u001b[1;34m(self, src, src_mask, src_key_padding_mask, cross_attention_src)\u001b[0m\n\u001b[0;32m 559\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_1(\n\u001b[0;32m 560\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sa_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm1(x), src_mask, src_key_padding_mask))\n\u001b[0;32m 561\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cross_attention_src \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 562\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_cross(\n\u001b[1;32m--> 563\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_cross_attention_block\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 564\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_cross\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[0;32m 565\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayer_scale_2(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ff_block(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnorm2(x)))\n\u001b[0;32m 566\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
249 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:546\u001b[0m, in \u001b[0;36mStreamingTransformerLayer._cross_attention_block\u001b[1;34m(self, src, cross_attention_src)\u001b[0m\n\u001b[0;32m 544\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcross_attention \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m 545\u001b[0m \u001b[38;5;66;03m# queries are from src, keys and values from cross_attention_src.\u001b[39;00m\n\u001b[1;32m--> 546\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcross_attention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 547\u001b[0m \u001b[43m \u001b[49m\u001b[43msrc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcross_attention_src\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mneed_weights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n\u001b[0;32m 548\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdropout_cross(x)\n",
|
250 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
251 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
252 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiocraft\\modules\\transformer.py:356\u001b[0m, in \u001b[0;36mStreamingMultiheadAttention.forward\u001b[1;34m(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)\u001b[0m\n\u001b[0;32m 354\u001b[0m q \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(query, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[:dim], bias_q)\n\u001b[0;32m 355\u001b[0m \u001b[38;5;66;03m# todo: when streaming, we could actually save k, v and check the shape actually match.\u001b[39;00m\n\u001b[1;32m--> 356\u001b[0m k \u001b[38;5;241m=\u001b[39m \u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfunctional\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43min_proj_weight\u001b[49m\u001b[43m[\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbias_k\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 357\u001b[0m v \u001b[38;5;241m=\u001b[39m nn\u001b[38;5;241m.\u001b[39mfunctional\u001b[38;5;241m.\u001b[39mlinear(value, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39min_proj_weight[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m dim:], bias_v)\n\u001b[0;32m 358\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mqk_layer_norm \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n",
|
253 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
254 |
+
]
|
255 |
+
}
|
256 |
+
],
|
257 |
+
"source": [
|
258 |
+
"music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4)"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": null,
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [],
|
266 |
+
"source": [
|
267 |
+
"torch.save(music, 'generated_music.pt')"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": null,
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"import torchaudio\n",
|
277 |
+
"output_path = \"out.wav\"\n",
|
278 |
+
"sample_rate = 44100\n",
|
279 |
+
"torchaudio.save(output_path, music.cpu() , sample_rate)"
|
280 |
+
]
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"kernelspec": {
|
285 |
+
"display_name": "myenv",
|
286 |
+
"language": "python",
|
287 |
+
"name": "python3"
|
288 |
+
},
|
289 |
+
"language_info": {
|
290 |
+
"codemirror_mode": {
|
291 |
+
"name": "ipython",
|
292 |
+
"version": 3
|
293 |
+
},
|
294 |
+
"file_extension": ".py",
|
295 |
+
"mimetype": "text/x-python",
|
296 |
+
"name": "python",
|
297 |
+
"nbconvert_exporter": "python",
|
298 |
+
"pygments_lexer": "ipython3",
|
299 |
+
"version": "3.11.2"
|
300 |
+
}
|
301 |
+
},
|
302 |
+
"nbformat": 4,
|
303 |
+
"nbformat_minor": 2
|
304 |
+
}
|
semantic_transformer.ipynb
ADDED
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Semantic Transformer"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"### Libraries"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 1,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"import torch\n",
|
24 |
+
"import multiprocessing\n",
|
25 |
+
"from audiolm_pytorch import HubertWithKmeans, MusicLMSoundStream\n",
|
26 |
+
"from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n",
|
27 |
+
"from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n",
|
28 |
+
"from audiolm_pytorch import FineTransformer, FineTransformerTrainer\n",
|
29 |
+
"from musiclm_pytorch import MuLaNEmbedQuantizer\n",
|
30 |
+
"from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer\n",
|
31 |
+
"import gc "
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 2,
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n",
|
41 |
+
"kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n",
|
42 |
+
"audio_output_dir = './audio'\n",
|
43 |
+
"batch_size = 1\n",
|
44 |
+
"data_max_length = 320 * 32\n",
|
45 |
+
"num_train_steps = 1000"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 3,
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [
|
53 |
+
{
|
54 |
+
"name": "stdout",
|
55 |
+
"output_type": "stream",
|
56 |
+
"text": [
|
57 |
+
"spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n",
|
58 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
59 |
+
"ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n",
|
60 |
+
"training with dataset of 4806 samples and validating with randomly splitted 253 samples\n",
|
61 |
+
"0: loss: 6.5572309494018555\n",
|
62 |
+
"0: valid loss 6.723005294799805\n",
|
63 |
+
"0: saving model to results\n",
|
64 |
+
"1: loss: 6.5375285148620605\n",
|
65 |
+
"2: loss: 5.515031337738037\n",
|
66 |
+
"3: loss: 0.6989991664886475\n",
|
67 |
+
"4: loss: 0.016623886302113533\n",
|
68 |
+
"5: loss: 6.3969268798828125\n",
|
69 |
+
"6: loss: 0.8643577098846436\n",
|
70 |
+
"7: loss: 0.008508207276463509\n",
|
71 |
+
"8: loss: 0.00020680516900029033\n",
|
72 |
+
"9: loss: 8.900370597839355\n",
|
73 |
+
"10: loss: 0.00010900969209615141\n",
|
74 |
+
"11: loss: 0.0001591881300555542\n",
|
75 |
+
"12: loss: 8.055902481079102\n",
|
76 |
+
"13: loss: 0.0009496303973719478\n",
|
77 |
+
"14: loss: 0.0027423782739788294\n",
|
78 |
+
"15: loss: 0.0009589337860234082\n",
|
79 |
+
"16: loss: 7.296541690826416\n",
|
80 |
+
"17: loss: 0.0005210856324993074\n",
|
81 |
+
"18: loss: 0.0008424322586506605\n",
|
82 |
+
"19: loss: 5.571179389953613\n",
|
83 |
+
"20: loss: 0.003094581188634038\n",
|
84 |
+
"21: loss: 0.0019461463671177626\n",
|
85 |
+
"22: loss: 5.488490104675293\n",
|
86 |
+
"23: loss: 4.800296783447266\n",
|
87 |
+
"24: loss: 4.962136268615723\n",
|
88 |
+
"25: loss: 5.943732738494873\n",
|
89 |
+
"26: loss: 0.006312617566436529\n",
|
90 |
+
"27: loss: 4.396454334259033\n",
|
91 |
+
"28: loss: 0.012498963624238968\n",
|
92 |
+
"29: loss: 0.0049488842487335205\n",
|
93 |
+
"30: loss: 0.0011625693878158927\n",
|
94 |
+
"31: loss: 3.445856809616089\n",
|
95 |
+
"32: loss: 0.000534387887455523\n",
|
96 |
+
"33: loss: 0.000711498549208045\n",
|
97 |
+
"34: loss: 0.0009514373959973454\n",
|
98 |
+
"35: loss: 0.001239188713952899\n",
|
99 |
+
"36: loss: 8.732012748718262\n",
|
100 |
+
"37: loss: 0.0009216524777002633\n",
|
101 |
+
"38: loss: 0.0006809335318394005\n",
|
102 |
+
"39: loss: 0.000797786982730031\n",
|
103 |
+
"40: loss: 4.916833400726318\n",
|
104 |
+
"41: loss: 0.0010107718408107758\n",
|
105 |
+
"42: loss: 0.0008451942121610045\n",
|
106 |
+
"43: loss: 3.160980701446533\n",
|
107 |
+
"44: loss: 0.0008387335110455751\n",
|
108 |
+
"45: loss: 0.0010360947344452143\n",
|
109 |
+
"46: loss: 0.001215349417179823\n",
|
110 |
+
"47: loss: 5.990973949432373\n",
|
111 |
+
"48: loss: 0.0017369053093716502\n",
|
112 |
+
"49: loss: 6.410669803619385\n",
|
113 |
+
"50: loss: 0.003450337564572692\n",
|
114 |
+
"51: loss: 0.003860922297462821\n",
|
115 |
+
"52: loss: 0.002359303878620267\n",
|
116 |
+
"53: loss: 0.001058467198163271\n",
|
117 |
+
"54: loss: 0.00047752217506058514\n",
|
118 |
+
"55: loss: 0.00025489379186183214\n",
|
119 |
+
"56: loss: 0.00016276698443107307\n",
|
120 |
+
"57: loss: 7.828070163726807\n",
|
121 |
+
"58: loss: 0.00011652028479147702\n",
|
122 |
+
"59: loss: 4.505963325500488\n",
|
123 |
+
"60: loss: 0.00013153781765140593\n",
|
124 |
+
"61: loss: 0.00015024915046524256\n",
|
125 |
+
"62: loss: 0.00017777853645384312\n",
|
126 |
+
"63: loss: 8.09732437133789\n",
|
127 |
+
"64: loss: 0.00041875039460137486\n",
|
128 |
+
"65: loss: 0.0009824583539739251\n",
|
129 |
+
"66: loss: 0.001990197692066431\n",
|
130 |
+
"67: loss: 5.392111778259277\n",
|
131 |
+
"68: loss: 0.0017270153621211648\n",
|
132 |
+
"69: loss: 0.0010434042196720839\n",
|
133 |
+
"70: loss: 0.0005951145431026816\n",
|
134 |
+
"71: loss: 0.00037293724017217755\n",
|
135 |
+
"72: loss: 0.00025969729176722467\n",
|
136 |
+
"73: loss: 7.013213157653809\n",
|
137 |
+
"74: loss: 3.807203531265259\n",
|
138 |
+
"75: loss: 0.00026780215557664633\n",
|
139 |
+
"76: loss: 0.00031897667213343084\n",
|
140 |
+
"77: loss: 0.0003657388442661613\n",
|
141 |
+
"78: loss: 5.076975345611572\n",
|
142 |
+
"79: loss: 0.001055362867191434\n",
|
143 |
+
"80: loss: 0.0010116726625710726\n",
|
144 |
+
"81: loss: 0.0017484871204942465\n",
|
145 |
+
"82: loss: 0.0018696936313062906\n",
|
146 |
+
"83: loss: 5.30266809463501\n",
|
147 |
+
"84: loss: 5.457505226135254\n",
|
148 |
+
"85: loss: 0.0012204349040985107\n",
|
149 |
+
"86: loss: 3.2936503887176514\n",
|
150 |
+
"87: loss: 0.0020471797324717045\n",
|
151 |
+
"88: loss: 0.0026046710554510355\n",
|
152 |
+
"89: loss: 0.0026721167378127575\n",
|
153 |
+
"90: loss: 0.0024667021352797747\n",
|
154 |
+
"91: loss: 5.0201215744018555\n",
|
155 |
+
"92: loss: 4.591504096984863\n",
|
156 |
+
"93: loss: 0.0025711969938129187\n",
|
157 |
+
"94: loss: 0.002706416416913271\n",
|
158 |
+
"95: loss: 0.0024713831953704357\n",
|
159 |
+
"96: loss: 0.002004373585805297\n",
|
160 |
+
"97: loss: 0.001489203074015677\n",
|
161 |
+
"98: loss: 0.0010426173685118556\n",
|
162 |
+
"99: loss: 8.796974182128906\n",
|
163 |
+
"100: loss: 0.0005365900578908622\n",
|
164 |
+
"100: valid loss 5.255128860473633\n",
|
165 |
+
"101: loss: 0.0004417159070726484\n",
|
166 |
+
"102: loss: 4.595282554626465\n",
|
167 |
+
"103: loss: 0.000659952696878463\n",
|
168 |
+
"104: loss: 0.0008260122267529368\n",
|
169 |
+
"105: loss: 0.0009083786280825734\n",
|
170 |
+
"106: loss: 4.042155742645264\n",
|
171 |
+
"107: loss: 4.17121696472168\n",
|
172 |
+
"108: loss: 0.0007671767962165177\n",
|
173 |
+
"109: loss: 4.022541522979736\n",
|
174 |
+
"110: loss: 3.5455234050750732\n",
|
175 |
+
"111: loss: 0.001035561435855925\n",
|
176 |
+
"112: loss: 0.0012967187212780118\n",
|
177 |
+
"113: loss: 7.237168312072754\n",
|
178 |
+
"114: loss: 3.522667407989502\n",
|
179 |
+
"115: loss: 0.004003542009741068\n",
|
180 |
+
"116: loss: 0.0040553268045187\n",
|
181 |
+
"117: loss: 0.0029700316954404116\n",
|
182 |
+
"118: loss: 0.0019125432008877397\n",
|
183 |
+
"119: loss: 3.4947195053100586\n",
|
184 |
+
"120: loss: 0.001095975050702691\n",
|
185 |
+
"121: loss: 0.0009612821158953011\n",
|
186 |
+
"122: loss: 0.000824352668132633\n",
|
187 |
+
"123: loss: 3.3077425956726074\n",
|
188 |
+
"124: loss: 0.0007418203167617321\n",
|
189 |
+
"125: loss: 0.0007488489500246942\n",
|
190 |
+
"126: loss: 0.0007235489320009947\n",
|
191 |
+
"127: loss: 3.426555633544922\n",
|
192 |
+
"128: loss: 0.0006980476318858564\n",
|
193 |
+
"129: loss: 0.0006986281368881464\n",
|
194 |
+
"130: loss: 0.0006706370622850955\n",
|
195 |
+
"131: loss: 0.0006185953388921916\n",
|
196 |
+
"132: loss: 4.421964645385742\n",
|
197 |
+
"133: loss: 0.0006264401017688215\n",
|
198 |
+
"134: loss: 0.0006876335828565061\n",
|
199 |
+
"135: loss: 0.0007215599762275815\n",
|
200 |
+
"136: loss: 0.0007203654968179762\n",
|
201 |
+
"137: loss: 0.0006922150496393442\n",
|
202 |
+
"138: loss: 0.0006356032681651413\n",
|
203 |
+
"139: loss: 3.7695367336273193\n",
|
204 |
+
"140: loss: 0.0006305422284640372\n",
|
205 |
+
"141: loss: 0.0006744156708009541\n",
|
206 |
+
"142: loss: 0.0006895355763845146\n",
|
207 |
+
"143: loss: 3.770907402038574\n",
|
208 |
+
"144: loss: 0.000908360059838742\n",
|
209 |
+
"145: loss: 0.0011299465550109744\n",
|
210 |
+
"146: loss: 0.0012696339981630445\n",
|
211 |
+
"147: loss: 0.0012722468236461282\n",
|
212 |
+
"148: loss: 3.8808021545410156\n",
|
213 |
+
"149: loss: 3.783026695251465\n",
|
214 |
+
"150: loss: 0.002035590121522546\n",
|
215 |
+
"151: loss: 0.0026034933980554342\n",
|
216 |
+
"152: loss: 0.0024936539120972157\n",
|
217 |
+
"153: loss: 0.0018582777120172977\n",
|
218 |
+
"154: loss: 2.8572535514831543\n",
|
219 |
+
"155: loss: 0.001062657218426466\n",
|
220 |
+
"156: loss: 0.0008821044466458261\n",
|
221 |
+
"157: loss: 0.0007058316841721535\n",
|
222 |
+
"158: loss: 0.0005539683043025434\n",
|
223 |
+
"159: loss: 5.476413726806641\n",
|
224 |
+
"160: loss: 0.00043070572428405285\n",
|
225 |
+
"161: loss: 0.00042034301441162825\n",
|
226 |
+
"162: loss: 0.0004015824815724045\n",
|
227 |
+
"163: loss: 0.0003759717510547489\n",
|
228 |
+
"164: loss: 0.00034577338374219835\n",
|
229 |
+
"165: loss: 3.9209775924682617\n",
|
230 |
+
"166: loss: 0.0003425567992962897\n",
|
231 |
+
"167: loss: 0.00036322552477940917\n",
|
232 |
+
"168: loss: 0.00037287475424818695\n",
|
233 |
+
"169: loss: 7.9045209884643555\n",
|
234 |
+
"170: loss: 0.0004473228473216295\n",
|
235 |
+
"171: loss: 0.0005134259699843824\n",
|
236 |
+
"172: loss: 2.9501657485961914\n",
|
237 |
+
"173: loss: 0.0008285943185910583\n",
|
238 |
+
"174: loss: 0.00113466486800462\n",
|
239 |
+
"175: loss: 0.0013167448341846466\n",
|
240 |
+
"176: loss: 0.0014080735854804516\n",
|
241 |
+
"177: loss: 4.0473408699035645\n",
|
242 |
+
"178: loss: 0.0016744763124734163\n",
|
243 |
+
"179: loss: 0.0016492144204676151\n",
|
244 |
+
"180: loss: 4.165207386016846\n",
|
245 |
+
"181: loss: 0.0017677460564300418\n",
|
246 |
+
"182: loss: 0.0018474040552973747\n",
|
247 |
+
"183: loss: 0.0017496442887932062\n",
|
248 |
+
"184: loss: 3.3882288932800293\n",
|
249 |
+
"185: loss: 0.0018872346263378859\n",
|
250 |
+
"186: loss: 3.0333187580108643\n",
|
251 |
+
"187: loss: 0.0028638774529099464\n",
|
252 |
+
"188: loss: 3.709534168243408\n",
|
253 |
+
"189: loss: 6.904417991638184\n",
|
254 |
+
"190: loss: 0.006619434338063002\n",
|
255 |
+
"191: loss: 0.00595641415566206\n",
|
256 |
+
"192: loss: 5.050801753997803\n",
|
257 |
+
"193: loss: 3.7556490898132324\n",
|
258 |
+
"194: loss: 0.002467694692313671\n",
|
259 |
+
"195: loss: 0.002025420544669032\n",
|
260 |
+
"196: loss: 0.001494809053838253\n",
|
261 |
+
"197: loss: 0.0010330628138035536\n",
|
262 |
+
"198: loss: 0.0006917425780557096\n",
|
263 |
+
"199: loss: 0.0004644835426006466\n",
|
264 |
+
"200: loss: 4.2029547691345215\n",
|
265 |
+
"200: valid loss 0.00025821171584539115\n",
|
266 |
+
"201: loss: 4.2771501541137695\n",
|
267 |
+
"202: loss: 3.7102839946746826\n",
|
268 |
+
"203: loss: 3.7408058643341064\n",
|
269 |
+
"204: loss: 0.0003981325135100633\n",
|
270 |
+
"205: loss: 0.0005507581518031657\n",
|
271 |
+
"206: loss: 4.065332889556885\n",
|
272 |
+
"207: loss: 0.0011804178357124329\n",
|
273 |
+
"208: loss: 0.0017080714460462332\n",
|
274 |
+
"209: loss: 0.0021062048617750406\n",
|
275 |
+
"210: loss: 0.0021494715474545956\n",
|
276 |
+
"211: loss: 6.465389251708984\n",
|
277 |
+
"212: loss: 0.0029505854472517967\n",
|
278 |
+
"213: loss: 3.367213010787964\n",
|
279 |
+
"214: loss: 0.27502918243408203\n",
|
280 |
+
"215: loss: 0.9933775663375854\n",
|
281 |
+
"216: loss: 0.810478925704956\n",
|
282 |
+
"217: loss: 0.4562891721725464\n",
|
283 |
+
"218: loss: 0.24387648701667786\n",
|
284 |
+
"219: loss: 0.11290910840034485\n",
|
285 |
+
"220: loss: 0.019248925149440765\n",
|
286 |
+
"221: loss: 0.0021138046868145466\n",
|
287 |
+
"222: loss: 5.169565200805664\n",
|
288 |
+
"223: loss: 0.0008601757581345737\n",
|
289 |
+
"224: loss: 3.9269232749938965\n",
|
290 |
+
"225: loss: 0.0007863161154091358\n",
|
291 |
+
"226: loss: 0.00024547570501454175\n",
|
292 |
+
"227: loss: 4.449281215667725\n",
|
293 |
+
"228: loss: 0.00019524114031810313\n",
|
294 |
+
"229: loss: 5.162830829620361\n",
|
295 |
+
"230: loss: 0.0005567128537222743\n",
|
296 |
+
"231: loss: 4.195521831512451\n",
|
297 |
+
"232: loss: 3.7389187812805176\n",
|
298 |
+
"233: loss: 5.919421672821045\n",
|
299 |
+
"234: loss: 6.7034173011779785\n",
|
300 |
+
"235: loss: 5.353506088256836\n",
|
301 |
+
"236: loss: 2.4018566608428955\n",
|
302 |
+
"237: loss: 3.7457311153411865\n",
|
303 |
+
"238: loss: 0.17652225494384766\n",
|
304 |
+
"239: loss: 4.564880847930908\n",
|
305 |
+
"240: loss: 0.027039170265197754\n",
|
306 |
+
"241: loss: 0.005270962603390217\n",
|
307 |
+
"242: loss: 0.0015485308831557631\n",
|
308 |
+
"243: loss: 0.0010360399028286338\n",
|
309 |
+
"244: loss: 0.0007773903198540211\n",
|
310 |
+
"245: loss: 6.206174850463867\n",
|
311 |
+
"246: loss: 6.409456253051758\n",
|
312 |
+
"247: loss: 0.04051050543785095\n",
|
313 |
+
"248: loss: 0.0017684113699942827\n",
|
314 |
+
"249: loss: 0.00044090740266256034\n",
|
315 |
+
"250: loss: 5.761023044586182\n",
|
316 |
+
"251: loss: 0.00016311556100845337\n",
|
317 |
+
"252: loss: 0.0001715785765554756\n",
|
318 |
+
"253: loss: 0.00019523760420270264\n",
|
319 |
+
"254: loss: 0.00023307953961193562\n",
|
320 |
+
"255: loss: 0.00028373271925374866\n",
|
321 |
+
"256: loss: 4.927147388458252\n",
|
322 |
+
"257: loss: 4.228280544281006\n",
|
323 |
+
"258: loss: 0.0011933923233300447\n",
|
324 |
+
"259: loss: 0.005215882323682308\n",
|
325 |
+
"260: loss: 0.0013388781808316708\n",
|
326 |
+
"261: loss: 4.206026554107666\n",
|
327 |
+
"262: loss: 0.0034830207005143166\n",
|
328 |
+
"263: loss: 4.173500061035156\n",
|
329 |
+
"264: loss: 0.007450783159583807\n",
|
330 |
+
"265: loss: 4.5892510414123535\n",
|
331 |
+
"266: loss: 0.006880312692373991\n",
|
332 |
+
"267: loss: 4.572935104370117\n",
|
333 |
+
"268: loss: 0.002904222346842289\n",
|
334 |
+
"269: loss: 3.2348222732543945\n",
|
335 |
+
"270: loss: 4.376621723175049\n",
|
336 |
+
"271: loss: 3.573988914489746\n",
|
337 |
+
"272: loss: 0.0010127610294148326\n",
|
338 |
+
"273: loss: 9.308874130249023\n",
|
339 |
+
"274: loss: 4.688360214233398\n",
|
340 |
+
"275: loss: 3.9581832885742188\n",
|
341 |
+
"276: loss: 0.01065391581505537\n",
|
342 |
+
"277: loss: 0.0067514535039663315\n",
|
343 |
+
"278: loss: 0.003611961379647255\n",
|
344 |
+
"279: loss: 0.001811509020626545\n",
|
345 |
+
"280: loss: 0.0009013370145112276\n",
|
346 |
+
"281: loss: 4.266546726226807\n",
|
347 |
+
"282: loss: 5.132745742797852\n",
|
348 |
+
"283: loss: 0.000957090116571635\n",
|
349 |
+
"284: loss: 0.0015025322791188955\n",
|
350 |
+
"285: loss: 6.258731842041016\n",
|
351 |
+
"286: loss: 5.029386043548584\n",
|
352 |
+
"287: loss: 0.007954631000757217\n",
|
353 |
+
"288: loss: 0.0050008054822683334\n",
|
354 |
+
"289: loss: 0.001655810745432973\n",
|
355 |
+
"290: loss: 5.501289367675781\n",
|
356 |
+
"291: loss: 4.655749797821045\n",
|
357 |
+
"292: loss: 4.383106231689453\n",
|
358 |
+
"293: loss: 0.000304496381431818\n",
|
359 |
+
"294: loss: 0.0003326725563965738\n",
|
360 |
+
"295: loss: 0.00035310350358486176\n",
|
361 |
+
"296: loss: 5.683162212371826\n",
|
362 |
+
"297: loss: 0.0004622728156391531\n",
|
363 |
+
"298: loss: 4.067113399505615\n",
|
364 |
+
"299: loss: 0.0008154112147167325\n",
|
365 |
+
"300: loss: 0.00108420941978693\n",
|
366 |
+
"300: valid loss 0.0013179074740037322\n",
|
367 |
+
"301: loss: 0.0013179074740037322\n",
|
368 |
+
"302: loss: 4.358561992645264\n",
|
369 |
+
"303: loss: 5.026749610900879\n",
|
370 |
+
"304: loss: 0.002862808993086219\n",
|
371 |
+
"305: loss: 0.003396229352802038\n",
|
372 |
+
"306: loss: 5.530904293060303\n",
|
373 |
+
"307: loss: 0.0035779180470854044\n",
|
374 |
+
"308: loss: 0.003205555956810713\n",
|
375 |
+
"309: loss: 4.112671852111816\n",
|
376 |
+
"310: loss: 3.6920313835144043\n",
|
377 |
+
"311: loss: 0.0026951604522764683\n",
|
378 |
+
"312: loss: 0.0026851999573409557\n",
|
379 |
+
"313: loss: 3.3092551231384277\n",
|
380 |
+
"314: loss: 0.0024079573340713978\n",
|
381 |
+
"315: loss: 0.0022026696242392063\n",
|
382 |
+
"316: loss: 0.0018284200923517346\n",
|
383 |
+
"317: loss: 0.0014258958399295807\n",
|
384 |
+
"318: loss: 0.0010761057492345572\n",
|
385 |
+
"319: loss: 0.0008039181702770293\n",
|
386 |
+
"320: loss: 0.0006038622814230621\n",
|
387 |
+
"321: loss: 0.00046244796249084175\n",
|
388 |
+
"322: loss: 5.89370059967041\n",
|
389 |
+
"323: loss: 0.00031747910543344915\n",
|
390 |
+
"324: loss: 0.00028221303364261985\n",
|
391 |
+
"325: loss: 0.00025451104738749564\n",
|
392 |
+
"326: loss: 0.00023175252135843039\n",
|
393 |
+
"327: loss: 0.00021364034910220653\n",
|
394 |
+
"328: loss: 3.906613826751709\n",
|
395 |
+
"329: loss: 3.844726085662842\n",
|
396 |
+
"330: loss: 0.00023705456987954676\n",
|
397 |
+
"331: loss: 0.0002663657069206238\n",
|
398 |
+
"332: loss: 0.0002947220054920763\n",
|
399 |
+
"333: loss: 6.28004264831543\n",
|
400 |
+
"334: loss: 0.0003821635036729276\n",
|
401 |
+
"335: loss: 3.633335828781128\n",
|
402 |
+
"336: loss: 0.0005681345355696976\n",
|
403 |
+
"337: loss: 6.994467735290527\n",
|
404 |
+
"338: loss: 7.915759086608887\n",
|
405 |
+
"339: loss: 0.0026061832904815674\n",
|
406 |
+
"340: loss: 0.0048998151905834675\n",
|
407 |
+
"341: loss: 0.004243680741637945\n",
|
408 |
+
"342: loss: 0.0025005636271089315\n",
|
409 |
+
"343: loss: 4.005818843841553\n",
|
410 |
+
"344: loss: 0.0011636920971795917\n",
|
411 |
+
"345: loss: 0.0009634271846152842\n",
|
412 |
+
"346: loss: 0.0008427661377936602\n",
|
413 |
+
"347: loss: 0.0007607618463225663\n",
|
414 |
+
"348: loss: 0.0006956492434255779\n",
|
415 |
+
"349: loss: 4.547393798828125\n",
|
416 |
+
"350: loss: 0.0006480301963165402\n",
|
417 |
+
"351: loss: 0.0006520788883790374\n",
|
418 |
+
"352: loss: 0.0006446384941227734\n",
|
419 |
+
"353: loss: 4.283820629119873\n",
|
420 |
+
"354: loss: 0.0007140468223951757\n",
|
421 |
+
"355: loss: 0.000788742327131331\n",
|
422 |
+
"356: loss: 0.0008332571596838534\n",
|
423 |
+
"357: loss: 0.0008390303701162338\n",
|
424 |
+
"358: loss: 0.000806896947324276\n",
|
425 |
+
"359: loss: 4.646646976470947\n",
|
426 |
+
"360: loss: 0.0021708165295422077\n",
|
427 |
+
"361: loss: 0.0009108624653890729\n",
|
428 |
+
"362: loss: 3.9582133293151855\n",
|
429 |
+
"363: loss: 3.3569955825805664\n",
|
430 |
+
"364: loss: 0.002499263733625412\n",
|
431 |
+
"365: loss: 4.646510601043701\n",
|
432 |
+
"366: loss: 0.0032457842025905848\n",
|
433 |
+
"367: loss: 0.0033331059385091066\n",
|
434 |
+
"368: loss: 0.00275675137527287\n",
|
435 |
+
"369: loss: 0.0020243506878614426\n",
|
436 |
+
"370: loss: 4.458893775939941\n",
|
437 |
+
"371: loss: 5.930361270904541\n",
|
438 |
+
"372: loss: 4.287806510925293\n",
|
439 |
+
"373: loss: 3.365216016769409\n",
|
440 |
+
"374: loss: 0.011499284766614437\n",
|
441 |
+
"375: loss: 0.0031067240051925182\n",
|
442 |
+
"376: loss: 0.003569819498807192\n",
|
443 |
+
"377: loss: 0.0032246895134449005\n",
|
444 |
+
"378: loss: 0.0023426800034940243\n",
|
445 |
+
"379: loss: 0.0016774036921560764\n",
|
446 |
+
"380: loss: 0.0010665183654055\n",
|
447 |
+
"381: loss: 0.0007539619691669941\n",
|
448 |
+
"382: loss: 3.873556137084961\n",
|
449 |
+
"383: loss: 0.08063449710607529\n",
|
450 |
+
"384: loss: 0.0005400768714025617\n",
|
451 |
+
"385: loss: 0.000518861401360482\n",
|
452 |
+
"386: loss: 0.00048329788842238486\n",
|
453 |
+
"387: loss: 4.2107648849487305\n",
|
454 |
+
"388: loss: 4.465734481811523\n",
|
455 |
+
"389: loss: 0.000529197626747191\n",
|
456 |
+
"390: loss: 3.872891664505005\n",
|
457 |
+
"391: loss: 5.214785099029541\n",
|
458 |
+
"392: loss: 4.345657825469971\n",
|
459 |
+
"393: loss: 0.0016826370265334845\n",
|
460 |
+
"394: loss: 0.0024580529425293207\n",
|
461 |
+
"395: loss: 0.002994671929627657\n",
|
462 |
+
"396: loss: 0.002981696743518114\n",
|
463 |
+
"397: loss: 0.002537172520533204\n",
|
464 |
+
"398: loss: 0.001975367311388254\n",
|
465 |
+
"399: loss: 0.0014994062948971987\n",
|
466 |
+
"400: loss: 0.0011500928085297346\n",
|
467 |
+
"400: valid loss 0.0009022268350236118\n",
|
468 |
+
"401: loss: 5.212808132171631\n",
|
469 |
+
"402: loss: 0.0008533270447514951\n",
|
470 |
+
"403: loss: 0.0008498210809193552\n",
|
471 |
+
"404: loss: 0.0008541711140424013\n",
|
472 |
+
"405: loss: 3.912627696990967\n",
|
473 |
+
"406: loss: 0.0008917151135392487\n",
|
474 |
+
"407: loss: 0.0009278871002607048\n",
|
475 |
+
"408: loss: 3.4623196125030518\n",
|
476 |
+
"409: loss: 0.0011483340058475733\n",
|
477 |
+
"410: loss: 0.0014651089441031218\n",
|
478 |
+
"411: loss: 3.501060962677002\n",
|
479 |
+
"412: loss: 4.905694484710693\n",
|
480 |
+
"413: loss: 0.0025538327172398567\n",
|
481 |
+
"414: loss: 0.0019650040194392204\n",
|
482 |
+
"415: loss: 0.001453581964597106\n",
|
483 |
+
"416: loss: 4.282127857208252\n",
|
484 |
+
"417: loss: 0.001117513864301145\n",
|
485 |
+
"418: loss: 3.2745401859283447\n",
|
486 |
+
"419: loss: 3.0665171146392822\n",
|
487 |
+
"420: loss: 0.001583368401043117\n",
|
488 |
+
"421: loss: 0.0018978181760758162\n",
|
489 |
+
"422: loss: 5.070369720458984\n",
|
490 |
+
"423: loss: 0.0025998111814260483\n",
|
491 |
+
"424: loss: 0.0028609540313482285\n",
|
492 |
+
"425: loss: 2.7316229343414307\n",
|
493 |
+
"426: loss: 0.003324385266751051\n",
|
494 |
+
"427: loss: 0.00243724649772048\n",
|
495 |
+
"428: loss: 0.0020084292627871037\n",
|
496 |
+
"429: loss: 0.001639676047489047\n",
|
497 |
+
"430: loss: 0.0012756038922816515\n",
|
498 |
+
"431: loss: 0.0010202551493421197\n",
|
499 |
+
"432: loss: 0.0008382818195968866\n",
|
500 |
+
"433: loss: 3.9101459980010986\n",
|
501 |
+
"434: loss: 3.4464950561523438\n",
|
502 |
+
"435: loss: 4.598957538604736\n",
|
503 |
+
"436: loss: 6.656869888305664\n",
|
504 |
+
"437: loss: 2.557544469833374\n",
|
505 |
+
"438: loss: 1.769715666770935\n",
|
506 |
+
"439: loss: 0.8786362409591675\n",
|
507 |
+
"440: loss: 0.09529905021190643\n",
|
508 |
+
"441: loss: 3.9526867866516113\n",
|
509 |
+
"442: loss: 3.4567954540252686\n",
|
510 |
+
"443: loss: 0.28547608852386475\n",
|
511 |
+
"444: loss: 0.1331639289855957\n",
|
512 |
+
"445: loss: 0.01748904585838318\n",
|
513 |
+
"446: loss: 3.7364015579223633\n",
|
514 |
+
"447: loss: 1.6454107761383057\n",
|
515 |
+
"448: loss: 0.007931341417133808\n",
|
516 |
+
"449: loss: 0.0017749288817867637\n",
|
517 |
+
"450: loss: 3.6518070697784424\n",
|
518 |
+
"451: loss: 3.056483507156372\n",
|
519 |
+
"452: loss: 0.0008364453678950667\n",
|
520 |
+
"453: loss: 0.0009152528364211321\n",
|
521 |
+
"454: loss: 0.0009797721868380904\n",
|
522 |
+
"455: loss: 4.194733142852783\n",
|
523 |
+
"456: loss: 0.0013897174503654242\n",
|
524 |
+
"457: loss: 0.0018761098617687821\n",
|
525 |
+
"458: loss: 0.0020015202462673187\n",
|
526 |
+
"459: loss: 9.263550758361816\n",
|
527 |
+
"460: loss: 0.0025061527267098427\n",
|
528 |
+
"461: loss: 0.003998400643467903\n",
|
529 |
+
"462: loss: 0.0031979954801499844\n",
|
530 |
+
"463: loss: 0.0009064731420949101\n",
|
531 |
+
"464: loss: 3.1668450832366943\n",
|
532 |
+
"465: loss: 6.006053924560547\n",
|
533 |
+
"466: loss: 0.0006406777538359165\n",
|
534 |
+
"467: loss: 0.0009267539135180414\n",
|
535 |
+
"468: loss: 0.0012060123262926936\n",
|
536 |
+
"469: loss: 0.0013315295800566673\n",
|
537 |
+
"470: loss: 3.5539376735687256\n",
|
538 |
+
"471: loss: 3.4590916633605957\n",
|
539 |
+
"472: loss: 0.0017678193980827928\n",
|
540 |
+
"473: loss: 0.00218581547960639\n",
|
541 |
+
"474: loss: 0.0025737383402884007\n",
|
542 |
+
"475: loss: 2.97592830657959\n",
|
543 |
+
"476: loss: 0.0032222135923802853\n",
|
544 |
+
"477: loss: 0.0020487091969698668\n",
|
545 |
+
"478: loss: 3.0420033931732178\n",
|
546 |
+
"479: loss: 0.001554043497890234\n",
|
547 |
+
"480: loss: 0.001528518507257104\n",
|
548 |
+
"481: loss: 0.001422215485945344\n",
|
549 |
+
"482: loss: 0.0012641653884202242\n",
|
550 |
+
"483: loss: 0.0010866222437471151\n",
|
551 |
+
"484: loss: 7.149199962615967\n",
|
552 |
+
"485: loss: 0.0010687584290280938\n",
|
553 |
+
"486: loss: 0.0012197017204016447\n",
|
554 |
+
"487: loss: 0.001343191834166646\n",
|
555 |
+
"488: loss: 0.0013996028574183583\n",
|
556 |
+
"489: loss: 0.001371717662550509\n",
|
557 |
+
"490: loss: 3.68569278717041\n",
|
558 |
+
"491: loss: 0.0014253916451707482\n",
|
559 |
+
"492: loss: 0.001504680491052568\n",
|
560 |
+
"493: loss: 0.0014929386088624597\n",
|
561 |
+
"494: loss: 0.0013759569264948368\n",
|
562 |
+
"495: loss: 3.385620355606079\n",
|
563 |
+
"496: loss: 0.0012212302535772324\n",
|
564 |
+
"497: loss: 0.0011952322674915195\n",
|
565 |
+
"498: loss: 3.1083197593688965\n",
|
566 |
+
"499: loss: 8.146794319152832\n",
|
567 |
+
"500: loss: 3.8151681423187256\n",
|
568 |
+
"500: valid loss 3.2241313457489014\n",
|
569 |
+
"501: loss: 0.002565972041338682\n",
|
570 |
+
"502: loss: 4.1275224685668945\n",
|
571 |
+
"503: loss: 0.004586916882544756\n",
|
572 |
+
"504: loss: 3.6200292110443115\n",
|
573 |
+
"505: loss: 0.004917770624160767\n",
|
574 |
+
"506: loss: 0.0035543786361813545\n",
|
575 |
+
"507: loss: 0.002198878675699234\n",
|
576 |
+
"508: loss: 3.9696688652038574\n",
|
577 |
+
"509: loss: 0.0012150105321779847\n",
|
578 |
+
"510: loss: 3.0237858295440674\n",
|
579 |
+
"511: loss: 0.0016711285570636392\n",
|
580 |
+
"512: loss: 0.0017911652103066444\n",
|
581 |
+
"513: loss: 0.001645330572500825\n",
|
582 |
+
"514: loss: 3.3689823150634766\n",
|
583 |
+
"515: loss: 0.0014145843451842666\n",
|
584 |
+
"516: loss: 0.0013438486494123936\n",
|
585 |
+
"517: loss: 0.0011701782932505012\n",
|
586 |
+
"518: loss: 0.0009688445716165006\n",
|
587 |
+
"519: loss: 0.0007915324531495571\n",
|
588 |
+
"520: loss: 4.113221645355225\n",
|
589 |
+
"521: loss: 0.0006360645638778806\n",
|
590 |
+
"522: loss: 0.0006149905384518206\n",
|
591 |
+
"523: loss: 8.360527038574219\n",
|
592 |
+
"524: loss: 0.0006234433385543525\n",
|
593 |
+
"525: loss: 0.0006739232921972871\n",
|
594 |
+
"526: loss: 0.0007281479192897677\n",
|
595 |
+
"527: loss: 0.000767726160120219\n",
|
596 |
+
"528: loss: 0.000772368221078068\n",
|
597 |
+
"529: loss: 0.0007228502072393894\n",
|
598 |
+
"530: loss: 0.0006368369213305414\n",
|
599 |
+
"531: loss: 3.732311725616455\n",
|
600 |
+
"532: loss: 5.932078838348389\n",
|
601 |
+
"533: loss: 3.5892159938812256\n",
|
602 |
+
"534: loss: 5.249965667724609\n",
|
603 |
+
"535: loss: 7.211183071136475\n",
|
604 |
+
"536: loss: 4.0714263916015625\n",
|
605 |
+
"537: loss: 3.1499719619750977\n",
|
606 |
+
"538: loss: 0.1844794750213623\n",
|
607 |
+
"539: loss: 3.4192230701446533\n",
|
608 |
+
"540: loss: 0.011980107054114342\n",
|
609 |
+
"541: loss: 0.010612019337713718\n",
|
610 |
+
"542: loss: 0.0045662750490009785\n",
|
611 |
+
"543: loss: 0.005457601509988308\n",
|
612 |
+
"544: loss: 0.015783555805683136\n",
|
613 |
+
"545: loss: 0.0013816619757562876\n",
|
614 |
+
"546: loss: 8.18481731414795\n",
|
615 |
+
"547: loss: 0.0006438567652367055\n",
|
616 |
+
"548: loss: 0.000572906865272671\n",
|
617 |
+
"549: loss: 10.10994815826416\n",
|
618 |
+
"550: loss: 0.003346000798046589\n",
|
619 |
+
"551: loss: 0.0006713962065987289\n",
|
620 |
+
"552: loss: 0.00026078836526721716\n",
|
621 |
+
"553: loss: 11.756505012512207\n",
|
622 |
+
"554: loss: 7.101832389831543\n",
|
623 |
+
"555: loss: 0.00021459207346197218\n",
|
624 |
+
"556: loss: 0.00025998923229053617\n",
|
625 |
+
"557: loss: 0.0003112201811745763\n",
|
626 |
+
"558: loss: 14.851192474365234\n",
|
627 |
+
"559: loss: 0.0004224810691084713\n",
|
628 |
+
"560: loss: 0.00047494613681919873\n",
|
629 |
+
"561: loss: 0.000519308028742671\n",
|
630 |
+
"562: loss: 0.0005509845213964581\n",
|
631 |
+
"563: loss: 0.0005668219528160989\n",
|
632 |
+
"564: loss: 14.569344520568848\n",
|
633 |
+
"565: loss: 6.4913740158081055\n",
|
634 |
+
"566: loss: 0.0008433411712758243\n",
|
635 |
+
"567: loss: 8.495502471923828\n",
|
636 |
+
"568: loss: 0.0019402098841965199\n",
|
637 |
+
"569: loss: 0.0035519124940037727\n",
|
638 |
+
"570: loss: 0.006841914728283882\n",
|
639 |
+
"571: loss: 4.089066982269287\n",
|
640 |
+
"572: loss: 5.491721153259277\n",
|
641 |
+
"573: loss: 3.87937331199646\n",
|
642 |
+
"574: loss: 0.03460773825645447\n",
|
643 |
+
"575: loss: 0.015647828578948975\n",
|
644 |
+
"576: loss: 0.002720448188483715\n",
|
645 |
+
"577: loss: 6.188972473144531\n",
|
646 |
+
"578: loss: 0.0008381525985896587\n",
|
647 |
+
"579: loss: 0.0008579537970945239\n",
|
648 |
+
"580: loss: 0.0008331844583153725\n",
|
649 |
+
"581: loss: 7.444668769836426\n",
|
650 |
+
"582: loss: 0.0013645365834236145\n",
|
651 |
+
"583: loss: 0.0018909723730757833\n",
|
652 |
+
"584: loss: 4.148159503936768\n",
|
653 |
+
"585: loss: 6.465692043304443\n",
|
654 |
+
"586: loss: 0.0040971520356833935\n",
|
655 |
+
"587: loss: 0.015496809035539627\n",
|
656 |
+
"588: loss: 0.0011185817420482635\n",
|
657 |
+
"589: loss: 0.00048535081441514194\n",
|
658 |
+
"590: loss: 0.0002821610542014241\n",
|
659 |
+
"591: loss: 0.00022055530280340463\n",
|
660 |
+
"592: loss: 0.0002070294285658747\n",
|
661 |
+
"593: loss: 0.00021876658138353378\n",
|
662 |
+
"594: loss: 0.00024527875939384103\n",
|
663 |
+
"595: loss: 0.00028197691426612437\n",
|
664 |
+
"596: loss: 0.00031235843198373914\n",
|
665 |
+
"597: loss: 0.00032129406463354826\n",
|
666 |
+
"598: loss: 0.000305092049529776\n",
|
667 |
+
"599: loss: 6.581624507904053\n",
|
668 |
+
"600: loss: 0.0004181505355518311\n",
|
669 |
+
"600: valid loss 0.001562803634442389\n",
|
670 |
+
"601: loss: 0.001562803634442389\n",
|
671 |
+
"602: loss: 0.0008329854463227093\n",
|
672 |
+
"603: loss: 8.43118953704834\n",
|
673 |
+
"604: loss: 0.00018880203424487263\n",
|
674 |
+
"605: loss: 6.225329399108887\n",
|
675 |
+
"606: loss: 0.0001953585451701656\n",
|
676 |
+
"607: loss: 0.00031005332130007446\n",
|
677 |
+
"608: loss: 6.243394374847412\n",
|
678 |
+
"609: loss: 0.002007008297368884\n",
|
679 |
+
"610: loss: 0.2842656672000885\n",
|
680 |
+
"611: loss: 0.002102950122207403\n",
|
681 |
+
"612: loss: 0.0013235295191407204\n",
|
682 |
+
"613: loss: 0.0012432391522452235\n",
|
683 |
+
"614: loss: 0.0011076040100306273\n",
|
684 |
+
"615: loss: 0.0009366637095808983\n",
|
685 |
+
"616: loss: 0.0007713991799391806\n",
|
686 |
+
"617: loss: 0.0006266268319450319\n",
|
687 |
+
"618: loss: 0.0005072436179034412\n",
|
688 |
+
"619: loss: 0.00041213506483472884\n",
|
689 |
+
"620: loss: 0.0003370844351593405\n",
|
690 |
+
"621: loss: 0.0002783465606626123\n",
|
691 |
+
"622: loss: 6.750359535217285\n",
|
692 |
+
"623: loss: 4.032569408416748\n",
|
693 |
+
"624: loss: 4.749107360839844\n",
|
694 |
+
"625: loss: 5.599199295043945\n",
|
695 |
+
"626: loss: 4.851316452026367\n",
|
696 |
+
"627: loss: 0.0012356003280729055\n",
|
697 |
+
"628: loss: 0.0019876735750585794\n",
|
698 |
+
"629: loss: 0.0022025934886187315\n",
|
699 |
+
"630: loss: 0.09389199316501617\n",
|
700 |
+
"631: loss: 0.0011942394776269794\n",
|
701 |
+
"632: loss: 0.0008771757711656392\n",
|
702 |
+
"633: loss: 0.000724500569049269\n",
|
703 |
+
"634: loss: 4.850365161895752\n",
|
704 |
+
"635: loss: 6.96458101272583\n",
|
705 |
+
"636: loss: 3.944305658340454\n",
|
706 |
+
"637: loss: 1.573992133140564\n",
|
707 |
+
"638: loss: 0.006376080680638552\n",
|
708 |
+
"639: loss: 0.004621799103915691\n",
|
709 |
+
"640: loss: 0.008686978369951248\n",
|
710 |
+
"641: loss: 0.002786734839901328\n",
|
711 |
+
"642: loss: 0.0012673415476456285\n",
|
712 |
+
"643: loss: 0.0008905518334358931\n"
|
713 |
+
]
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"ename": "KeyboardInterrupt",
|
717 |
+
"evalue": "",
|
718 |
+
"output_type": "error",
|
719 |
+
"traceback": [
|
720 |
+
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
721 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
722 |
+
"Cell \u001b[1;32mIn[3], line 78\u001b[0m\n\u001b[0;32m 72\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m semantic_transformer, trainer, wav2vec\n\u001b[0;32m 73\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 78\u001b[0m \u001b[43mtrain_semantic_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
723 |
+
"Cell \u001b[1;32mIn[3], line 69\u001b[0m, in \u001b[0;36mtrain_semantic_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 52\u001b[0m semantic_transformer \u001b[38;5;241m=\u001b[39m SemanticTransformer(\n\u001b[0;32m 53\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 54\u001b[0m dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[0;32m 55\u001b[0m depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m6\u001b[39m,\n\u001b[0;32m 56\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 57\u001b[0m )\n\u001b[0;32m 59\u001b[0m trainer \u001b[38;5;241m=\u001b[39m SemanticTransformerTrainer(\n\u001b[0;32m 60\u001b[0m transformer\u001b[38;5;241m=\u001b[39msemantic_transformer,\n\u001b[0;32m 61\u001b[0m wav2vec\u001b[38;5;241m=\u001b[39mwav2vec,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 66\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 67\u001b[0m )\n\u001b[1;32m---> 69\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 70\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(semantic_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msemantic_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 71\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave semantic_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
724 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1000\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 997\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 999\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1000\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1001\u001b[0m log_fn(logs)\n\u001b[0;32m 1003\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
725 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:944\u001b[0m, in \u001b[0;36mSemanticTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 941\u001b[0m data_kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_tuple_to_kwargs(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdl_iter))\n\u001b[0;32m 943\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[1;32m--> 944\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_wrapper\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdata_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_loss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 946\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mbackward(loss \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every)\n\u001b[0;32m 948\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n",
|
726 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
727 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
728 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\audiolm_pytorch.py:1480\u001b[0m, in \u001b[0;36mSemanticTransformerWrapper.forward\u001b[1;34m(self, semantic_token_ids, raw_wave, text, text_embeds, return_loss, **kwargs)\u001b[0m\n\u001b[0;32m 1478\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(raw_wave)\n\u001b[0;32m 1479\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(text_embeds)\n\u001b[1;32m-> 1480\u001b[0m text_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio_conditioner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mraw_wave\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnamespace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msemantic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1482\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exists(semantic_token_ids):\n\u001b[0;32m 1483\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwav2vec), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mVQWav2Vec must be be provided if given raw wave for training\u001b[39m\u001b[38;5;124m'\u001b[39m\n",
|
729 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
730 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
731 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:872\u001b[0m, in \u001b[0;36mMuLaNEmbedQuantizer.forward\u001b[1;34m(self, wavs, texts, namespace)\u001b[0m\n\u001b[0;32m 869\u001b[0m \u001b[38;5;66;03m# sound and language live in joint embedding space because of contrastive learning\u001b[39;00m\n\u001b[0;32m 871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(wavs):\n\u001b[1;32m--> 872\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmulan\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_audio_latents\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 873\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m exists(texts):\n\u001b[0;32m 874\u001b[0m latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmulan\u001b[38;5;241m.\u001b[39mget_text_latents(texts)\n",
|
732 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:732\u001b[0m, in \u001b[0;36mMuLaN.get_audio_latents\u001b[1;34m(self, wavs, return_all_layers)\u001b[0m\n\u001b[0;32m 727\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_audio_latents\u001b[39m(\n\u001b[0;32m 728\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m 729\u001b[0m wavs,\n\u001b[0;32m 730\u001b[0m return_all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m 731\u001b[0m ):\n\u001b[1;32m--> 732\u001b[0m audio_embeds, audio_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maudio\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwavs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 733\u001b[0m audio_latents \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maudio_to_latents(audio_embeds)\n\u001b[0;32m 734\u001b[0m out \u001b[38;5;241m=\u001b[39m l2norm(audio_latents)\n",
|
733 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
734 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
735 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:525\u001b[0m, in \u001b[0;36mAudioSpectrogramTransformer.forward\u001b[1;34m(self, x, force_no_patch_dropout, return_all_layers)\u001b[0m\n\u001b[0;32m 521\u001b[0m rel_pos_bias \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdynamic_pos_bias_mlp(rel_dist\u001b[38;5;241m.\u001b[39mfloat())\n\u001b[0;32m 523\u001b[0m \u001b[38;5;66;03m# attention, what else\u001b[39;00m\n\u001b[1;32m--> 525\u001b[0m x, all_layers \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mrel_pos_bias\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_layers\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[38;5;66;03m# final global average and norm (most recent papers show this is superior to CLS token)\u001b[39;00m\n\u001b[0;32m 529\u001b[0m x \u001b[38;5;241m=\u001b[39m reduce(x, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mb n d -> b d\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmean\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
736 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
737 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
738 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\musiclm_pytorch\\musiclm_pytorch.py:247\u001b[0m, in \u001b[0;36mTransformer.forward\u001b[1;34m(self, x, rel_pos_bias, mask, return_all_layers)\u001b[0m\n\u001b[0;32m 245\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m attn, ff \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[0;32m 246\u001b[0m x \u001b[38;5;241m=\u001b[39m attn(x, rel_pos_bias \u001b[38;5;241m=\u001b[39m rel_pos_bias, mask \u001b[38;5;241m=\u001b[39m mask) \u001b[38;5;241m+\u001b[39m x\n\u001b[1;32m--> 247\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[43mff\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m x\n\u001b[0;32m 248\u001b[0m layers\u001b[38;5;241m.\u001b[39mappend(x)\n\u001b[0;32m 250\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_all_layers:\n",
|
739 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
740 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
741 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\container.py:217\u001b[0m, in \u001b[0;36mSequential.forward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[0;32m 216\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[1;32m--> 217\u001b[0m \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 218\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
|
742 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
743 |
+
"File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
744 |
+
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
745 |
+
]
|
746 |
+
}
|
747 |
+
],
|
748 |
+
"source": [
|
749 |
+
"audio_transformer = AudioSpectrogramTransformer(\n",
|
750 |
+
" dim = 512,\n",
|
751 |
+
" depth = 6,\n",
|
752 |
+
" heads = 8,\n",
|
753 |
+
" dim_head = 64,\n",
|
754 |
+
" spec_n_fft = 128,\n",
|
755 |
+
" spec_win_length = 24,\n",
|
756 |
+
" spec_aug_stretch_factor = 0.8\n",
|
757 |
+
")\n",
|
758 |
+
"\n",
|
759 |
+
"text_transformer = TextTransformer(\n",
|
760 |
+
" dim = 512,\n",
|
761 |
+
" depth = 6,\n",
|
762 |
+
" heads = 8,\n",
|
763 |
+
" dim_head = 64\n",
|
764 |
+
")\n",
|
765 |
+
"\n",
|
766 |
+
"mulan = MuLaN(\n",
|
767 |
+
" audio_transformer = audio_transformer,\n",
|
768 |
+
" text_transformer = text_transformer\n",
|
769 |
+
")\n",
|
770 |
+
"\n",
|
771 |
+
"# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)\n",
|
772 |
+
"\n",
|
773 |
+
"quantizer = MuLaNEmbedQuantizer(\n",
|
774 |
+
" mulan = mulan, # pass in trained mulan from above\n",
|
775 |
+
" conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024\n",
|
776 |
+
" namespaces = ('semantic', 'coarse', 'fine')\n",
|
777 |
+
")\n",
|
778 |
+
"\n",
|
779 |
+
"# now say you want the conditioning embeddings for semantic transformer\n",
|
780 |
+
"\n",
|
781 |
+
"wavs = torch.randn(2, 1024)\n",
|
782 |
+
"conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers\n",
|
783 |
+
"\n",
|
784 |
+
"# SemanticTransformer\n",
|
785 |
+
"def train_semantic_transformer():\n",
|
786 |
+
" wav2vec = HubertWithKmeans(\n",
|
787 |
+
" checkpoint_path=checkpoint_path,\n",
|
788 |
+
" kmeans_path=kmeans_path\n",
|
789 |
+
" )\n",
|
790 |
+
"\n",
|
791 |
+
"\n",
|
792 |
+
" if torch.cuda.is_available():\n",
|
793 |
+
" semantic_transformer = SemanticTransformer(\n",
|
794 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
795 |
+
" dim=1024,\n",
|
796 |
+
" depth=6,\n",
|
797 |
+
" audio_text_condition=True\n",
|
798 |
+
" ).cuda()\n",
|
799 |
+
" else:\n",
|
800 |
+
" semantic_transformer = SemanticTransformer(\n",
|
801 |
+
" num_semantic_tokens=wav2vec.codebook_size,\n",
|
802 |
+
" dim=1024,\n",
|
803 |
+
" depth=6,\n",
|
804 |
+
" audio_text_condition=True\n",
|
805 |
+
" )\n",
|
806 |
+
"\n",
|
807 |
+
" trainer = SemanticTransformerTrainer(\n",
|
808 |
+
" transformer=semantic_transformer,\n",
|
809 |
+
" wav2vec=wav2vec,\n",
|
810 |
+
" audio_conditioner=quantizer,\n",
|
811 |
+
" folder=audio_output_dir,\n",
|
812 |
+
" batch_size=batch_size,\n",
|
813 |
+
" data_max_length=data_max_length,\n",
|
814 |
+
" num_train_steps=num_train_steps\n",
|
815 |
+
" )\n",
|
816 |
+
"\n",
|
817 |
+
" trainer.train()\n",
|
818 |
+
" torch.save(semantic_transformer.state_dict(), 'semantic_transformer.pth')\n",
|
819 |
+
" print(\"save semantic_transformer.pth\")\n",
|
820 |
+
" del semantic_transformer, trainer, wav2vec\n",
|
821 |
+
" gc.collect()\n",
|
822 |
+
"\n",
|
823 |
+
"\n",
|
824 |
+
"\n",
|
825 |
+
"\n",
|
826 |
+
"train_semantic_transformer()"
|
827 |
+
]
|
828 |
+
}
|
829 |
+
],
|
830 |
+
"metadata": {
|
831 |
+
"kernelspec": {
|
832 |
+
"display_name": "myenv",
|
833 |
+
"language": "python",
|
834 |
+
"name": "python3"
|
835 |
+
},
|
836 |
+
"language_info": {
|
837 |
+
"codemirror_mode": {
|
838 |
+
"name": "ipython",
|
839 |
+
"version": 3
|
840 |
+
},
|
841 |
+
"file_extension": ".py",
|
842 |
+
"mimetype": "text/x-python",
|
843 |
+
"name": "python",
|
844 |
+
"nbconvert_exporter": "python",
|
845 |
+
"pygments_lexer": "ipython3",
|
846 |
+
"version": "3.11.2"
|
847 |
+
}
|
848 |
+
},
|
849 |
+
"nbformat": 4,
|
850 |
+
"nbformat_minor": 2
|
851 |
+
}
|