{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Coarse Transformer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Libraries:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from audiolm_pytorch import HubertWithKmeans\n", "from audiolm_pytorch import SemanticTransformer, SemanticTransformerTrainer\n", "from audiolm_pytorch import CoarseTransformer, CoarseTransformerTrainer\n", "from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerTrainer\n", "from audiolm_pytorch import AudioLMSoundStream, AudioLM, MusicLMSoundStream\n", "import gc\n", "from musiclm_pytorch import MuLaNEmbedQuantizer\n", "from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = './models/hubert/hubert_base_ls960.pt'\n", "kmeans_path = './models/hubert/hubert_base_ls960_L9_km500.bin'\n", "\n", "audio_output_dir = './audio'\n", "batch_size = 1\n", "data_max_length = 320 * 32\n", "num_train_steps = 1000" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "spectrogram yielded shape of (65, 86), but had to be cropped to (64, 80) to be patchified for transformer\n" ] } ], "source": [ "audio_transformer = AudioSpectrogramTransformer(\n", " dim = 512,\n", " depth = 6,\n", " heads = 8,\n", " dim_head = 64,\n", " spec_n_fft = 128,\n", " spec_win_length = 24,\n", " spec_aug_stretch_factor = 0.8\n", ")\n", "\n", "text_transformer = TextTransformer(\n", " dim = 512,\n", " depth = 6,\n", " heads = 8,\n", " dim_head = 64\n", ")\n", "\n", "mulan = MuLaN(\n", " audio_transformer = audio_transformer,\n", " text_transformer = text_transformer\n", ")\n", "\n", "quantizer = MuLaNEmbedQuantizer(\n", " mulan = mulan, \n", " conditioning_dims = (1024, 1024, 1024), \n", " namespaces = ('semantic', 'coarse', 'fine')\n", ")\n", "wavs = torch.randn(2, 1024)\n", "conds = quantizer(wavs = wavs, namespace = 'semantic')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n", "ANTLR runtime and generated code versions disagree: 4.9.3!=4.8\n", "training with dataset of 4806 samples and validating with randomly splitted 253 samples\n", "0: loss: 90.55248260498047\n", "0: valid loss 28.765926361083984\n", "0: saving model to results\n", "1: loss: 39.71841812133789\n", "2: loss: 89.22168731689453\n", "3: loss: 64.72769927978516\n", "4: loss: 46.61131286621094\n", "5: loss: 71.61656951904297\n", "6: loss: 51.03081130981445\n", "7: loss: 41.790443420410156\n", "8: loss: 53.92983627319336\n", "9: loss: 34.468536376953125\n", "10: loss: 33.230533599853516\n", "11: loss: 39.82740020751953\n", "12: loss: 25.284324645996094\n", "13: loss: 28.97213363647461\n", "14: loss: 30.330350875854492\n", "15: loss: 29.048341751098633\n", "16: loss: 22.92132568359375\n", "17: loss: 19.784038543701172\n", "18: loss: 24.917173385620117\n", "19: loss: 21.861900329589844\n", "20: loss: 21.64893913269043\n", "21: loss: 19.426795959472656\n", "22: loss: 16.47875213623047\n", "23: loss: 14.150989532470703\n", "24: loss: 16.4312686920166\n", "25: loss: 10.732200622558594\n", "26: loss: 9.64625358581543\n", "27: loss: 13.40906047821045\n", "28: loss: 8.942117691040039\n", "29: loss: 14.944022178649902\n", "30: loss: 17.149667739868164\n", "31: loss: 8.965814590454102\n", "32: loss: 10.492903709411621\n", "33: loss: 11.236382484436035\n", "34: loss: 10.356119155883789\n", "35: loss: 9.816141128540039\n", "36: loss: 11.789191246032715\n", "37: loss: 10.450325012207031\n", "38: loss: 18.911396026611328\n", "39: loss: 8.278931617736816\n", "40: loss: 10.884782791137695\n", "41: loss: 8.885784149169922\n", "42: loss: 9.226049423217773\n", "43: loss: 10.362125396728516\n", "44: loss: 4.0845770835876465\n", "45: loss: 9.664544105529785\n", "46: loss: 9.46312427520752\n", "47: loss: 9.138323783874512\n", "48: loss: 7.396448135375977\n", "49: loss: 7.293612480163574\n", "50: loss: 10.331693649291992\n", "51: loss: 7.775559425354004\n", "52: loss: 7.011277198791504\n", "53: loss: 6.324047565460205\n", "54: loss: 5.501199245452881\n", "55: loss: 4.69442081451416\n", "56: loss: 4.073971748352051\n", "57: loss: 4.142904758453369\n", "58: loss: 4.585968017578125\n", "59: loss: 4.700481414794922\n", "60: loss: 5.152374267578125\n", "61: loss: 8.181085586547852\n", "62: loss: 6.7371416091918945\n", "63: loss: 10.67423152923584\n", "64: loss: 5.926950454711914\n", "65: loss: 5.470860004425049\n", "66: loss: 4.630016803741455\n", "67: loss: 5.366561412811279\n", "68: loss: 11.271105766296387\n", "69: loss: 6.516841411590576\n", "70: loss: 7.9438066482543945\n", "71: loss: 5.358776092529297\n", "72: loss: 5.713461875915527\n", "73: loss: 7.075550556182861\n", "74: loss: 5.229584217071533\n", "75: loss: 5.103419303894043\n", "76: loss: 4.516308307647705\n", "77: loss: 7.4682488441467285\n", "78: loss: 7.275866508483887\n", "79: loss: 5.846785545349121\n", "80: loss: 5.688624382019043\n", "81: loss: 5.150119781494141\n", "82: loss: 4.671944618225098\n", "83: loss: 8.293455123901367\n", "84: loss: 7.202897071838379\n", "85: loss: 4.38778018951416\n", "86: loss: 4.410329818725586\n", "87: loss: 4.341781139373779\n", "88: loss: 4.000961780548096\n", "89: loss: 4.009156703948975\n", "90: loss: 3.562082052230835\n", "91: loss: 3.641108989715576\n", "92: loss: 5.916473388671875\n", "93: loss: 4.046755790710449\n", "94: loss: 6.699942111968994\n", "95: loss: 6.139719009399414\n", "96: loss: 10.71791934967041\n", "97: loss: 4.094853401184082\n", "98: loss: 6.08973503112793\n", "99: loss: 9.11803150177002\n", "100: loss: 8.486052513122559\n", "100: valid loss 4.0021281242370605\n", "101: loss: 4.0021281242370605\n", "102: loss: 3.346961736679077\n", "103: loss: 3.15854549407959\n", "104: loss: 2.5357956886291504\n", "105: loss: 5.492861270904541\n", "106: loss: 2.7623958587646484\n", "107: loss: 2.9482226371765137\n", "108: loss: 6.3801493644714355\n", "109: loss: 4.1293463706970215\n", "110: loss: 3.566096067428589\n", "111: loss: 3.569946527481079\n", "112: loss: 3.762925624847412\n", "113: loss: 6.147146701812744\n", "114: loss: 5.933719635009766\n", "115: loss: 6.800720691680908\n", "116: loss: 2.86614990234375\n", "117: loss: 3.0812878608703613\n", "118: loss: 3.110222101211548\n", "119: loss: 4.000320911407471\n", "120: loss: 3.2422871589660645\n", "121: loss: 3.7775020599365234\n", "122: loss: 3.595900774002075\n", "123: loss: 2.73819637298584\n", "124: loss: 3.4981672763824463\n", "125: loss: 5.3726325035095215\n", "126: loss: 3.0014798641204834\n", "127: loss: 3.5963802337646484\n", "128: loss: 2.8306686878204346\n", "129: loss: 2.5162878036499023\n", "130: loss: 2.685560941696167\n", "131: loss: 6.374442100524902\n", "132: loss: 7.788975715637207\n", "133: loss: 2.897576332092285\n", "134: loss: 3.333127737045288\n", "135: loss: 3.436774253845215\n", "136: loss: 4.979071617126465\n", "137: loss: 4.120012283325195\n", "138: loss: 3.7855355739593506\n", "139: loss: 4.324587345123291\n", "140: loss: 3.4336843490600586\n", "141: loss: 2.6801435947418213\n", "142: loss: 3.359581470489502\n", "143: loss: 5.4692182540893555\n", "144: loss: 5.773078918457031\n", "145: loss: 4.27987813949585\n", "146: loss: 7.247451305389404\n", "147: loss: 6.170166492462158\n", "148: loss: 4.961609840393066\n", "149: loss: 4.028770923614502\n", "150: loss: 2.90120005607605\n", "151: loss: 1.9893661737442017\n", "152: loss: 1.652574062347412\n", "153: loss: 2.374600887298584\n", "154: loss: 2.1045265197753906\n", "155: loss: 6.417508125305176\n", "156: loss: 5.273669719696045\n", "157: loss: 6.238985538482666\n", "158: loss: 3.8025736808776855\n", "159: loss: 6.6854705810546875\n", "160: loss: 2.5476467609405518\n", "161: loss: 6.810393810272217\n", "162: loss: 2.2033159732818604\n", "163: loss: 1.9863100051879883\n", "164: loss: 4.976431369781494\n", "165: loss: 3.899188756942749\n", "166: loss: 4.68454647064209\n", "167: loss: 2.4539690017700195\n", "168: loss: 6.830282688140869\n", "169: loss: 1.7942843437194824\n", "170: loss: 1.242318868637085\n", "171: loss: 5.012855052947998\n", "172: loss: 1.6154134273529053\n", "173: loss: 1.5895756483078003\n", "174: loss: 5.240614891052246\n", "175: loss: 1.8958660364151\n", "176: loss: 2.1411402225494385\n", "177: loss: 5.932228088378906\n", "178: loss: 2.7539122104644775\n", "179: loss: 6.218499660491943\n", "180: loss: 2.991704225540161\n", "181: loss: 3.378645896911621\n", "182: loss: 2.719741106033325\n", "183: loss: 2.5844321250915527\n", "184: loss: 5.851257801055908\n", "185: loss: 2.239989995956421\n", "186: loss: 5.5589141845703125\n", "187: loss: 3.11521053314209\n", "188: loss: 2.5269265174865723\n", "189: loss: 2.181260824203491\n", "190: loss: 1.8941911458969116\n", "191: loss: 5.106175422668457\n", "192: loss: 3.5514838695526123\n", "193: loss: 3.233003854751587\n", "194: loss: 2.55694317817688\n", "195: loss: 6.5134053230285645\n", "196: loss: 6.311967372894287\n", "197: loss: 2.3541362285614014\n", "198: loss: 6.195401668548584\n", "199: loss: 3.013007879257202\n", "200: loss: 2.53104567527771\n", "200: valid loss 1.895339846611023\n", "201: loss: 7.572109699249268\n", "202: loss: 1.946860909461975\n", "203: loss: 1.6077873706817627\n", "204: loss: 1.5050052404403687\n", "205: loss: 1.1216596364974976\n", "206: loss: 1.017206072807312\n", "207: loss: 7.081823825836182\n", "208: loss: 1.1608872413635254\n", "209: loss: 0.728882908821106\n", "210: loss: 0.514722466468811\n", "211: loss: 0.6075964570045471\n", "212: loss: 0.7593868970870972\n", "213: loss: 0.6465023159980774\n", "214: loss: 8.1160888671875\n", "215: loss: 0.8256340622901917\n", "216: loss: 0.5982277393341064\n", "217: loss: 7.202335834503174\n", "218: loss: 4.8967790603637695\n", "219: loss: 2.037604331970215\n", "220: loss: 1.7443571090698242\n", "221: loss: 0.8838777542114258\n", "222: loss: 0.7871264219284058\n", "223: loss: 5.985363483428955\n", "224: loss: 3.6808922290802\n", "225: loss: 4.453125476837158\n", "226: loss: 4.137350559234619\n", "227: loss: 1.5606231689453125\n", "228: loss: 5.764791488647461\n", "229: loss: 1.2394036054611206\n", "230: loss: 1.1438194513320923\n", "231: loss: 0.5560073852539062\n", "232: loss: 5.746810436248779\n", "233: loss: 4.34252405166626\n", "234: loss: 6.079676628112793\n", "235: loss: 4.213600158691406\n", "236: loss: 1.1661522388458252\n", "237: loss: 7.770791053771973\n", "238: loss: 3.6331183910369873\n", "239: loss: 6.657710552215576\n", "240: loss: 4.314018249511719\n", "241: loss: 3.964081048965454\n", "242: loss: 3.4643802642822266\n", "243: loss: 3.2389814853668213\n", "244: loss: 5.009263515472412\n", "245: loss: 5.4173903465271\n", "246: loss: 3.464853048324585\n", "247: loss: 2.690930128097534\n", "248: loss: 5.482550621032715\n", "249: loss: 1.500435709953308\n", "250: loss: 1.207865834236145\n", "251: loss: 6.162202835083008\n", "252: loss: 0.5159206986427307\n", "253: loss: 0.352285772562027\n", "254: loss: 0.28347644209861755\n", "255: loss: 0.2998739182949066\n", "256: loss: 7.412589073181152\n", "257: loss: 1.0271281003952026\n", "258: loss: 0.5622831583023071\n", "259: loss: 6.975170135498047\n", "260: loss: 0.050237879157066345\n", "261: loss: 9.500787734985352\n", "262: loss: 1.1100494861602783\n", "263: loss: 10.5401029586792\n", "264: loss: 7.637964725494385\n", "265: loss: 1.5384433269500732\n", "266: loss: 0.6748937368392944\n", "267: loss: 0.38336750864982605\n", "268: loss: 0.1832476705312729\n", "269: loss: 7.080984115600586\n", "270: loss: 6.806582927703857\n", "271: loss: 6.216980457305908\n", "272: loss: 8.122699737548828\n", "273: loss: 2.344430685043335\n", "274: loss: 5.185897350311279\n", "275: loss: 5.136538982391357\n", "276: loss: 4.847122669219971\n", "277: loss: 3.447641372680664\n", "278: loss: 1.9696052074432373\n", "279: loss: 6.129249095916748\n", "280: loss: 1.4744977951049805\n", "281: loss: 4.836997032165527\n", "282: loss: 4.361396789550781\n", "283: loss: 4.975046157836914\n", "284: loss: 5.6431074142456055\n", "285: loss: 8.127538681030273\n", "286: loss: 7.203218460083008\n", "287: loss: 2.408040761947632\n", "288: loss: 1.7607803344726562\n", "289: loss: 1.1752283573150635\n", "290: loss: 5.39897346496582\n", "291: loss: 0.8753417134284973\n", "292: loss: 6.104700088500977\n", "293: loss: 0.8714774250984192\n", "294: loss: 5.633414268493652\n", "295: loss: 1.0734435319900513\n", "296: loss: 0.5978174209594727\n", "297: loss: 0.6240620613098145\n", "298: loss: 0.3799970746040344\n", "299: loss: 5.793654441833496\n", "300: loss: 4.920631408691406\n", "300: valid loss 0.5733768343925476\n", "301: loss: 0.5733768343925476\n", "302: loss: 0.35356906056404114\n", "303: loss: 6.0288190841674805\n", "304: loss: 0.17994554340839386\n", "305: loss: 6.07096004486084\n", "306: loss: 0.798763632774353\n", "307: loss: 0.30721110105514526\n", "308: loss: 0.35866862535476685\n", "309: loss: 6.664376258850098\n", "310: loss: 10.371112823486328\n", "311: loss: 1.5442111492156982\n", "312: loss: 0.5046924948692322\n", "313: loss: 0.02138896845281124\n", "314: loss: 11.088417053222656\n", "315: loss: 0.2801823616027832\n", "316: loss: 1.6325680017471313\n", "317: loss: 1.042490005493164\n", "318: loss: 0.19980621337890625\n", "319: loss: 6.208798408508301\n", "320: loss: 2.2923152446746826\n", "321: loss: 1.5293265581130981\n", "322: loss: 5.384918212890625\n", "323: loss: 0.5806372165679932\n", "324: loss: 0.11083264648914337\n", "325: loss: 6.474861145019531\n", "326: loss: 6.7361063957214355\n", "327: loss: 6.07684850692749\n", "328: loss: 0.1449495404958725\n", "329: loss: 0.24492450058460236\n", "330: loss: 0.0179277490824461\n", "331: loss: 5.866001605987549\n", "332: loss: 0.14012691378593445\n", "333: loss: 0.14467062056064606\n", "334: loss: 0.01395170483738184\n", "335: loss: 0.04150881618261337\n", "336: loss: 0.07648518681526184\n", "337: loss: 9.367613792419434\n", "338: loss: 8.372873306274414\n", "339: loss: 0.6273093223571777\n", "340: loss: 0.11360179632902145\n", "341: loss: 0.02351052314043045\n", "342: loss: 0.06904540210962296\n", "343: loss: 0.02174321562051773\n", "344: loss: 0.11702124029397964\n", "345: loss: 0.061455100774765015\n", "346: loss: 0.03193430230021477\n", "347: loss: 0.33268794417381287\n", "348: loss: 0.053275030106306076\n", "349: loss: 0.009291582740843296\n", "350: loss: 0.18401774764060974\n", "351: loss: 0.30571281909942627\n", "352: loss: 17.913070678710938\n", "353: loss: 0.2126859426498413\n", "354: loss: 0.6229326128959656\n", "355: loss: 11.214807510375977\n", "356: loss: 0.15888328850269318\n", "357: loss: 0.662460446357727\n", "358: loss: 7.345875263214111\n", "359: loss: 7.803595066070557\n", "360: loss: 1.2322083711624146\n", "361: loss: 0.7014895081520081\n", "362: loss: 0.10298460721969604\n", "363: loss: 8.574231147766113\n", "364: loss: 0.03108447603881359\n", "365: loss: 0.6616091728210449\n", "366: loss: 4.938299655914307\n", "367: loss: 5.479018688201904\n", "368: loss: 6.740688800811768\n", "369: loss: 3.110865831375122\n", "370: loss: 4.795236587524414\n", "371: loss: 1.8502461910247803\n", "372: loss: 3.737464427947998\n", "373: loss: 1.9333598613739014\n", "374: loss: 7.145735740661621\n", "375: loss: 1.3372946977615356\n", "376: loss: 5.683573246002197\n", "377: loss: 1.204305648803711\n", "378: loss: 0.9289284348487854\n", "379: loss: 5.174688339233398\n", "380: loss: 1.458616852760315\n", "381: loss: 0.9457168579101562\n", "382: loss: 0.4627819359302521\n", "383: loss: 0.2658665180206299\n", "384: loss: 4.429558753967285\n", "385: loss: 1.2449607849121094\n", "386: loss: 1.3288488388061523\n", "387: loss: 6.628821849822998\n", "388: loss: 0.4825551211833954\n", "389: loss: 0.6510865688323975\n", "390: loss: 0.36395493149757385\n", "391: loss: 0.18036174774169922\n", "392: loss: 0.3237663209438324\n", "393: loss: 6.840792655944824\n", "394: loss: 1.6587960720062256\n", "395: loss: 7.458000659942627\n", "396: loss: 0.8729283809661865\n", "397: loss: 0.6731876134872437\n", "398: loss: 0.1747300773859024\n", "399: loss: 0.5882076621055603\n", "400: loss: 0.6982569098472595\n", "400: valid loss 0.4763210713863373\n", "401: loss: 0.4763210713863373\n", "402: loss: 0.46096739172935486\n", "403: loss: 4.166454792022705\n", "404: loss: 0.44991931319236755\n", "405: loss: 4.830379009246826\n", "406: loss: 0.5408239364624023\n", "407: loss: 0.2607786953449249\n", "408: loss: 0.13067474961280823\n", "409: loss: 4.062631130218506\n", "410: loss: 5.5028300285339355\n", "411: loss: 1.2942296266555786\n", "412: loss: 1.4390389919281006\n", "413: loss: 5.374651908874512\n", "414: loss: 1.2929461002349854\n", "415: loss: 0.643798291683197\n", "416: loss: 0.6353816986083984\n", "417: loss: 5.8032636642456055\n", "418: loss: 3.3737053871154785\n", "419: loss: 1.8712362051010132\n", "420: loss: 1.0622261762619019\n", "421: loss: 0.8681365847587585\n", "422: loss: 0.6761938333511353\n", "423: loss: 4.074782371520996\n", "424: loss: 0.4106965661048889\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[1;32mIn[4], line 49\u001b[0m\n\u001b[0;32m 46\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m coarse_transformer, trainer, wav2vec, soundstream\n\u001b[0;32m 47\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m---> 49\u001b[0m \u001b[43mtrain_coarse_transformer\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "Cell \u001b[1;32mIn[4], line 43\u001b[0m, in \u001b[0;36mtrain_coarse_transformer\u001b[1;34m()\u001b[0m\n\u001b[0;32m 23\u001b[0m coarse_transformer \u001b[38;5;241m=\u001b[39m CoarseTransformer(\n\u001b[0;32m 24\u001b[0m num_semantic_tokens\u001b[38;5;241m=\u001b[39mwav2vec\u001b[38;5;241m.\u001b[39mcodebook_size,\n\u001b[0;32m 25\u001b[0m codebook_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 29\u001b[0m audio_text_condition\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 30\u001b[0m )\n\u001b[0;32m 32\u001b[0m trainer \u001b[38;5;241m=\u001b[39m CoarseTransformerTrainer(\n\u001b[0;32m 33\u001b[0m transformer\u001b[38;5;241m=\u001b[39mcoarse_transformer,\n\u001b[0;32m 34\u001b[0m codec\u001b[38;5;241m=\u001b[39msoundstream,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 40\u001b[0m num_train_steps\u001b[38;5;241m=\u001b[39mnum_train_steps\n\u001b[0;32m 41\u001b[0m )\n\u001b[1;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 44\u001b[0m torch\u001b[38;5;241m.\u001b[39msave(coarse_transformer\u001b[38;5;241m.\u001b[39mstate_dict(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcoarse_transformer.pth\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 45\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msave coarse_transformer.pth\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1302\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train\u001b[1;34m(self, log_fn)\u001b[0m\n\u001b[0;32m 1299\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtrain\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_fn \u001b[38;5;241m=\u001b[39m noop):\n\u001b[0;32m 1301\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msteps \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_train_steps:\n\u001b[1;32m-> 1302\u001b[0m logs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1303\u001b[0m log_fn(logs)\n\u001b[0;32m 1305\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprint(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtraining complete\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\audiolm_pytorch\\trainer.py:1244\u001b[0m, in \u001b[0;36mCoarseTransformerTrainer.train_step\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 1238\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mautocast(), context():\n\u001b[0;32m 1239\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_wrapper(\n\u001b[0;32m 1240\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdata_kwargs,\n\u001b[0;32m 1241\u001b[0m return_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 1242\u001b[0m )\n\u001b[1;32m-> 1244\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m/\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad_accum_every\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1246\u001b[0m accum_log(logs, {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss\u001b[38;5;241m.\u001b[39mitem() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgrad_accum_every})\n\u001b[0;32m 1248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m exists(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_grad_norm):\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\accelerate\\accelerator.py:2151\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[1;34m(self, loss, **kwargs)\u001b[0m\n\u001b[0;32m 2149\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[0;32m 2150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2151\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m 515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m 516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m 517\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m 518\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 523\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m 524\u001b[0m )\n\u001b[1;32m--> 525\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 526\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m 527\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m 262\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m 264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m 265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m 266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 267\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 268\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\Sunil\\Mini Project\\MusicLM\\myenv\\Lib\\site-packages\\torch\\autograd\\graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m 742\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m 743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 744\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m 745\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[0;32m 746\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "def train_coarse_transformer():\n", " wav2vec = HubertWithKmeans(\n", " checkpoint_path=checkpoint_path,\n", " kmeans_path=kmeans_path\n", " )\n", " soundstream = MusicLMSoundStream(\n", " codebook_size=1024, # Add this line to specify the codebook size\n", " strides=(3, 4, 5, 8),\n", " target_sample_hz=24000,\n", " rq_num_quantizers=8\n", " )\n", "\n", " if torch.cuda.is_available():\n", " coarse_transformer = CoarseTransformer(\n", " num_semantic_tokens=wav2vec.codebook_size,\n", " codebook_size=1024,\n", " num_coarse_quantizers=4,\n", " dim=1024,\n", " depth=6,\n", " audio_text_condition=True\n", " ).cuda()\n", " else:\n", " coarse_transformer = CoarseTransformer(\n", " num_semantic_tokens=wav2vec.codebook_size,\n", " codebook_size=1024,\n", " num_coarse_quantizers=4,\n", " dim=1024,\n", " depth=6,\n", " audio_text_condition=True\n", " )\n", "\n", " trainer = CoarseTransformerTrainer(\n", " transformer=coarse_transformer,\n", " codec=soundstream,\n", " wav2vec=wav2vec,\n", " audio_conditioner=quantizer,\n", " folder=audio_output_dir,\n", " batch_size=batch_size,\n", " data_max_length=data_max_length,\n", " num_train_steps=num_train_steps\n", " )\n", "\n", " trainer.train()\n", " torch.save(coarse_transformer.state_dict(), 'coarse_transformer.pth')\n", " print(\"save coarse_transformer.pth\")\n", " del coarse_transformer, trainer, wav2vec, soundstream\n", " gc.collect()\n", "\n", "train_coarse_transformer()" ] } ], "metadata": { "kernelspec": { "display_name": "myenv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.2" } }, "nbformat": 4, "nbformat_minor": 2 }