{ "cells": [ { "cell_type": "markdown", "id": "5ede69a1", "metadata": {}, "source": [ "# Classification Challenge using CatBoost\n", "\n", "## INF2179 Fall 2021\n", "### Hamid Yuksel\n", "\n", "This submission uses [CatBoost](https://catboost.ai/).\n", "CatBoost was chosen for its listed benefits, mainly in requiring less hyperparameter tuning and preprocessing of categorical and text features. It is also fast and fairly easy to set up.\n", "\n", "\"Markdown\n" ] }, { "cell_type": "code", "execution_count": 119, "id": "ee82451e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: catboost in /Users/yuksel/.local/lib/python3.8/site-packages (1.0.3)\n", "Requirement already satisfied: six in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from catboost) (1.15.0)\n", "Requirement already satisfied: pandas>=0.24.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from catboost) (1.2.4)\n", "Requirement already satisfied: matplotlib in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from catboost) (3.4.3)\n", "Requirement already satisfied: graphviz in /Users/yuksel/.local/lib/python3.8/site-packages (from catboost) (0.18)\n", "Requirement already satisfied: scipy in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from catboost) (1.6.2)\n", "Requirement already satisfied: numpy>=1.16.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from catboost) (1.20.1)\n", "Requirement already satisfied: plotly in /Users/yuksel/.local/lib/python3.8/site-packages (from catboost) (5.3.1)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from pandas>=0.24.0->catboost) (2.8.1)\n", "Requirement already satisfied: pytz>=2017.3 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from pandas>=0.24.0->catboost) (2021.1)\n", "Requirement already satisfied: pillow>=6.2.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from matplotlib->catboost) (8.2.0)\n", "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from matplotlib->catboost) (1.3.1)\n", "Requirement already satisfied: pyparsing>=2.2.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from matplotlib->catboost) (2.4.7)\n", "Requirement already satisfied: cycler>=0.10 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from matplotlib->catboost) (0.10.0)\n", "Requirement already satisfied: tenacity>=6.2.0 in /Users/yuksel/.local/lib/python3.8/site-packages (from plotly->catboost) (8.0.1)\n", "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.\n", "You should consider upgrading via the '/Users/yuksel/opt/anaconda3/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", "Requirement already satisfied: ipywidgets in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (7.6.3)\n", "Requirement already satisfied: traitlets>=4.3.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (5.0.5)\n", "Requirement already satisfied: ipython>=4.0.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (7.22.0)\n", "Requirement already satisfied: widgetsnbextension~=3.5.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (3.5.1)\n", "Requirement already satisfied: nbformat>=4.2.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (5.1.3)\n", "Requirement already satisfied: ipykernel>=4.5.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (5.3.4)\n", "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipywidgets) (1.0.0)\n", "Requirement already satisfied: appnope in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (0.1.2)\n", "Requirement already satisfied: tornado>=4.2 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1)\n", "Requirement already satisfied: jupyter-client in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.1.12)\n", "Requirement already satisfied: pygments in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (2.8.1)\n", "Requirement already satisfied: pexpect>4.3 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (4.8.0)\n", "Requirement already satisfied: jedi>=0.16 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.17.2)\n", "Requirement already satisfied: decorator in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (5.0.6)\n", "Requirement already satisfied: backcall in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (3.0.17)\n", "Requirement already satisfied: setuptools>=18.5 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (52.0.0.post20210125)\n", "Requirement already satisfied: pickleshare in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets) (0.7.5)\n", "Requirement already satisfied: parso<0.8.0,>=0.7.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.7.0)\n", "Requirement already satisfied: ipython-genutils in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\n", "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (3.2.0)\n", "Requirement already satisfied: jupyter-core in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets) (4.7.1)\n", "Requirement already satisfied: pyrsistent>=0.14.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.17.3)\n", "Requirement already satisfied: attrs>=17.4.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (20.3.0)\n", "Requirement already satisfied: six>=1.11.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (1.15.0)\n", "Requirement already satisfied: ptyprocess>=0.5 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)\n", "Requirement already satisfied: wcwidth in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets) (0.2.5)\n", "Requirement already satisfied: notebook>=4.4.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.3.0)\n", "Requirement already satisfied: argon2-cffi in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (20.1.0)\n", "Requirement already satisfied: nbconvert in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (6.0.7)\n", "Requirement already satisfied: Send2Trash>=1.5.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n", "Requirement already satisfied: jinja2 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.11.3)\n", "Requirement already satisfied: pyzmq>=17 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (20.0.0)\n", "Requirement already satisfied: terminado>=0.8.3 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.9.4)\n", "Requirement already satisfied: prometheus-client in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.10.1)\n", "Requirement already satisfied: python-dateutil>=2.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.1)\n", "Requirement already satisfied: cffi>=1.0.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.14.5)\n", "Requirement already satisfied: pycparser in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from cffi>=1.0.0->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.20)\n", "Requirement already satisfied: MarkupSafe>=0.23 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.0.1)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: mistune<2,>=0.8.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n", "Requirement already satisfied: bleach in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.3.0)\n", "Requirement already satisfied: testpath in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.4)\n", "Requirement already satisfied: jupyterlab-pygments in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.1.2)\n", "Requirement already satisfied: defusedxml in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.3)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.3)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\n", "Requirement already satisfied: nest-asyncio in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.1)\n", "Requirement already satisfied: async-generator in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from nbclient<0.6.0,>=0.5.0->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.10)\n", "Requirement already satisfied: packaging in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (20.9)\n", "Requirement already satisfied: webencodings in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n", "Requirement already satisfied: pyparsing>=2.0.2 in /Users/yuksel/opt/anaconda3/lib/python3.8/site-packages (from packaging->bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.4.7)\n", "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 21.3.1 is available.\n", "You should consider upgrading via the '/Users/yuksel/opt/anaconda3/bin/python -m pip install --upgrade pip' command.\u001b[0m\n", "Enabling notebook extension jupyter-js-widgets/extension...\n", " - Validating: \u001b[32mOK\u001b[0m\n" ] } ], "source": [ "#Installing and Importing required libraries\n", "! pip3 install --user catboost\n", "! pip3 install --user ipywidgets\n", "! jupyter nbextension enable --py widgetsnbextension\n", "\n", "import pandas as pd \n", "import numpy as np\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.metrics import accuracy_score \n", "from catboost import Pool, CatBoostClassifier" ] }, { "cell_type": "code", "execution_count": 121, "id": "1853f85d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Rock 24486\n", "Pop 16251\n", "Hip Hop 9263\n", "unknown 5000\n", "Name: Genre, dtype: int64" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Reading data\n", "df = pd.read_csv('data.csv')\n", "\n", "# Splitting\n", "training = df.head(50000)\n", "holdout_set = training.sample(5000, random_state=1) # pick 5000 observations randomly\n", "training = training.drop(holdout_set.index) # Remove holdout from training data\n", "testing = df.tail(5000)\n", "\n", "#Looking at counts per genre\n", "df['Genre'].value_counts()" ] }, { "cell_type": "code", "execution_count": 122, "id": "f0a20f7a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Lyric
50000Feels so good,. Feels so good,. Feels so good ...
50001Shadow of a doubt. I heard your heart,. you he...
50002Slaves. Hebrews born to serve to the pharaoh. ...
50003You've been picked and it's over. What's the c...
50004Magic happens. But only if you are open to the...
......
54995I can't believe what you did to me. Down on my...
54996Have all the songs been written?. Have all the...
54997Everything you do you do so right. The clothes...
54998(trecho). (Rule Number Two. Understanding what...
54999As fall rides off in the Sunset. I sweep the S...
\n", "

5000 rows × 1 columns

\n", "
" ], "text/plain": [ " Lyric\n", "50000 Feels so good,. Feels so good,. Feels so good ...\n", "50001 Shadow of a doubt. I heard your heart,. you he...\n", "50002 Slaves. Hebrews born to serve to the pharaoh. ...\n", "50003 You've been picked and it's over. What's the c...\n", "50004 Magic happens. But only if you are open to the...\n", "... ...\n", "54995 I can't believe what you did to me. Down on my...\n", "54996 Have all the songs been written?. Have all the...\n", "54997 Everything you do you do so right. The clothes...\n", "54998 (trecho). (Rule Number Two. Understanding what...\n", "54999 As fall rides off in the Sunset. I sweep the S...\n", "\n", "[5000 rows x 1 columns]" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Splitting training/testing set to feature (X) and labels (y)\n", "train_y = training.Genre\n", "train_X = training.drop('Genre', axis=1)\n", "\n", "test_X = testing.drop('Genre', axis=1)\n", "\n", "test_X" ] }, { "cell_type": "code", "execution_count": 123, "id": "1c8eb420", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\tlearn: 0.8096011\ttotal: 81.1ms\tremaining: 8.03s\n", "1:\tlearn: 0.7791994\ttotal: 150ms\tremaining: 7.34s\n", "2:\tlearn: 0.7603779\ttotal: 224ms\tremaining: 7.25s\n", "3:\tlearn: 0.7547274\ttotal: 290ms\tremaining: 6.95s\n", "4:\tlearn: 0.7483987\ttotal: 367ms\tremaining: 6.97s\n", "5:\tlearn: 0.7445907\ttotal: 432ms\tremaining: 6.76s\n", "6:\tlearn: 0.7401263\ttotal: 501ms\tremaining: 6.65s\n", "7:\tlearn: 0.7385096\ttotal: 578ms\tremaining: 6.64s\n", "8:\tlearn: 0.7358355\ttotal: 654ms\tremaining: 6.61s\n", "9:\tlearn: 0.7336089\ttotal: 757ms\tremaining: 6.81s\n", "10:\tlearn: 0.7320447\ttotal: 816ms\tremaining: 6.6s\n", "11:\tlearn: 0.7291184\ttotal: 882ms\tremaining: 6.47s\n", "12:\tlearn: 0.7262838\ttotal: 959ms\tremaining: 6.42s\n", "13:\tlearn: 0.7228727\ttotal: 1.04s\tremaining: 6.37s\n", "14:\tlearn: 0.7201308\ttotal: 1.1s\tremaining: 6.23s\n", "15:\tlearn: 0.7178398\ttotal: 1.19s\tremaining: 6.26s\n", "16:\tlearn: 0.7163848\ttotal: 1.26s\tremaining: 6.17s\n", "17:\tlearn: 0.7146310\ttotal: 1.33s\tremaining: 6.06s\n", "18:\tlearn: 0.7136866\ttotal: 1.39s\tremaining: 5.92s\n", "19:\tlearn: 0.7126146\ttotal: 1.47s\tremaining: 5.88s\n", "20:\tlearn: 0.7116228\ttotal: 1.53s\tremaining: 5.76s\n", "21:\tlearn: 0.7108619\ttotal: 1.59s\tremaining: 5.65s\n", "22:\tlearn: 0.7101405\ttotal: 1.67s\tremaining: 5.58s\n", "23:\tlearn: 0.7086742\ttotal: 1.74s\tremaining: 5.5s\n", "24:\tlearn: 0.7073975\ttotal: 1.82s\tremaining: 5.45s\n", "25:\tlearn: 0.7064083\ttotal: 1.88s\tremaining: 5.35s\n", "26:\tlearn: 0.7053139\ttotal: 1.94s\tremaining: 5.24s\n", "27:\tlearn: 0.7039690\ttotal: 2s\tremaining: 5.16s\n", "28:\tlearn: 0.7030711\ttotal: 2.08s\tremaining: 5.09s\n", "29:\tlearn: 0.7021268\ttotal: 2.17s\tremaining: 5.07s\n", "30:\tlearn: 0.7015723\ttotal: 2.23s\tremaining: 4.98s\n", "31:\tlearn: 0.7008915\ttotal: 2.3s\tremaining: 4.88s\n", "32:\tlearn: 0.6999748\ttotal: 2.35s\tremaining: 4.78s\n", "33:\tlearn: 0.6992735\ttotal: 2.42s\tremaining: 4.7s\n", "34:\tlearn: 0.6985654\ttotal: 2.49s\tremaining: 4.62s\n", "35:\tlearn: 0.6977550\ttotal: 2.56s\tremaining: 4.54s\n", "36:\tlearn: 0.6970458\ttotal: 2.63s\tremaining: 4.48s\n", "37:\tlearn: 0.6961753\ttotal: 2.72s\tremaining: 4.43s\n", "38:\tlearn: 0.6954978\ttotal: 2.77s\tremaining: 4.34s\n", "39:\tlearn: 0.6949396\ttotal: 2.84s\tremaining: 4.25s\n", "40:\tlearn: 0.6941580\ttotal: 2.89s\tremaining: 4.16s\n", "41:\tlearn: 0.6927943\ttotal: 2.98s\tremaining: 4.11s\n", "42:\tlearn: 0.6920748\ttotal: 3.03s\tremaining: 4.02s\n", "43:\tlearn: 0.6916593\ttotal: 3.1s\tremaining: 3.94s\n", "44:\tlearn: 0.6911027\ttotal: 3.15s\tremaining: 3.85s\n", "45:\tlearn: 0.6906870\ttotal: 3.22s\tremaining: 3.78s\n", "46:\tlearn: 0.6900598\ttotal: 3.29s\tremaining: 3.71s\n", "47:\tlearn: 0.6893376\ttotal: 3.35s\tremaining: 3.63s\n", "48:\tlearn: 0.6890191\ttotal: 3.41s\tremaining: 3.55s\n", "49:\tlearn: 0.6884704\ttotal: 3.48s\tremaining: 3.48s\n", "50:\tlearn: 0.6877523\ttotal: 3.57s\tremaining: 3.43s\n", "51:\tlearn: 0.6871784\ttotal: 3.63s\tremaining: 3.35s\n", "52:\tlearn: 0.6859393\ttotal: 3.71s\tremaining: 3.29s\n", "53:\tlearn: 0.6854332\ttotal: 3.78s\tremaining: 3.22s\n", "54:\tlearn: 0.6849834\ttotal: 3.85s\tremaining: 3.15s\n", "55:\tlearn: 0.6845240\ttotal: 3.91s\tremaining: 3.07s\n", "56:\tlearn: 0.6840909\ttotal: 3.98s\tremaining: 3s\n", "57:\tlearn: 0.6837616\ttotal: 4.05s\tremaining: 2.93s\n", "58:\tlearn: 0.6830614\ttotal: 4.12s\tremaining: 2.86s\n", "59:\tlearn: 0.6825999\ttotal: 4.17s\tremaining: 2.78s\n", "60:\tlearn: 0.6821756\ttotal: 4.24s\tremaining: 2.71s\n", "61:\tlearn: 0.6817505\ttotal: 4.3s\tremaining: 2.63s\n", "62:\tlearn: 0.6812764\ttotal: 4.36s\tremaining: 2.56s\n", "63:\tlearn: 0.6807991\ttotal: 4.42s\tremaining: 2.48s\n", "64:\tlearn: 0.6802569\ttotal: 4.47s\tremaining: 2.41s\n", "65:\tlearn: 0.6797401\ttotal: 4.53s\tremaining: 2.33s\n", "66:\tlearn: 0.6789358\ttotal: 4.6s\tremaining: 2.26s\n", "67:\tlearn: 0.6784886\ttotal: 4.66s\tremaining: 2.19s\n", "68:\tlearn: 0.6781482\ttotal: 4.72s\tremaining: 2.12s\n", "69:\tlearn: 0.6777138\ttotal: 4.78s\tremaining: 2.05s\n", "70:\tlearn: 0.6771701\ttotal: 4.84s\tremaining: 1.98s\n", "71:\tlearn: 0.6767495\ttotal: 4.9s\tremaining: 1.9s\n", "72:\tlearn: 0.6762978\ttotal: 4.96s\tremaining: 1.83s\n", "73:\tlearn: 0.6753376\ttotal: 5.02s\tremaining: 1.76s\n", "74:\tlearn: 0.6749408\ttotal: 5.08s\tremaining: 1.69s\n", "75:\tlearn: 0.6746421\ttotal: 5.14s\tremaining: 1.62s\n", "76:\tlearn: 0.6739094\ttotal: 5.2s\tremaining: 1.55s\n", "77:\tlearn: 0.6735961\ttotal: 5.25s\tremaining: 1.48s\n", "78:\tlearn: 0.6730509\ttotal: 5.31s\tremaining: 1.41s\n", "79:\tlearn: 0.6726551\ttotal: 5.39s\tremaining: 1.35s\n", "80:\tlearn: 0.6721683\ttotal: 5.47s\tremaining: 1.28s\n", "81:\tlearn: 0.6716987\ttotal: 5.53s\tremaining: 1.21s\n", "82:\tlearn: 0.6712949\ttotal: 5.59s\tremaining: 1.14s\n", "83:\tlearn: 0.6709170\ttotal: 5.65s\tremaining: 1.08s\n", "84:\tlearn: 0.6703371\ttotal: 5.71s\tremaining: 1.01s\n", "85:\tlearn: 0.6699702\ttotal: 5.77s\tremaining: 940ms\n", "86:\tlearn: 0.6694328\ttotal: 5.83s\tremaining: 871ms\n", "87:\tlearn: 0.6690860\ttotal: 5.89s\tremaining: 803ms\n", "88:\tlearn: 0.6687439\ttotal: 5.95s\tremaining: 735ms\n", "89:\tlearn: 0.6681246\ttotal: 6.02s\tremaining: 669ms\n", "90:\tlearn: 0.6677217\ttotal: 6.08s\tremaining: 602ms\n", "91:\tlearn: 0.6672547\ttotal: 6.15s\tremaining: 535ms\n", "92:\tlearn: 0.6669134\ttotal: 6.23s\tremaining: 469ms\n", "93:\tlearn: 0.6666193\ttotal: 6.29s\tremaining: 402ms\n", "94:\tlearn: 0.6661117\ttotal: 6.36s\tremaining: 335ms\n", "95:\tlearn: 0.6656902\ttotal: 6.44s\tremaining: 268ms\n", "96:\tlearn: 0.6652210\ttotal: 6.5s\tremaining: 201ms\n", "97:\tlearn: 0.6648345\ttotal: 6.57s\tremaining: 134ms\n", "98:\tlearn: 0.6642868\ttotal: 6.65s\tremaining: 67.2ms\n", "99:\tlearn: 0.6639657\ttotal: 6.71s\tremaining: 0us\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Build a classifier\n", "text_features = ['Lyric']\n", "\n", "\n", "train_dataset = Pool(data=train_X,\n", " label=train_y,\n", " text_features=text_features)\n", "\n", "model = CatBoostClassifier(iterations=100,\n", " learning_rate=1,\n", " depth=3,\n", " loss_function='MultiClass')\n", "\n", "model.fit(train_dataset)" ] }, { "cell_type": "code", "execution_count": 129, "id": "f770a8f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.6794\n" ] } ], "source": [ "# Estimate accuracy\n", "pred = model.predict(holdout_set.drop('Genre',axis=1))\n", "estimated_accuracy = accuracy_score(holdout_set['Genre'], pred)\n", "print(estimated_accuracy)\n", "pd.Series(estimated_accuracy).to_csv('ea.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 131, "id": "035f58c6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['Pop' 'Rock' 'Rock' ... 'Rock' 'Pop' 'Rock']\n" ] } ], "source": [ "# Predict testing set\n", "pred = model.predict(test_X)\n", "print(pred.flatten())\n", "pred = pd.Series(pred.flatten()).to_csv('pred.csv', index=False, header=False)" ] }, { "cell_type": "code", "execution_count": 128, "id": "2137c958", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array(['Hip Hop', 'Pop', 'Rock'], dtype=object), array([ 802, 1307, 2891]))" ] }, "execution_count": 128, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# to check number of instances of each genre in pred\n", "np.unique(model.predict(test_X), return_counts=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "83aa22c4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }