File size: 38,226 Bytes
d2f9514 |
1 |
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyNUFphd1WJDBGh7ucucmaIf"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"fS2d3a6P9izs","executionInfo":{"status":"ok","timestamp":1687067545711,"user_tz":-330,"elapsed":19201,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"b61c9237-6961-4497-adc2-431e8013f7aa"},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["import os\n","print(os.getcwd())"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"vR_XPy3p9w-l","executionInfo":{"status":"ok","timestamp":1687067545713,"user_tz":-330,"elapsed":27,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"cc48de3e-f973-4f1d-e1de-186649edef36"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["/content\n"]}]},{"cell_type":"code","source":["!pwd\n","path_to_mount = '/content/drive/My Drive/Colab Notebooks/ChatBot/'\n","\n","# Change current working directory\n","os.chdir(path_to_mount)\n","!ls"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"rA-YBdj295JN","executionInfo":{"status":"ok","timestamp":1687067545714,"user_tz":-330,"elapsed":18,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"70a51df2-679e-4ca6-c65a-9c23dcd86796"},"execution_count":3,"outputs":[{"output_type":"stream","name":"stdout","text":["/content\n","'ChatBot with Attention.ipynb'\t data output.tsv\n"]}]},{"cell_type":"code","source":["# PyTorch\n","import torch\n","import torch.nn as nn\n","from torch import optim\n","import torch.nn.functional as F\n","\n","# Etc\n","from __future__ import unicode_literals, print_function, division\n","from io import open\n","import unicodedata\n","import re\n","import random\n","import glob\n","import json\n","\n","\n","# Use GPU if available\n","if (torch.cuda.is_available()):\n"," device = torch.device('cuda')\n"," print(\"Running on GPU\")\n","else:\n"," device = torch.device('cpu')\n"," print(\"Running on CPU\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XiLoxzUP9-D1","executionInfo":{"status":"ok","timestamp":1687067550547,"user_tz":-330,"elapsed":4841,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"2caea824-36f5-4b99-8254-ebcb3912b384"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stdout","text":["Running on GPU\n"]}]},{"cell_type":"code","source":["# Get absolute paths of files\n","dialogues_regex_folder_path = \"data/dialogues/*.txt\"\n","\n","# Get the absolute paths for each file\n","list_of_files = glob.glob(path_to_mount + dialogues_regex_folder_path)\n","print(list_of_files[:3]) # Visualize the first 3\n","print(len(list_of_files)) # 47"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j50_Rnrg-UD3","executionInfo":{"status":"ok","timestamp":1687067550548,"user_tz":-330,"elapsed":10,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"63582383-0f2c-4d18-a5b4-7e25a76ec160"},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["['/content/drive/My Drive/Colab Notebooks/ChatBot/data/dialogues/GAME_RULES.txt', '/content/drive/My Drive/Colab Notebooks/ChatBot/data/dialogues/PHONE_SETTINGS.txt', '/content/drive/My Drive/Colab Notebooks/ChatBot/data/dialogues/GEOGRAPHY.txt']\n","47\n"]}]},{"cell_type":"code","source":["# Parsing\n","list_of_dicts = [] # Init\n","\n","# Loop for each file\n","for filename in list_of_files:\n"," with open(filename) as f:\n"," for line in f: # Loop for each line (inside each file)\n"," list_of_dicts.append(json.loads(line)) # insert in a dictionary\n"],"metadata":{"id":"3p6u_m0p-k43","executionInfo":{"status":"ok","timestamp":1687067559807,"user_tz":-330,"elapsed":9265,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":6,"outputs":[]},{"cell_type":"code","source":["# Visualize the dictionaries\n","print(list_of_dicts[0])\n","print(list_of_dicts[1].keys)\n","print(list_of_dicts[332])\n","print(list_of_dicts[:3])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"gDurAifGAuPp","executionInfo":{"status":"ok","timestamp":1687067559808,"user_tz":-330,"elapsed":31,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"665ea600-e6d9-4762-e0e1-9d788e932865"},"execution_count":7,"outputs":[{"output_type":"stream","name":"stdout","text":["{'id': 'e1eeaaee', 'user_id': 'c1e41909', 'bot_id': '7b7b4ace', 'domain': 'GAME_RULES', 'task_id': 'f321dd70', 'turns': ['Hello how may I help you?', 'bot, I GOT SOME QUESTIONS', 'sure, how can i help?', 'im trying to win a text adventure game called zork Can you help me win?', \"Unfortunately, I can help you play games, but I'm not programmed to help you win.\", 'nooooooooo Come on. you arent programmed to help me win at least one game?', \"I'm sorry, I'm not.\", 'ok, we will have to get that changed', 'Is there anything else I can help you with?', 'how does one play zork', \"It's a pretty complicated computer game. I can't go through every step, do you have any specific questions?\", \"nope'\"]}\n","<built-in method keys of dict object at 0x7ff9eddd01c0>\n","{'id': 'e94a0e91', 'user_id': '1fc96e77', 'bot_id': '7283ec3b', 'domain': 'GAME_RULES', 'task_id': 'f321dd70', 'turns': ['Hello how may I help you?', 'How to win at text adventure Zork Game ?', 'I am programmed to help you play games not win them', 'What game can you help me with then ?', 'What game would you like help with?', 'Catch ?', 'I can help you with catch', 'How is it played ?', 'Two or more people throw an object to each other for an indefinite period of time', 'Great Thanks for the help', 'No problem']}\n","[{'id': 'e1eeaaee', 'user_id': 'c1e41909', 'bot_id': '7b7b4ace', 'domain': 'GAME_RULES', 'task_id': 'f321dd70', 'turns': ['Hello how may I help you?', 'bot, I GOT SOME QUESTIONS', 'sure, how can i help?', 'im trying to win a text adventure game called zork Can you help me win?', \"Unfortunately, I can help you play games, but I'm not programmed to help you win.\", 'nooooooooo Come on. you arent programmed to help me win at least one game?', \"I'm sorry, I'm not.\", 'ok, we will have to get that changed', 'Is there anything else I can help you with?', 'how does one play zork', \"It's a pretty complicated computer game. I can't go through every step, do you have any specific questions?\", \"nope'\"]}, {'id': '36882420', 'user_id': 'dd3a3a1f', 'bot_id': 'f70444dd', 'domain': 'GAME_RULES', 'task_id': 'dc579a94', 'turns': ['Hello how may I help you?', 'I need some help', 'Some help with what?', 'How to Play Catch?', 'I can help you with that.', 'okay', 'The game is technically played with 2 or more people. The participants throw an object back and forth (usually a ball) until they no longer want to play the game.', \"Wht's the rule?\", \"There really are no rules. It's not a competitive game unless you are playing a modified version of the game.\", 'okay thanks!', \"It's pretty much for fun and to practice throwing the object where you want it to go, and catching the object when it is thrown back to you.\"]}, {'id': '33a9e2e5', 'user_id': 'e918049b', 'bot_id': 'c828c320', 'domain': 'GAME_RULES', 'task_id': 'ae8feb7e', 'turns': ['Hello how may I help you?', 'Can you help me with rules?', 'Rules to what?', 'Monopoly', \"Everybody yells I'm rich and whoever says it first gets to go first\", 'Are you sure?', 'Very sure it is the new Donald Trump rules', \"I think that's wrong\", \"You can think whatever you want but that's doesn't make it right\", 'You must be using flawed data', \"I'm using Monopoly data\", 'well ok then I guess so']}]\n"]}]},{"cell_type":"code","source":["# Create a new dict containing only useful data\n","new_list_of_dicts = []\n","\n","for old_dict in list_of_dicts:\n"," foodict = {k: v for k, v in old_dict.items() if (k == 'turns')}\n"," new_list_of_dicts.append(foodict)\n","\n","print(len(new_list_of_dicts))\n","\n","\n","list_of_dicts = []\n","list_of_dicts = new_list_of_dicts\n","\n","print(list_of_dicts[:2])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"T3QJnx4HA8y1","executionInfo":{"status":"ok","timestamp":1687067559809,"user_tz":-330,"elapsed":26,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"c6a80430-67e1-46d3-c4ea-1a2b1afadd88"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stdout","text":["37884\n","[{'turns': ['Hello how may I help you?', 'bot, I GOT SOME QUESTIONS', 'sure, how can i help?', 'im trying to win a text adventure game called zork Can you help me win?', \"Unfortunately, I can help you play games, but I'm not programmed to help you win.\", 'nooooooooo Come on. you arent programmed to help me win at least one game?', \"I'm sorry, I'm not.\", 'ok, we will have to get that changed', 'Is there anything else I can help you with?', 'how does one play zork', \"It's a pretty complicated computer game. I can't go through every step, do you have any specific questions?\", \"nope'\"]}, {'turns': ['Hello how may I help you?', 'I need some help', 'Some help with what?', 'How to Play Catch?', 'I can help you with that.', 'okay', 'The game is technically played with 2 or more people. The participants throw an object back and forth (usually a ball) until they no longer want to play the game.', \"Wht's the rule?\", \"There really are no rules. It's not a competitive game unless you are playing a modified version of the game.\", 'okay thanks!', \"It's pretty much for fun and to practice throwing the object where you want it to go, and catching the object when it is thrown back to you.\"]}]\n"]}]},{"cell_type":"markdown","source":["#Data Augmentation and Preparation"],"metadata":{"id":"_UvxI28fBU3u"}},{"cell_type":"code","source":["# Init matrices\n","questions = []\n","answers = []\n","\n","matrix_greetings = [\"Hey\", \"Hi\"]\n","\n","\n","matrix_byes = [\"Ok\", \"Okie\", \"Bye\"]\n","\n","# For each dictionary in the list\n","for dictionary in list_of_dicts:\n"," matrix_QA = dictionary['turns']\n","\n"," # Append a first random greeting\n"," questions.append(random.choice(matrix_greetings))\n","\n"," # In order to split the QAs to 2 matrices (questions & answers),\n"," # we will use a flag to indicate if the sentence\n"," # is given from the bot or from the user\n"," bot_flag = True # Init\n","\n"," # For each Q/A in the matrix\n"," for sentence in matrix_QA:\n","\n"," if bot_flag == True:\n"," answers.append(sentence) # Used for bot's answers\n"," bot_flag = False # Switch\n"," continue\n"," else:\n"," questions.append(sentence) # Used for user's questions\n"," bot_flag = True # Switch\n"," continue\n","\n","\n"," if bot_flag == True:\n"," answers.append(random.choice(matrix_byes))\n"],"metadata":{"id":"6i3YjVNgBN9B","executionInfo":{"status":"ok","timestamp":1687067559810,"user_tz":-330,"elapsed":18,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":9,"outputs":[]},{"cell_type":"code","source":["assert len(questions) == len(answers), \"ERROR: The length of the questions and answer matrices are different.\"\n","\n","print(len(questions))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1tfiD1Z4CyiH","executionInfo":{"status":"ok","timestamp":1687067559811,"user_tz":-330,"elapsed":17,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"07267911-42bb-45df-9bc8-4c47615c880c"},"execution_count":10,"outputs":[{"output_type":"stream","name":"stdout","text":["238051\n"]}]},{"cell_type":"code","source":["\"\"\"\n"," Write to tsv file so we just load this each time\n","\"\"\"\n","import csv\n","\n","filepath_to_save = '/content/drive/My Drive/Colab Notebooks/ChatBot/output.tsv' # Change accordingly\n","with open(filepath_to_save, 'wt') as out_file:\n"," # Instantiate object\n"," tsv_writer = csv.writer(out_file, delimiter='\\t')\n","\n"," # Loop QAs & write to file\n"," for i in range(len(questions)):\n"," tsv_writer.writerow([questions[i], answers[i]])"],"metadata":{"id":"-aGsHzzzJ0YM","executionInfo":{"status":"ok","timestamp":1687067561352,"user_tz":-330,"elapsed":1550,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","source":["####\n","SOS_TOKEN = 0 # Start of sentence\n","EOS_TOKEN = 1 # End of sentence\n","\n","\n","class QA_Lang:\n"," \"\"\"\n"," # The constructor should be specified by its:\n"," # - word2index, a dictionary that maps each word to each index\n"," # - index2word, a dictionary that maps each index to each word\n"," # - n_words, the number of words in the dictionary\n"," \"\"\"\n"," def __init__(self):\n"," self.word2index = {}\n"," self.index2word = {0: 'SOS', 1: 'EOS'} # Reserved for start and end token\n"," self.n_words = 2 # Initialize with start and end token\n","\n"," # Use each sentence and instantiate the class properties\n"," def add_sentence(self, sentence):\n"," for word in sentence.split(' '): # For each word in the sentence\n"," if word not in self.word2index: # If word is not seen\n"," # Add new word\n"," self.word2index[word] = self.n_words\n"," self.index2word[self.n_words] = word\n"," self.n_words += 1\n","\n"],"metadata":{"id":"u-xGnkaNMpvU","executionInfo":{"status":"ok","timestamp":1687067561353,"user_tz":-330,"elapsed":25,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":12,"outputs":[]},{"cell_type":"code","source":["# Preprocessing helper function\n","def preprocess_text(sentence):\n"," \"\"\"\n"," Preprocesses text to lowercase ASCII alphabet-only characters\n"," without punctuation\n"," \"\"\"\n","\n"," # Convert sentence to lowercase, after removing whitespaces\n"," sentence = sentence.lower().strip()\n","\n"," # Convert Unicode string to plain ASCII characters\n"," normalized_sentence = [c for c in unicodedata.normalize('NFD', sentence) if\n"," unicodedata.category(c) != 'Mn']\n","\n"," # Append the normalized sentence\n"," sentence = ''\n"," sentence = ''.join(normalized_sentence)\n","\n"," # Remove punctuation and non-alphabet characters\n"," sentence = re.sub(r\"([.!?])\", r\" \\1\", sentence)\n"," sentence = re.sub(r\"[^a-zA-Z.!?]+\", r\" \", sentence)\n","\n"," return sentence"],"metadata":{"id":"UJDBzgJENHCH","executionInfo":{"status":"ok","timestamp":1687067561353,"user_tz":-330,"elapsed":22,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":13,"outputs":[]},{"cell_type":"code","source":["print(os.getcwd())"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MKb0028fXFdw","executionInfo":{"status":"ok","timestamp":1687067561354,"user_tz":-330,"elapsed":21,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"7bf4b25d-4a32-486a-d4c0-b476b2383f41"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["/content/drive/My Drive/Colab Notebooks/ChatBot\n"]}]},{"cell_type":"code","source":["# Reading helper function\n","def readQA():\n"," \"\"\"\n"," Reads the tab-separated data from the storage and cleans it\n"," \"\"\"\n","\n"," print('Reading lines from file...')\n","\n"," data_path = os.getcwd() + \"/data/dataset.tsv\" # Change to your own\n"," lines = open(data_path, encoding='utf-8').read().strip().split('\\n')\n","\n"," # Split lines into pairs, normalize\n"," TAB_CHARACTER = '\\t'\n","\n"," pairs = [[preprocess_text(sentence) \\\n"," for sentence in line.split(TAB_CHARACTER)] \\\n"," for line in lines]\n","\n"," '''\n"," # Find maximum length of pairs\n"," count1 = count2 = 0\n"," max_words = 0\n"," for i in range(len(pairs)):\n"," count1 = len(pairs[i][0].split())\n"," count2 = len(pairs[i][1].split())\n"," result = count1 + count2\n"," if result > max_words:\n"," max_words = result\n","\n"," print(max_words) # 304\n"," '''\n","\n"," questions = QA_Lang()\n"," answers = QA_Lang()\n","\n"," return questions, answers, pairs\n"],"metadata":{"id":"kMSQ2PpTXiwx","executionInfo":{"status":"ok","timestamp":1687067561355,"user_tz":-330,"elapsed":19,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":15,"outputs":[]},{"cell_type":"code","source":["MAX_LENGTH = 25\n","\n","def filter(pairs):\n"," \"\"\"\n"," Filters sentences based on the max length defined above.\n"," \"\"\"\n"," new_pairs = []\n","\n"," for pair in pairs:\n"," question_length = len(pair[0].split(' '))\n"," answer_length = len(pair[1].split(' '))\n","\n"," if question_length < MAX_LENGTH and answer_length < MAX_LENGTH:\n"," new_pairs.append(pair)\n","\n"," return new_pairs"],"metadata":{"id":"1NQB8OFRZJT2","executionInfo":{"status":"ok","timestamp":1687067561355,"user_tz":-330,"elapsed":17,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":16,"outputs":[]},{"cell_type":"markdown","source":["#Preparing the dataset"],"metadata":{"id":"yU8X7i_jaJ4h"}},{"cell_type":"code","source":["def prepare_data():\n"," \"\"\"\n"," Prepares the data, combining all of the above methods and returns:\n"," questions, answers objects and the pairs of sentences\n"," \"\"\"\n"," # Read sentence pairs\n"," questions, answers, pairs = readQA()\n"," print(\"Read \" + str(len(pairs)) + \" sentence pairs\")\n","\n"," # Filter pairs\n"," pairs = filter(pairs)\n"," print(\"Filtered down to \" + str(len(pairs)) + \" sentence pairs\")\n","\n"," # Count words and instantiate the 'language' objects\n"," for pair in pairs:\n"," questions.add_sentence(pair[0])\n"," answers.add_sentence(pair[1])\n","\n"," print(\"The questions object is defined by \" +\n"," str(questions.n_words) + \" words\")\n","\n"," print(\"The answers object is defined by \" +\n"," str(answers.n_words) + \" words\")\n","\n"," return questions, answers, pairs"],"metadata":{"id":"FEEdE233aD5s","executionInfo":{"status":"ok","timestamp":1687067561356,"user_tz":-330,"elapsed":17,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":17,"outputs":[]},{"cell_type":"code","source":["# Load and prepare the dataset, printing some characteristics\n","questions, answers, pairs = prepare_data()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"WsEpXBRwfXsW","executionInfo":{"status":"ok","timestamp":1687067570287,"user_tz":-330,"elapsed":8947,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"89b97181-4311-4d72-a761-6fd95d9c48e5"},"execution_count":18,"outputs":[{"output_type":"stream","name":"stdout","text":["Reading lines from file...\n","Read 238051 sentence pairs\n","Filtered down to 231229 sentence pairs\n","The questions object is defined by 18233 words\n","The answers object is defined by 20581 words\n"]}]},{"cell_type":"code","source":["# Visualize 3 random pairs of Q&A\n","for _ in range(3):\n"," print(random.choice(pairs))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"pslt4DJshbr_","executionInfo":{"status":"ok","timestamp":1687067570288,"user_tz":-330,"elapsed":19,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"00fdc3c4-a7c5-4d6e-cb1c-930b15a6d8a7"},"execution_count":19,"outputs":[{"output_type":"stream","name":"stdout","text":["['about how long does it take to deliver ?', ' minutes if you order now .']\n","['can you give me points ?', 'no its not on my role']\n","['that s outrageous ! it shouldn t cost that much !', 'sorry']\n"]}]},{"cell_type":"markdown","source":["#Neural Network: Attention based seq-2-seq Model"],"metadata":{"id":"4ArICHIrjzLy"}},{"cell_type":"code","source":["# SEQ2SEQ MODEL\n","\n","class EncoderRNN(nn.Module):\n"," \"\"\"\n"," The encoder is a GRU in our case.\n"," It takes the questions matrix as input. For each word in the\n"," sentence, it produces a vector and a hidden state; The last one\n"," will be passed to the decoder in order to initialize it.\n"," \"\"\"\n"," # Initialize encoder\n"," def __init__(self, input_size, hidden_size):\n"," super(EncoderRNN, self).__init__()\n"," self.hidden_size = hidden_size\n","\n"," # Embedding layers convert the padded sentences into appropriate vectors\n"," # The input size is equal to the questions vocabulary\n"," self.embedding = nn.Embedding(input_size, hidden_size)\n","\n","\n"," self.gru = nn.GRU(hidden_size, hidden_size)\n","\n"," # Forward passes\n"," def forward(self, input, hidden):\n"," embedded = self.embedding(input).view(1, 1, -1)\n"," output = embedded\n","\n"," # Pass the hidden state and the encoder output to the next word input\n"," output, hidden = self.gru(output, hidden)\n","\n"," return output, hidden\n","\n"," # PyTorch Forward Passes\n"," def init_hidden(self):\n"," return torch.zeros(1, 1, self.hidden_size, device=device)\n","\n","# ATTENTION-BASED DECODER\n","\"\"\"\n","Calculate a set of attention weights.\n","\n","Multiply attention weights by the encoder output vectors to create a weighted\n","combination. The result would contain information about that specific part of\n","the input sequence, and thus help the decoder choose the right output words.\n","\n","To calculate the attention weights, we'll use a feed-forward layer that uses\n","the decoder's input and hidden state as inputs.\n","\n","We will have to choose a max sentence length (input length, for encoder outputs),\n","wherein sentences of the max length will use all attention weights, while shorter\n","sentences would only use the first few.\n","\"\"\"\n","class AttnDecoderRNN(nn.Module):\n"," def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):\n"," # Initialize the constructor\n"," super(AttnDecoderRNN, self).__init__()\n"," self.hidden_size = hidden_size\n"," self.output_size = output_size\n"," self.dropout_p = dropout_p\n"," self.max_length = max_length\n","\n"," self.embedding = nn.Embedding(self.output_size, self.hidden_size)\n"," # Combine Fully Connected Layer\n"," self.attention = nn.Linear(self.hidden_size * 2, self.max_length)\n"," self.attention_combine = nn.Linear(self.hidden_size * 2,\n"," self.hidden_size)\n"," # Use dropout\n"," self.dropout = nn.Dropout(self.dropout_p)\n","\n","\n"," self.gru = nn.GRU(self.hidden_size, self.hidden_size)\n"," self.out = nn.Linear(self.hidden_size, self.output_size)\n","\n"," def forward(self, input, hidden, encoder_outputs):\n"," # Forward passes as from the repo\n"," embedded = self.embedding(input).view(1, 1, -1)\n"," embedded = self.dropout(embedded)\n","\n"," attention_weights = F.softmax(self.attention(torch.cat((embedded[0],\n"," hidden[0]), 1)),\n"," dim=1)\n","\n"," attention_applied = torch.bmm(attention_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))\n","\n"," output = torch.cat((embedded[0], attention_applied[0]), 1)\n"," output = self.attention_combine(output).unsqueeze(0)\n","\n"," # Follow with a ReLU activation function after dropout\n"," output = F.relu(output)\n","\n"," # Then, use the GRU\n"," output, hidden = self.gru(output, hidden)\n","\n"," # And use softmax as the activation function\n"," output = F.log_softmax(self.out(output[0]), dim=1)\n","\n"," return output, hidden, attention_weights\n","\n"," def init_hidden(self):\n"," return torch.zeros(1, 1, self.hidden_size, device=device)"],"metadata":{"id":"3dW0ncSfjuNB","executionInfo":{"status":"ok","timestamp":1687067570288,"user_tz":-330,"elapsed":10,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":20,"outputs":[]},{"cell_type":"code","source":["#Neural Network Preprocessing\n","\n","def tensor_from_sentence(lang, sentence):\n"," \"\"\"\n"," Given an input sentence and a 'language' object,\n"," it creates an appropriate tensor with the EOS_TOKEN in the end.\n"," \"\"\"\n","\n"," # For each sentence, get a list of the word indices\n"," indices = [lang.word2index[word] for word in sentence.split(' ')]\n"," indices.append(EOS_TOKEN) # That will help the decoder know when to stop\n","\n"," # Convert to a PyTorch tensor\n"," sentence_tensor = torch.tensor(indices, dtype=torch.long, device=device).view(-1, 1)\n","\n"," return sentence_tensor\n","\n","def tensors_from_pair(pair):\n"," \"\"\"\n"," Given our 2D dataset as a list, it calls the 'tensor_from_sentence' method\n"," and returns the appropriate input/target tensors\n"," \"\"\"\n","\n"," input_tensor = tensor_from_sentence(questions, pair[0])\n"," target_tensor = tensor_from_sentence(answers, pair[1])\n","\n"," return (input_tensor, target_tensor)"],"metadata":{"id":"wxsbYL6DnJ9U","executionInfo":{"status":"ok","timestamp":1687067570289,"user_tz":-330,"elapsed":10,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":21,"outputs":[]},{"cell_type":"code","source":["##### DISPLAY HELPERS\n","\"\"\"\n","Helper functions for printing time elapsed and estimated remaining time for\n","training.\n","\"\"\"\n","import time\n","import math\n","\n","def as_minutes(s):\n"," m = math.floor(s / 60)\n"," s -= m * 60\n","\n"," return '%dm %ds' % (m, s)\n","\n","def time_since(since, percent):\n"," now = time.time()\n"," s = now - since\n"," es = s / (percent)\n"," rs = es - s\n","\n"," return '%s (- %s)' % (as_minutes(s), as_minutes(rs))"],"metadata":{"id":"UZ5o1TS6oEjc","executionInfo":{"status":"ok","timestamp":1687067570289,"user_tz":-330,"elapsed":9,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":22,"outputs":[]},{"cell_type":"code","source":["# Training helper method\n","def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer,\n"," decoder_optimizer, criterion, max_length = MAX_LENGTH):\n"," \"\"\"\n"," This method is responsible for the NN training. Specifically:\n","\n"," - Runs input sentence through encoder\n"," - Keeps track of every output and the last hidden state\n"," - Then, the decoder is given the start of sentence token (SOS)\n"," as its first input, and the last hidden state of the encoder\n"," as its first hidden state. We also utilize teacher forcing;\n"," The decoder uses the real target outputs as each next input.\n"," - Returns the current loss\n"," \"\"\"\n","\n"," # Train one iteration\n"," encoder_hidden = encoder.init_hidden()\n","\n"," # Set gradients to zero\n"," encoder_optimizer.zero_grad()\n"," decoder_optimizer.zero_grad()\n","\n"," # Get input and target length\n"," input_length = input_tensor.size(0)\n"," target_length = target_tensor.size(0)\n","\n"," # Init outputs to a zeros array equal to MAX_LENGTH\n"," # and the encoder's latent dimensionality\n"," encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n","\n"," # Initialize the loss\n"," loss = 0\n","\n"," # Encode input\n"," for encoder_input in range(input_length):\n"," # Include hidden state from the last input when encoding current input\n"," encoder_output, encoder_hidden = encoder(input_tensor[encoder_input], encoder_hidden)\n"," encoder_outputs[encoder_input] = encoder_output[0, 0]\n","\n"," # Decoder uses SOS token as first input\n"," decoder_input = torch.tensor([[SOS_TOKEN]], device=device)\n","\n"," # Decoder uses last hidden state of encoder as first hidden state\n"," decoder_hidden = encoder_hidden\n","\n"," # Teacher forcing: Feed the actual target as the next input instead of the predicted one\n"," for d_i in range(target_length):\n"," decoder_output, decoder_hidden, decoder_attention = decoder(decoder_input,\n"," decoder_hidden,\n"," encoder_outputs)\n","\n"," loss += criterion(decoder_output, target_tensor[d_i])\n","\n"," decoder_input = target_tensor[d_i] # Teacher forcing\n","\n"," # Compute costs for each trainable parameter (dloss/dx)\n"," loss.backward()\n","\n"," # Backpropagate & update parameters\n"," encoder_optimizer.step()\n"," decoder_optimizer.step()\n","\n"," return loss.item() / target_length"],"metadata":{"id":"W_dhEa-boTNA","executionInfo":{"status":"ok","timestamp":1687067570290,"user_tz":-330,"elapsed":10,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":23,"outputs":[]},{"cell_type":"code","source":["def train_iters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\n"," \"\"\"\n"," Calls the train() method for a number of iterations.\n"," It tracks the time progress while initializing optimizers and cost function.\n"," In the same time, it creates the sets of the training pairs.\n"," \"\"\"\n","\n"," start = time.time() # Get start time\n"," print_loss_total = 0 # Reset after each print_every\n","\n"," # Set optimizers\n"," #encoder_optimizer = optim.Adam(encoder.parameters(), amsgrad = True, lr=learning_rate)\n"," #decoder_optimizer = optim.Adam(encoder.parameters(), amsgrad = True, lr=learning_rate)\n"," encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\n"," decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\n","\n"," # Shuffle the training pairs\n"," training_pairs = [tensors_from_pair(random.choice(pairs)) for i in range(n_iters)]\n","\n"," # Set the cost function\n"," criterion = nn.NLLLoss() # Also known as the multiclass cross-entropy\n","\n"," # For each iteration\n"," for i in range(1, n_iters + 1):\n"," training_pair = training_pairs[i - 1] # Create a training pair\n","\n"," # Extract input and target tensor from the pair\n"," input_tensor = training_pair[0]\n"," target_tensor = training_pair[1]\n","\n"," # Train for each pair\n"," loss = train(input_tensor, target_tensor, encoder, decoder,\n"," encoder_optimizer, decoder_optimizer, criterion)\n","\n"," print_loss_total += loss\n","\n"," # Print progress\n"," if i % print_every == 0:\n"," print_loss_avg = print_loss_total / print_every\n"," print_loss_total = 0 # Reset\n"," print('%s (%d %d%%) %.4f' % (time_since(start, i / n_iters),\n"," i, i / n_iters * 100, print_loss_avg))"],"metadata":{"id":"4FRGlMs5qlvE","executionInfo":{"status":"ok","timestamp":1687067570291,"user_tz":-330,"elapsed":9,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":24,"outputs":[]},{"cell_type":"code","source":["# TRAIN\n","hidden_size = 535\n","\n","# Instantiate Encoder and Attention Decoder\n","encoder = EncoderRNN(questions.n_words, hidden_size).to(device)\n","attention_decoder = AttnDecoderRNN(hidden_size, answers.n_words, dropout_p=0.2).to(device)\n","\n","\n","n_iters = 70000"],"metadata":{"id":"4FghGG0H0JZ3","executionInfo":{"status":"ok","timestamp":1687067579525,"user_tz":-330,"elapsed":9243,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":25,"outputs":[]},{"cell_type":"code","source":["train_iters(encoder, attention_decoder, n_iters, print_every=(n_iters//15))"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"NYgyWtlx2Nwd","executionInfo":{"status":"ok","timestamp":1687070181877,"user_tz":-330,"elapsed":2602358,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"c65355e8-9798-41cd-ad92-f6432b4a159a"},"execution_count":26,"outputs":[{"output_type":"stream","name":"stdout","text":["3m 1s (- 42m 21s) (4666 6%) 3.8682\n","5m 51s (- 38m 7s) (9332 13%) 3.2375\n","8m 45s (- 35m 0s) (13998 19%) 3.1363\n","11m 36s (- 31m 55s) (18664 26%) 2.9933\n","14m 28s (- 28m 58s) (23330 33%) 2.9400\n","17m 23s (- 26m 5s) (27996 39%) 2.9013\n","20m 18s (- 23m 13s) (32662 46%) 2.8374\n","23m 13s (- 20m 19s) (37328 53%) 2.7617\n","26m 4s (- 17m 23s) (41994 59%) 2.7336\n","28m 58s (- 14m 29s) (46660 66%) 2.7035\n","31m 52s (- 11m 35s) (51326 73%) 2.6997\n","34m 44s (- 8m 41s) (55992 79%) 2.6058\n","37m 38s (- 5m 47s) (60658 86%) 2.6540\n","40m 30s (- 2m 53s) (65324 93%) 2.6140\n","43m 21s (- 0m 0s) (69990 99%) 2.5903\n"]}]},{"cell_type":"code","source":["# Inference helper method\n","def inference(encoder, decoder, sentence, max_length=MAX_LENGTH):\n"," \"\"\"\n"," Returns the decoded string after doing a forward pass in the seq2seq model.\n"," \"\"\"\n","\n"," with torch.no_grad(): # Stop autograd from tracking history on Tensors\n","\n"," sentence = preprocess_text(sentence) # Preprocess sentence\n","\n"," input_tensor = tensor_from_sentence(questions, sentence) # One-hot tensor\n"," input_length = input_tensor.size()[0]\n","\n"," # Init encoder hidden state\n"," encoder_hidden = encoder.init_hidden()\n","\n"," # Init encoder outputs\n"," encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\n","\n"," # Forward pass in the encoder\n"," for encoder_input in range(input_length):\n"," encoder_output, encoder_hidden = encoder(input_tensor[encoder_input],\n"," encoder_hidden)\n"," encoder_outputs[encoder_input] += encoder_output[0, 0]\n","\n"," # Start of sentence token\n"," decoder_input = torch.tensor([[SOS_TOKEN]], device=device)\n","\n"," # Decoder's initial hidden state is encoder's last hidden state\n"," decoder_hidden = encoder_hidden\n","\n"," # Init the results array\n"," decoded_words = []\n","\n"," # Forward pass in the decoder\n"," for d_i in range(max_length):\n"," decoder_output, decoder_hidden, decoder_attention = decoder(\n"," decoder_input, decoder_hidden, encoder_outputs)\n","\n"," _, top_i = decoder_output.data.topk(1)\n","\n"," if top_i.item() == EOS_TOKEN: # If EOS is predicted\n"," break # Break and return the sentence to the user\n"," else:\n"," # Append prediction by using index2word\n"," decoded_words.append(answers.index2word[top_i.item()])\n","\n"," # Use prediction as input\n"," decoder_input = top_i.squeeze().detach()\n","\n"," return ' '.join(decoded_words) # Return the predicted sentence string"],"metadata":{"id":"Faot93Wy1TKk","executionInfo":{"status":"ok","timestamp":1687070181879,"user_tz":-330,"elapsed":119,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}}},"execution_count":27,"outputs":[]},{"cell_type":"code","source":["\n","print(\"Enter /q to quit\")\n","while (1):\n","\n"," user_input = input(\"User: \")\n","\n"," user_input = str(user_input)\n","\n"," if user_input == '/q':\n"," print(\"Quitting chat..\")\n"," break;\n"," else:\n"," print(\"Bot: \" + str(inference(encoder, attention_decoder, user_input)))\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_ZXAP4aVWk9P","executionInfo":{"status":"ok","timestamp":1687072629744,"user_tz":-330,"elapsed":195797,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"2e3876c9-c410-4414-bf1c-95c073dc1c63"},"execution_count":29,"outputs":[{"output_type":"stream","name":"stdout","text":["Enter /q to quit\n","User: hello\n","Bot: how may i assist you ?\n","User: can you give good coffee name ?\n","Bot: yes\n","User: where to find goof coffee ?\n","Bot: sure what kind of boat ?\n","User: is rain good for us ?\n","Bot: yes there is a cent renewal fee .\n","User: can you order ice cream ? \n","Bot: yes i can . what would you like to know ?\n","User: send to my address \n","Bot: ok\n","User: do you know about london ?\n","Bot: yes i do !\n","User: tell me shopping street name of london ?\n","Bot: \n","User: Thank you, bye bye !\n","Bot: you re welcome !\n","User: bye byee!\n","Bot: bye\n","User: /q\n","Quitting chat..\n"]}]},{"cell_type":"code","source":["encoder_name = 'encoder_serialized.pt'\n","decoder_name = 'decoder_serialized.pt'\n","\n","# Serialize the encoder/decoder objects in your local directory\n","print('Saving model...')\n","torch.save(encoder, encoder_name)\n","torch.save(attention_decoder, decoder_name)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"--LMQBrEWllD","executionInfo":{"status":"ok","timestamp":1687072646046,"user_tz":-330,"elapsed":614,"user":{"displayName":"Growth Ai","userId":"13899479888288343300"}},"outputId":"172b4e5c-5f6d-4524-dc25-7fea425976c9"},"execution_count":30,"outputs":[{"output_type":"stream","name":"stdout","text":["Saving model...\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"dPm0i8Q2XZ3W"},"execution_count":null,"outputs":[]}]} |