{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "e0102cb4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Global seed set to 100\n" ] }, { "data": { "text/plain": [ "100" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import T5Tokenizer, T5ForConditionalGeneration \n", "\n", "from transformers import AdamW\n", "import pandas as pd\n", "import torch\n", "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", "from torch.nn.utils.rnn import pad_sequence\n", "# from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler\n", "\n", "pl.seed_everything(100)" ] }, { "cell_type": "code", "execution_count": 2, "id": "1ec5ec2a", "metadata": {}, "outputs": [], "source": [ "MODEL_NAME='t5-base'" ] }, { "cell_type": "code", "execution_count": 3, "id": "8044c622", "metadata": {}, "outputs": [], "source": [ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "INPUT_MAX_LEN = 128 \n", "OUTPUT_MAX_LEN = 128" ] }, { "cell_type": "code", "execution_count": 4, "id": "6390f2de", "metadata": {}, "outputs": [], "source": [ "tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, model_max_length=512)" ] }, { "cell_type": "code", "execution_count": 5, "id": "8eec35d1", "metadata": {}, "outputs": [], "source": [ "class T5Model(pl.LightningModule):\n", " \n", " def __init__(self):\n", " super().__init__()\n", " self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)\n", "\n", " \n", " def forward(self, input_ids, attention_mask, labels=None):\n", " \n", " output = self.model(\n", " input_ids=input_ids, \n", " attention_mask=attention_mask, \n", " labels=labels\n", " )\n", " return output.loss, output.logits\n", " \n", " def training_step(self, batch, batch_idx):\n", "\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels= batch[\"target\"]\n", " loss, logits = self(input_ids , attention_mask, labels)\n", "\n", " \n", " self.log(\"train_loss\", loss, prog_bar=True, logger=True)\n", "\n", " return {'loss': loss}\n", " \n", " def validation_step(self, batch, batch_idx):\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels= batch[\"target\"]\n", " loss, logits = self(input_ids, attention_mask, labels)\n", "\n", " self.log(\"val_loss\", loss, prog_bar=True, logger=True)\n", " \n", " return {'val_loss': loss}\n", "\n", " def configure_optimizers(self):\n", " return AdamW(self.parameters(), lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 6, "id": "e9d96844", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Lightning automatically upgraded your loaded checkpoint from v1.9.3 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file F:\\Projects & Open_source\\Chatbot_T5_kaggle\\best-model.ckpt`\n" ] } ], "source": [ "train_model = T5Model.load_from_checkpoint('best-model.ckpt',map_location=DEVICE)" ] }, { "cell_type": "code", "execution_count": 7, "id": "3449943f", "metadata": {}, "outputs": [], "source": [ "train_model.freeze()" ] }, { "cell_type": "code", "execution_count": 8, "id": "0e9f1058", "metadata": {}, "outputs": [], "source": [ "def generate_question(question):\n", "\n", " inputs_encoding = tokenizer(\n", " question,\n", " add_special_tokens=True,\n", " max_length= INPUT_MAX_LEN,\n", " padding = 'max_length',\n", " truncation='only_first',\n", " return_attention_mask=True,\n", " return_tensors=\"pt\"\n", " )\n", "\n", " \n", " generate_ids = train_model.model.generate(\n", " input_ids = inputs_encoding[\"input_ids\"],\n", " attention_mask = inputs_encoding[\"attention_mask\"],\n", " max_length = INPUT_MAX_LEN,\n", " num_beams = 4,\n", " num_return_sequences = 1,\n", " no_repeat_ngram_size=2,\n", " early_stopping=True,\n", " )\n", "\n", " preds = [\n", " tokenizer.decode(gen_id,\n", " skip_special_tokens=True, \n", " clean_up_tokenization_spaces=True)\n", " for gen_id in generate_ids\n", " ]\n", "\n", " return \"\".join(preds)\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "ee38a88c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Ques: hi, how are you doing?\n", "BOT: i'm so glad you're doing well.\n" ] } ], "source": [ "ques = \"hi, how are you doing?\"\n", "print(\"Ques: \",ques)\n", "print(\"BOT: \",generate_question(ques))" ] }, { "cell_type": "code", "execution_count": 11, "id": "22aa4414", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " gr.Chatbot.style(chatbot,height=400)\n", " msg = gr.Textbox(info=\"Press \\'Enter\\' to send\")\n", " clear = gr.Button(\"Clear\")\n", "\n", " def user(user_message, history):\n", " return \"\", history + [[user_message, None]]\n", "\n", " def bot(history):\n", " bot_message = generate_question(history[-1][0])\n", " history[-1][1] = \"\"\n", " for character in bot_message:\n", " history[-1][1] += character\n", " time.sleep(0.05)\n", " yield history\n", "\n", " msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(\n", " bot, chatbot, chatbot\n", " )\n", " clear.click(lambda: None, None, chatbot, queue=True)\n", "\n", "demo.queue(concurrency_count=2)\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "fef38bdc", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import gradio as gr\n", "import random\n", "import time\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " msg = gr.Textbox(placeholder='Got any spare time...let\\'s chat!!!')\n", " gr.Textbox.style(msg,show_copy_button=True)\n", " clear = gr.Button(\"Clear\")\n", "\n", " def respond(message, chat_history):\n", " bot_message = generate_question(message)\n", " bot_message = \"**\"+bot_message+\"**\"\n", " chat_history.append((message, bot_message))\n", " time.sleep(1)\n", " return \"\", chat_history\n", "\n", " msg.submit(respond, [msg, chatbot], [msg, chatbot])\n", " clear.click(lambda: None, None, chatbot, queue=False)\n", "\n", "demo.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "a86d446a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.1" } }, "nbformat": 4, "nbformat_minor": 5 }