{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard", "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# Fintune GPT2 using HuggingFace & PyTorch" ], "metadata": { "id": "2K_YzZvVxv81" } }, { "cell_type": "code", "source": [ "!pip install --quiet transformers==4.2.2" ], "metadata": { "id": "F4DGSHU_e915" }, "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "source": [ "Based off of [Philipp Schmid's](https://www.philschmid.de/philipp-schmid) [notebook](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb#scrollTo=laDp891gO25V) with data from the [Trump Twitter Archive](https://www.thetrumparchive.com/?results=1).\n", "\n", "- GPT2 [Model Card](https://huggingface.co/gpt2)\n", "-[HuggingFace's Finetuning Docs](https://huggingface.co/learn/nlp-course/chapter3/3?fw=pt)" ], "metadata": { "id": "lw58eJhpyCww" } }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "iwZxNbIIbzbR" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import pandas as pd\n", "import json\n", "from transformers import (\n", " TextDataset,\n", " DataCollatorForLanguageModeling,\n", " AutoTokenizer,\n", " AutoModelWithLMHead,\n", " get_linear_schedule_with_warmup,\n", " Trainer,\n", " TrainingArguments,\n", " pipeline\n", ")\n", "from sklearn.model_selection import train_test_split\n", "from tqdm.auto import tqdm\n", "import torch\n", "from pathlib import Path" ] }, { "cell_type": "code", "source": [ "model_name = \"gpt2-medium\"\n", "\n", "if model_name == \"gpt2\":\n", " model_size = \"124M\"\n", "elif model_name == \"gpt2-medium\":\n", " model_size = \"355M\"\n", "elif model_name == \"gpt2-large\":\n", " model_size = \"774M\"\n", "elif model_name == \"gpt2-xl\":\n", " model_size = \"1.5B\"" ], "metadata": { "id": "GxBa9kFFsHaM" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "# define some params for model\n", "batch_size = 8\n", "epochs = 15\n", "learning_rate = 5e-4\n", "epsilon = 1e-8\n", "warmup_steps = 1e2\n", "sample_every = 100 # produce sample output every 100 steps\n", "max_length = 140 # max length used in generate method of model" ], "metadata": { "id": "1hFiQUbNcANl" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Fetch / Load Data & Preprocess" ], "metadata": { "id": "9PY6SSKlcJNq" } }, { "cell_type": "code", "source": [ "tweets_path = Path(\"./data/tweets.json\")\n", "train_path = Path(\"./data/train_tweets.csv\")\n", "dev_path = Path(\"./data/dev_tweets.csv\")\n", "\n", "# fetch data if !exists already\n", "if not tweets_path.exists():\n", " !mkdir data\n", " !wget -O ./data/tweets.json \"https://drive.google.com/uc?export=download&id=16wm-2NTKohhcA26w-kaWfhLIGwl_oX95\"\n", "\n", "if not (train_path.exists() and dev_path.exists()):\n", " with open(tweets_path, 'rb') as f:\n", " # read json file into dict and then parse into df\n", " as_dict = json.loads(f.read())\n", " df = pd.DataFrame(as_dict)\n", " \n", " # filter df by !retweet\n", " df = df[df['isRetweet'] == \"f\"]\n", "\n", " # filter df to only text\n", " def is_multimedia(tweet: str):\n", " if tweet.startswith('https://t.co/'):\n", " return \"t\"\n", " else:\n", " return \"f\"\n", "\n", " df['isMultimedia'] = df['text'].apply(lambda x : is_multimedia(x))\n", " df = df[df['isMultimedia'] == \"f\"]\n", " df = df.reset_index(drop=True)\n", "\n", " # filter tweets to remove 'amp;'\n", " def remove_amp(tweet):\n", " tweet = tweet.replace('amp;', '')\n", " tweet = tweet.replace('amp', '')\n", " return tweet\n", " df['text'] = df['text'].apply(lambda x: remove_amp(x))\n", "\n", " # rename 'text' column to 'labels'\n", " # df = df.rename(columns={'text': 'labels'})\n", " \n", " # create train, validation splits\n", " train_data, dev_data = train_test_split(df[['text']], test_size=0.15) \n", " \n", " train_data.to_csv(train_path, index=False, header=None)\n", " dev_data.to_csv(dev_path, index=False, header=None)" ], "metadata": { "id": "aLQVWQ_dcB2h" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "# create tokenized datasets\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " model_name, \n", " pad_token='<|endoftext|>'\n", ")\n", "\n", "# custom load_dataset function because there are no labels\n", "def load_dataset(train_path, dev_path, tokenizer):\n", " block_size = 128\n", " # block_size = tokenizer.model_max_length\n", " \n", " train_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=train_path,\n", " block_size=block_size)\n", " \n", " dev_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=dev_path,\n", " block_size=block_size) \n", " \n", " data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer, mlm=False,\n", " )\n", " return train_dataset, dev_dataset, data_collator\n", "\n", "train_dataset, dev_dataset, data_collator = load_dataset(train_path, dev_path, tokenizer)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DAd1-0nLfcej", "outputId": "bc4e47c7-2fc9-47a4-add6-9a573568eb4c" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/transformers/data/datasets/language_modeling.py:54: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py\n", " warnings.warn(\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Finetune Model" ], "metadata": { "id": "6szJYteUf9L3" } }, { "cell_type": "code", "source": [ "# AutoModelWithLMHead will pick GPT-2 weights from name\n", "model = AutoModelWithLMHead.from_pretrained(model_name, cache_dir=Path('cache').resolve())\n", "\n", "# necessary because of additional bos, eos, pad tokens to embeddings\n", "model.resize_token_embeddings(len(tokenizer))\n", "\n", "# create optimizer and learning rate schedule \n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, eps=epsilon)\n", "\n", "training_steps = len(train_dataset) * epochs\n", "\n", "# adjust learning rate during training\n", "scheduler = get_linear_schedule_with_warmup(optimizer, \n", " num_warmup_steps = warmup_steps, \n", " num_training_steps = training_steps)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Nv-bFNB1f68X", "outputId": "610e42fc-4fc2-4ceb-eb38-f00423fb5594" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/transformers/models/auto/modeling_auto.py:921: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " warnings.warn(\n" ] } ] }, { "cell_type": "code", "source": [ "training_args = TrainingArguments(\n", " output_dir=f\"./{model_name}-{model_size}-trump\",\n", " overwrite_output_dir=True,\n", " num_train_epochs=epochs,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " eval_steps = 400, # n update steps between two evaluations\n", " save_steps=800, # n steps per model save \n", " warmup_steps=500, # n warmup steps for learning rate scheduler\n", " remove_unused_columns=False,\n", " prediction_loss_only=True\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=train_dataset,\n", " eval_dataset=dev_dataset,\n", ")" ], "metadata": { "id": "5OvNyCQagD1I" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "# train & save model run\n", "trainer.train()\n", "trainer.save_model()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "Ni9OyHY5gQLw", "outputId": "f0322248-f504-405d-d2e9-1b5646e8946c" }, "execution_count": 9, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " \n", " [10685/20460 2:24:09 < 2:11:54, 1.24 it/s, Epoch 7.83/15]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5003.622700
10003.301600
15003.145200
20002.932000
25002.925000
30002.777100
35002.661500
40002.668100
45002.482500
50002.455600
55002.443600
60002.266700
65002.271600
70002.228200
75002.108600
80002.133900
85002.017700
90001.985300
95001.999300
100001.859000
105001.869600

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " \n", " [20460/20460 4:36:05, Epoch 15/15]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
5003.622700
10003.301600
15003.145200
20002.932000
25002.925000
30002.777100
35002.661500
40002.668100
45002.482500
50002.455600
55002.443600
60002.266700
65002.271600
70002.228200
75002.108600
80002.133900
85002.017700
90001.985300
95001.999300
100001.859000
105001.869600
110001.858000
115001.759300
120001.765400
125001.732600
130001.670400
135001.689000
140001.619500
145001.611100
150001.619800
155001.539300
160001.550200
165001.539100
170001.491500
175001.507000
180001.479400
185001.462600
190001.464000
195001.442600
200001.439300

" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "source": [ "## Generate tweets" ], "metadata": { "id": "vJLUI-tSgtaX" } }, { "cell_type": "code", "source": [ "trump = pipeline(\"text-generation\", model=f\"./{model_name}-{model_size}-trump\", tokenizer=tokenizer, config={\"max_length\":max_length})" ], "metadata": { "id": "-qyyt5O8TqON" }, "execution_count": 10, "outputs": [] }, { "cell_type": "code", "source": [ "#@title\n", "# give Trump a prompt\n", "result = trump('The democrats have')" ], "metadata": { "id": "gnVtF1K_h473", "collapsed": true, "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "trump('Why does the lying news media')" ], "metadata": { "id": "H02NvY6lEPTJ", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "1c9263ae-e166-4008-f0e1-d6cd26e6109c" }, "execution_count": 12, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[{'generated_text': 'Why does the lying news media refuse to state that Cruz poll numbers, as opposed to others, are the highest of any GOP? He beat @RealBenCarson!\"\\n\"\"\"\"\"Donald Trump to run for PGA Grand regressor\"\"\"\" http'}]" ] }, "metadata": {}, "execution_count": 12 } ] }, { "cell_type": "code", "source": [ "trump(\"Today I'll be\")" ], "metadata": { "id": "n8BoiGLGEScg", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "81e35bc4-7005-42bf-c93b-8f70918ab802" }, "execution_count": 13, "outputs": [ { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "data": { "text/plain": [ "[{'generated_text': \"Today I'll be rallying w/ @FEMA, First Responders, Law Enforcement, and First Responders of Puerto Rico to help those most affected by the #IrmaFlood.https://t.co/gsFSghkmdM\"}]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ] }, { "cell_type": "code", "source": [ "trump(\"The democrats have\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "snXJFbPCEooG", "outputId": "79b70812-0eab-4f41-80ce-6e30fb028ebe" }, "execution_count": 14, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[{'generated_text': 'The democrats have made life so difficult for your favorite President and Vice President. Many thousands of jobs have been lost. Would rather make a deal with Russia than play games. Great power for the U.S.A.\"\\n\"... and the U'}]" ] }, "metadata": {}, "execution_count": 14 } ] } ] }