{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "gpuClass": "standard", "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# Fintune GPT2 using HuggingFace & PyTorch" ], "metadata": { "id": "2K_YzZvVxv81" } }, { "cell_type": "code", "source": [ "!pip install --quiet transformers==4.2.2" ], "metadata": { "id": "F4DGSHU_e915" }, "execution_count": 1, "outputs": [] }, { "cell_type": "markdown", "source": [ "Based off of [Philipp Schmid's](https://www.philschmid.de/philipp-schmid) [notebook](https://colab.research.google.com/github/philschmid/fine-tune-GPT-2/blob/master/Fine_tune_a_non_English_GPT_2_Model_with_Huggingface.ipynb#scrollTo=laDp891gO25V) with data from the [Trump Twitter Archive](https://www.thetrumparchive.com/?results=1).\n", "\n", "- GPT2 [Model Card](https://huggingface.co/gpt2)\n", "-[HuggingFace's Finetuning Docs](https://huggingface.co/learn/nlp-course/chapter3/3?fw=pt)" ], "metadata": { "id": "lw58eJhpyCww" } }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "iwZxNbIIbzbR" }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "import pandas as pd\n", "import json\n", "from transformers import (\n", " TextDataset,\n", " DataCollatorForLanguageModeling,\n", " AutoTokenizer,\n", " AutoModelWithLMHead,\n", " get_linear_schedule_with_warmup,\n", " Trainer,\n", " TrainingArguments,\n", " pipeline\n", ")\n", "from sklearn.model_selection import train_test_split\n", "from tqdm.auto import tqdm\n", "import torch\n", "from pathlib import Path" ] }, { "cell_type": "code", "source": [ "model_name = \"gpt2-medium\"\n", "\n", "if model_name == \"gpt2\":\n", " model_size = \"124M\"\n", "elif model_name == \"gpt2-medium\":\n", " model_size = \"355M\"\n", "elif model_name == \"gpt2-large\":\n", " model_size = \"774M\"\n", "elif model_name == \"gpt2-xl\":\n", " model_size = \"1.5B\"" ], "metadata": { "id": "GxBa9kFFsHaM" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "# define some params for model\n", "batch_size = 8\n", "epochs = 15\n", "learning_rate = 5e-4\n", "epsilon = 1e-8\n", "warmup_steps = 1e2\n", "sample_every = 100 # produce sample output every 100 steps\n", "max_length = 140 # max length used in generate method of model" ], "metadata": { "id": "1hFiQUbNcANl" }, "execution_count": 4, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Fetch / Load Data & Preprocess" ], "metadata": { "id": "9PY6SSKlcJNq" } }, { "cell_type": "code", "source": [ "tweets_path = Path(\"./data/tweets.json\")\n", "train_path = Path(\"./data/train_tweets.csv\")\n", "dev_path = Path(\"./data/dev_tweets.csv\")\n", "\n", "# fetch data if !exists already\n", "if not tweets_path.exists():\n", " !mkdir data\n", " !wget -O ./data/tweets.json \"https://drive.google.com/uc?export=download&id=16wm-2NTKohhcA26w-kaWfhLIGwl_oX95\"\n", "\n", "if not (train_path.exists() and dev_path.exists()):\n", " with open(tweets_path, 'rb') as f:\n", " # read json file into dict and then parse into df\n", " as_dict = json.loads(f.read())\n", " df = pd.DataFrame(as_dict)\n", " \n", " # filter df by !retweet\n", " df = df[df['isRetweet'] == \"f\"]\n", "\n", " # filter df to only text\n", " def is_multimedia(tweet: str):\n", " if tweet.startswith('https://t.co/'):\n", " return \"t\"\n", " else:\n", " return \"f\"\n", "\n", " df['isMultimedia'] = df['text'].apply(lambda x : is_multimedia(x))\n", " df = df[df['isMultimedia'] == \"f\"]\n", " df = df.reset_index(drop=True)\n", "\n", " # filter tweets to remove 'amp;'\n", " def remove_amp(tweet):\n", " tweet = tweet.replace('amp;', '')\n", " tweet = tweet.replace('amp', '')\n", " return tweet\n", " df['text'] = df['text'].apply(lambda x: remove_amp(x))\n", "\n", " # rename 'text' column to 'labels'\n", " # df = df.rename(columns={'text': 'labels'})\n", " \n", " # create train, validation splits\n", " train_data, dev_data = train_test_split(df[['text']], test_size=0.15) \n", " \n", " train_data.to_csv(train_path, index=False, header=None)\n", " dev_data.to_csv(dev_path, index=False, header=None)" ], "metadata": { "id": "aLQVWQ_dcB2h" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "# create tokenized datasets\n", "tokenizer = AutoTokenizer.from_pretrained(\n", " model_name, \n", " pad_token='<|endoftext|>'\n", ")\n", "\n", "# custom load_dataset function because there are no labels\n", "def load_dataset(train_path, dev_path, tokenizer):\n", " block_size = 128\n", " # block_size = tokenizer.model_max_length\n", " \n", " train_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=train_path,\n", " block_size=block_size)\n", " \n", " dev_dataset = TextDataset(\n", " tokenizer=tokenizer,\n", " file_path=dev_path,\n", " block_size=block_size) \n", " \n", " data_collator = DataCollatorForLanguageModeling(\n", " tokenizer=tokenizer, mlm=False,\n", " )\n", " return train_dataset, dev_dataset, data_collator\n", "\n", "train_dataset, dev_dataset, data_collator = load_dataset(train_path, dev_path, tokenizer)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "DAd1-0nLfcej", "outputId": "bc4e47c7-2fc9-47a4-add6-9a573568eb4c" }, "execution_count": 6, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/transformers/data/datasets/language_modeling.py:54: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_mlm.py\n", " warnings.warn(\n" ] } ] }, { "cell_type": "markdown", "source": [ "## Finetune Model" ], "metadata": { "id": "6szJYteUf9L3" } }, { "cell_type": "code", "source": [ "# AutoModelWithLMHead will pick GPT-2 weights from name\n", "model = AutoModelWithLMHead.from_pretrained(model_name, cache_dir=Path('cache').resolve())\n", "\n", "# necessary because of additional bos, eos, pad tokens to embeddings\n", "model.resize_token_embeddings(len(tokenizer))\n", "\n", "# create optimizer and learning rate schedule \n", "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, eps=epsilon)\n", "\n", "training_steps = len(train_dataset) * epochs\n", "\n", "# adjust learning rate during training\n", "scheduler = get_linear_schedule_with_warmup(optimizer, \n", " num_warmup_steps = warmup_steps, \n", " num_training_steps = training_steps)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Nv-bFNB1f68X", "outputId": "610e42fc-4fc2-4ceb-eb38-f00423fb5594" }, "execution_count": 7, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.9/dist-packages/transformers/models/auto/modeling_auto.py:921: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " warnings.warn(\n" ] } ] }, { "cell_type": "code", "source": [ "training_args = TrainingArguments(\n", " output_dir=f\"./{model_name}-{model_size}-trump\",\n", " overwrite_output_dir=True,\n", " num_train_epochs=epochs,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " eval_steps = 400, # n update steps between two evaluations\n", " save_steps=800, # n steps per model save \n", " warmup_steps=500, # n warmup steps for learning rate scheduler\n", " remove_unused_columns=False,\n", " prediction_loss_only=True\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " data_collator=data_collator,\n", " train_dataset=train_dataset,\n", " eval_dataset=dev_dataset,\n", ")" ], "metadata": { "id": "5OvNyCQagD1I" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "# train & save model run\n", "trainer.train()\n", "trainer.save_model()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "Ni9OyHY5gQLw", "outputId": "f0322248-f504-405d-d2e9-1b5646e8946c" }, "execution_count": 9, "outputs": [ { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "
---|---|
500 | \n", "3.622700 | \n", "
1000 | \n", "3.301600 | \n", "
1500 | \n", "3.145200 | \n", "
2000 | \n", "2.932000 | \n", "
2500 | \n", "2.925000 | \n", "
3000 | \n", "2.777100 | \n", "
3500 | \n", "2.661500 | \n", "
4000 | \n", "2.668100 | \n", "
4500 | \n", "2.482500 | \n", "
5000 | \n", "2.455600 | \n", "
5500 | \n", "2.443600 | \n", "
6000 | \n", "2.266700 | \n", "
6500 | \n", "2.271600 | \n", "
7000 | \n", "2.228200 | \n", "
7500 | \n", "2.108600 | \n", "
8000 | \n", "2.133900 | \n", "
8500 | \n", "2.017700 | \n", "
9000 | \n", "1.985300 | \n", "
9500 | \n", "1.999300 | \n", "
10000 | \n", "1.859000 | \n", "
10500 | \n", "1.869600 | \n", "
"
],
"text/plain": [
" "
]
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"## Generate tweets"
],
"metadata": {
"id": "vJLUI-tSgtaX"
}
},
{
"cell_type": "code",
"source": [
"trump = pipeline(\"text-generation\", model=f\"./{model_name}-{model_size}-trump\", tokenizer=tokenizer, config={\"max_length\":max_length})"
],
"metadata": {
"id": "-qyyt5O8TqON"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title\n",
"# give Trump a prompt\n",
"result = trump('The democrats have')"
],
"metadata": {
"id": "gnVtF1K_h473",
"collapsed": true,
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trump('Why does the lying news media')"
],
"metadata": {
"id": "H02NvY6lEPTJ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1c9263ae-e166-4008-f0e1-d6cd26e6109c"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[{'generated_text': 'Why does the lying news media refuse to state that Cruz poll numbers, as opposed to others, are the highest of any GOP? He beat @RealBenCarson!\"\\n\"\"\"\"\"Donald Trump to run for PGA Grand regressor\"\"\"\" http'}]"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"source": [
"trump(\"Today I'll be\")"
],
"metadata": {
"id": "n8BoiGLGEScg",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "81e35bc4-7005-42bf-c93b-8f70918ab802"
},
"execution_count": 13,
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"data": {
"text/plain": [
"[{'generated_text': \"Today I'll be rallying w/ @FEMA, First Responders, Law Enforcement, and First Responders of Puerto Rico to help those most affected by the #IrmaFlood.https://t.co/gsFSghkmdM\"}]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"source": [
"trump(\"The democrats have\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "snXJFbPCEooG",
"outputId": "79b70812-0eab-4f41-80ce-6e30fb028ebe"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[{'generated_text': 'The democrats have made life so difficult for your favorite President and Vice President. Many thousands of jobs have been lost. Would rather make a deal with Russia than play games. Great power for the U.S.A.\"\\n\"... and the U'}]"
]
},
"metadata": {},
"execution_count": 14
}
]
}
]
}\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Step \n",
" Training Loss \n",
" \n",
" \n",
" 500 \n",
" 3.622700 \n",
" \n",
" \n",
" 1000 \n",
" 3.301600 \n",
" \n",
" \n",
" 1500 \n",
" 3.145200 \n",
" \n",
" \n",
" 2000 \n",
" 2.932000 \n",
" \n",
" \n",
" 2500 \n",
" 2.925000 \n",
" \n",
" \n",
" 3000 \n",
" 2.777100 \n",
" \n",
" \n",
" 3500 \n",
" 2.661500 \n",
" \n",
" \n",
" 4000 \n",
" 2.668100 \n",
" \n",
" \n",
" 4500 \n",
" 2.482500 \n",
" \n",
" \n",
" 5000 \n",
" 2.455600 \n",
" \n",
" \n",
" 5500 \n",
" 2.443600 \n",
" \n",
" \n",
" 6000 \n",
" 2.266700 \n",
" \n",
" \n",
" 6500 \n",
" 2.271600 \n",
" \n",
" \n",
" 7000 \n",
" 2.228200 \n",
" \n",
" \n",
" 7500 \n",
" 2.108600 \n",
" \n",
" \n",
" 8000 \n",
" 2.133900 \n",
" \n",
" \n",
" 8500 \n",
" 2.017700 \n",
" \n",
" \n",
" 9000 \n",
" 1.985300 \n",
" \n",
" \n",
" 9500 \n",
" 1.999300 \n",
" \n",
" \n",
" 10000 \n",
" 1.859000 \n",
" \n",
" \n",
" 10500 \n",
" 1.869600 \n",
" \n",
" \n",
" 11000 \n",
" 1.858000 \n",
" \n",
" \n",
" 11500 \n",
" 1.759300 \n",
" \n",
" \n",
" 12000 \n",
" 1.765400 \n",
" \n",
" \n",
" 12500 \n",
" 1.732600 \n",
" \n",
" \n",
" 13000 \n",
" 1.670400 \n",
" \n",
" \n",
" 13500 \n",
" 1.689000 \n",
" \n",
" \n",
" 14000 \n",
" 1.619500 \n",
" \n",
" \n",
" 14500 \n",
" 1.611100 \n",
" \n",
" \n",
" 15000 \n",
" 1.619800 \n",
" \n",
" \n",
" 15500 \n",
" 1.539300 \n",
" \n",
" \n",
" 16000 \n",
" 1.550200 \n",
" \n",
" \n",
" 16500 \n",
" 1.539100 \n",
" \n",
" \n",
" 17000 \n",
" 1.491500 \n",
" \n",
" \n",
" 17500 \n",
" 1.507000 \n",
" \n",
" \n",
" 18000 \n",
" 1.479400 \n",
" \n",
" \n",
" 18500 \n",
" 1.462600 \n",
" \n",
" \n",
" 19000 \n",
" 1.464000 \n",
" \n",
" \n",
" 19500 \n",
" 1.442600 \n",
" \n",
" \n",
" \n",
"20000 \n",
" 1.439300 \n",
"