{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "HiY1rH4fuPeF", "outputId": "7c4cb24c-0a80-41fa-9f3e-7a79d7231921" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.44.2)\n", "Collecting datasets\n", " Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.0+cu121)\n", "Collecting faiss-cpu\n", " Downloading faiss_cpu-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.4 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.16.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.24.7)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.9.11)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.5)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.6)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess<0.70.17 (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)\n", " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.10)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", "Requirement already satisfied: yarl<2.0,>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.0)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4.0)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.8.30)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from yarl<2.0,>=1.12.0->aiohttp->datasets) (0.2.0)\n", "Downloading datasets-3.1.0-py3-none-any.whl (480 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m21.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading faiss_cpu-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (27.5 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m27.5/27.5 MB\u001b[0m \u001b[31m39.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m801.7 kB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m11.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m13.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: xxhash, fsspec, faiss-cpu, dill, multiprocess, datasets\n", " Attempting uninstall: fsspec\n", " Found existing installation: fsspec 2024.10.0\n", " Uninstalling fsspec-2024.10.0:\n", " Successfully uninstalled fsspec-2024.10.0\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed datasets-3.1.0 dill-0.3.8 faiss-cpu-1.9.0 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0\n" ] } ], "source": [ "!pip install transformers datasets torch faiss-cpu\n" ] }, { "cell_type": "code", "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"Amod/mental_health_counseling_conversations\")\n" ], "metadata": { "id": "Zh2FSQW-uWkg" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "import torch\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments\n", "from datasets import load_dataset\n", "# Instead of using train_test_split, we'll use Dataset.train_test_split\n", "#from sklearn.model_selection import train_test_split\n", "\n", "# Load dataset\n", "dataset = load_dataset(\"Amod/mental_health_counseling_conversations\")\n", "\n", "# Split the dataset using Dataset.train_test_split\n", "train_data = dataset['train'].train_test_split(test_size=0.2, seed=42)['train'] # 80% train\n", "val_data = dataset['train'].train_test_split(test_size=0.2, seed=42)['test'] # 20% validation\n", "\n", "\n", "# Load FLAN-T5 Small model and tokenizer\n", "model_name = \"google/flan-t5-small\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n", "\n", "# Preprocess the dataset\n", "def preprocess_data(examples):\n", " # Tokenize the context and response as input-output pairs for the Seq2Seq model\n", " inputs = tokenizer(examples['Context'], padding=\"max_length\", truncation=True, max_length=512)\n", " targets = tokenizer(examples['Response'], padding=\"max_length\", truncation=True, max_length=128)\n", " inputs['labels'] = targets['input_ids']\n", " return inputs\n", "\n", "# Apply preprocessing to the train and validation datasets\n", "train_data = train_data.map(preprocess_data, batched=True)\n", "val_data = val_data.map(preprocess_data, batched=True)\n", "\n", "# Remove unnecessary columns (Context, Response) from the dataset after preprocessing\n", "\n", "\n", "# Setup training arguments\n", "training_args = TrainingArguments(\n", " output_dir='./results',\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=2,\n", " per_device_eval_batch_size=2,\n", " num_train_epochs=5,\n", " logging_dir='./logs',\n", " logging_steps=10,\n", " save_strategy=\"epoch\", # Changed save_strategy to 'epoch' to match evaluation_strategy\n", " save_steps=500,\n", " save_total_limit=2,\n", " load_best_model_at_end=True,\n", " metric_for_best_model='loss',\n", ")\n", "\n", "# Initialize Trainer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_data,\n", " eval_dataset=val_data,\n", " tokenizer=tokenizer,\n", ")\n", "\n", "# Train the model\n", "trainer.train()\n", "\n", "# Save the fine-tuned model\n", "trainer.save_model(\"fine_tuned_flan_t5\")\n", "\n", "# Now the model is fine-tuned and saved, we can use it for RAG" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 487 }, "id": "uZ2imdKguqeb", "outputId": "2ad6eefe-2d44-4e23-8456-ad294a750aba" }, "execution_count": 17, "outputs": [ { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1525: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "
---|---|---|
1 | \n", "3.433000 | \n", "3.051346 | \n", "
2 | \n", "3.104000 | \n", "2.949355 | \n", "
3 | \n", "3.121100 | \n", "2.920628 | \n", "
4 | \n", "3.126200 | \n", "2.906628 | \n", "
"
],
"text/plain": [
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"trainer.evaluate(train_data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 124
},
"id": "sLp70Cgw7YTW",
"outputId": "7410e2b7-d4a9-459d-8e36-18d36b155493"
},
"execution_count": 25,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Epoch \n",
" Training Loss \n",
" Validation Loss \n",
" \n",
" \n",
" 1 \n",
" 3.433000 \n",
" 3.051346 \n",
" \n",
" \n",
" 2 \n",
" 3.104000 \n",
" 2.949355 \n",
" \n",
" \n",
" 3 \n",
" 3.121100 \n",
" 2.920628 \n",
" \n",
" \n",
" 4 \n",
" 3.126200 \n",
" 2.906628 \n",
" \n",
" \n",
" \n",
"5 \n",
" 3.051300 \n",
" 2.902490 \n",
"