{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForTokenClassification\n", "from transformers import AutoTokenizer\n", "\n", "from datasets import load_dataset\n", "from pprint import pprint\n", "from collections import Counter\n", "import random\n", "import evaluate\n", "import numpy as np\n", "\n", "import os\n", "from huggingface_hub import login\n", "from transformers import TrainingArguments, Trainer\n", "from transformers import DataCollatorForTokenClassification" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Define the checkpoint and get access to the huggingface token for uploading the model to huggingface hub\n", "checkpoint = \"bert-base-cased\"\n", "os.environ[\"HF_TOKEN\"] = open(\n", " \"/home/hf/hf-course/chapter7/hf-token.txt\", \"r\").readlines()[0]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 8528\n", " })\n", " validation: Dataset({\n", " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 8528\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the dataset\n", "dataset = load_dataset(\"louisguitton/dev-ner-ontonotes\")\n", "dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'entities': [],\n", " 'entities-suggestion': {'end': [30],\n", " 'label': ['PERSON'],\n", " 'score': [1.0],\n", " 'start': [23],\n", " 'text': ['Camilla']},\n", " 'entities-suggestion-metadata': {'agent': 'gold_labels',\n", " 'score': None,\n", " 'type': None},\n", " 'external_id': None,\n", " 'metadata': '{}',\n", " 'text': 'The horse is basically Camilla /.'}\n" ] } ], "source": [ "# Have a look at one sample example in the dataset\n", "pprint(dataset[\"train\"].shuffle().take(1)[0])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'B-CARDINAL', 'I-CARDINAL', 'B-DATE', 'I-DATE', 'B-EVENT', 'I-EVENT', 'B-FAC', 'I-FAC', 'B-GPE', 'I-GPE', 'B-LANGUAGE', 'I-LANGUAGE', 'B-LAW', 'I-LAW', 'B-LOC', 'I-LOC', 'B-MONEY', 'I-MONEY', 'B-NORP', 'I-NORP', 'B-ORDINAL', 'I-ORDINAL', 'B-ORG', 'I-ORG', 'B-PERCENT', 'I-PERCENT', 'B-PERSON', 'I-PERSON', 'B-PRODUCT', 'I-PRODUCT', 'B-QUANTITY', 'I-QUANTITY', 'B-TIME', 'I-TIME', 'B-WORK_OF_ART', 'I-WORK_OF_ART']\n", "Counter({'GPE': 2268, 'PERSON': 2020, 'ORG': 1740, 'DATE': 1507, 'CARDINAL': 938, 'NORP': 847, 'MONEY': 274, 'ORDINAL': 232, 'TIME': 214, 'LOC': 204, 'PERCENT': 177, 'EVENT': 143, 'WORK_OF_ART': 142, 'FAC': 115, 'QUANTITY': 100, 'PRODUCT': 72, 'LAW': 40, 'LANGUAGE': 33})\n" ] } ], "source": [ "# Have a look at the distribution of all the labels\n", "entity_types = []\n", "\n", "for element in dataset[\"train\"]:\n", " entity_types.extend(element[\"entities-suggestion\"][\"label\"])\n", "\n", "entities = sorted(set(entity_types))\n", "final_entities = [\"O\"]\n", "for entity in entities:\n", " final_entities.extend([f\"B-{entity}\", f\"I-{entity}\"])\n", "print(final_entities)\n", "print(Counter(entity_types))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Create a couple of dictionaries to map all the entities to integer ids and vice versa\n", "id2label = {i: label for i, label in enumerate(final_entities)}\n", "label2id = {v: k for k, v in id2label.items()}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "# Create the tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\n" ] } ], "source": [ "# Have a look at the tokenizer\n", "pprint(tokenizer)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Tokenize one sample and check what all is returned\n", "output = tokenizer(dataset[\"train\"][0][\"text\"], return_offsets_mapping=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.keys()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'start': [2, 40, 53, 108, 122],\n", " 'end': [9, 45, 56, 113, 137],\n", " 'label': ['NORP', 'CARDINAL', 'CARDINAL', 'PRODUCT', 'LOC'],\n", " 'text': ['Russian', 'three', '118', 'Kursk', 'the Barents Sea'],\n", " 'score': [1.0, 1.0, 1.0, 1.0, 1.0]}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Have a look at the entities\n", "dataset[\"train\"][\"entities-suggestion\"][0]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def in_span(source_start, source_end, target_start, target_end):\n", " \"\"\"\n", " Function to check if the target span is contained within the source span\n", " \"\"\"\n", " if (target_start >= source_start) and (target_end <= source_end):\n", " return True\n", " return False\n", "\n", "\n", "def tokenize_and_create_labels(example):\n", " \"\"\"\n", " Function to tokenize the example and subsequently create labels. The labels provided will not be aligned with the tokens (after wordpiece tokenization); hence this step.\n", " \"\"\"\n", " outputs = tokenizer(\n", " example[\"text\"], truncation=True, return_offsets_mapping=True)\n", "\n", " output_labels = []\n", " n_samples = len(example[\"text\"])\n", "\n", " # Do for all the samples in the batch\n", " for i in range(n_samples):\n", " # Do not take the first and last offsets as they belong to a special token (CLS and SEP respectively)\n", " offsets = outputs[\"offset_mapping\"][i][1:-1]\n", " num_tokens = len(offsets)\n", "\n", " # Entity spans\n", " entity_starts = example[\"entities-suggestion\"][i][\"start\"]\n", " entity_ends = example[\"entities-suggestion\"][i][\"end\"]\n", "\n", " # Labels and their number\n", " text_labels = example[\"entities-suggestion\"][i][\"label\"]\n", " num_entities = len(text_labels)\n", "\n", " labels = []\n", "\n", " entities = example[\"entities-suggestion\"][i]\n", "\n", " # If there are no spans, it will all be a list of Os\n", " if len(entities[\"start\"]) == 0:\n", " labels = [label2id[\"O\"] for _ in range(num_tokens)]\n", " # Otherwise check span by span\n", " else:\n", " idx = 0\n", " source_start, source_end = entity_starts[idx], entity_ends[idx]\n", " previous_label = \"O\"\n", "\n", " for loop_idx, (start, end) in enumerate(offsets):\n", " # By default, the token is an O token\n", " lab = \"O\"\n", "\n", " # While you have not exceeded the number of identities provided\n", " if idx < num_entities:\n", " # While you have not stepped ahead of the next identity span\n", " if start > source_end:\n", " # If you have reached the end of the identities annotated, simply fill in the remainder of the tokens as O\n", " if idx == num_entities - 1:\n", " lab = \"O\"\n", " remainder = [\n", " label2id[\"O\"] for _ in range(num_tokens - loop_idx)\n", " ]\n", " labels.extend(remainder)\n", " break\n", " else:\n", " idx += 1\n", "\n", " # If the idx is refreshed, then consider new span\n", " source_start, source_end = entity_starts[idx], entity_ends[idx]\n", "\n", " # Check if current token is within the source span\n", " if in_span(source_start, source_end, start, end):\n", " # Check if the previous label was an O, if so then this one would begin with a B- else an I-\n", " lab = \"B-\" if previous_label == \"O\" else \"I-\"\n", " lab = lab + text_labels[idx]\n", " else:\n", " lab = \"O\"\n", "\n", " labels.append(label2id[lab])\n", " previous_label = lab\n", " # The first and last tokens are reserved for special words [CLS] and [SEP], hence modify their indices accordingly\n", " output_labels.append([-100] + labels + [-100])\n", " outputs[\"labels\"] = output_labels\n", "\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "tokenized_dataset = dataset.map(tokenize_and_create_labels, batched=True,\n", " remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14b7a117c7c4418aa3d0d08eb7563add", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/5 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Create a sample of 5 items for the sake of visualization\n", "samples = dataset[\"train\"].shuffle(seed=43).take(5).map(\n", " tokenize_and_create_labels, batched=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[CLS] An easy but rare maneuver with extraordinary consequences / . [SEP] \n", "SPECIAL O O O O O O O O O O SPECIAL \n", "Number of tokens: 12, Number of Labels: 12\n", "Entities Annotated: {'start': [], 'end': [], 'label': [], 'text': [], 'score': []}\n" ] } ], "source": [ "# Visualize a few samples from the dataset randomly\n", "idx = random.randint(0, len(samples))\n", "\n", "ip_tokens = [tokenizer.decode([x]) for x in samples[idx][\"input_ids\"]]\n", "labels = samples[idx][\"labels\"]\n", "\n", "token_op, lbl_op = \"\", \"\"\n", "for token, lbl in zip(ip_tokens, labels):\n", " lbl = id2label.get(lbl, \"SPECIAL\")\n", " l = max(len(token), len(lbl)) + 2\n", " token_op += f\"{token:<{l}}\"\n", " lbl_op += f\"{lbl:<{l}}\"\n", "\n", "print(token_op)\n", "print(lbl_op)\n", "print(f\"Number of tokens: {len(ip_tokens)}, Number of Labels: {len(labels)}\")\n", "print(\"Entities Annotated: \", samples[idx][\"entities-suggestion\"])" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# We need to remove the offset mappings as it would not be possible to colalte data without dropping this column\n", "tokenized_dataset = tokenized_dataset.remove_columns(\n", " column_names=[\"offset_mapping\"])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/plain": [ "tensor([[-100, 0, 19, 0, 0, 0, 0, 0, 0, 1, 0, 0,\n", " 1, 0, 0, 0, 0, 0, 0, 0, 0, 29, 30, 0,\n", " 0, 15, 16, 16, 16, 0, -100],\n", " [-100, 0, 0, 0, 0, 0, 0, 0, 19, 0, 19, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, -100, -100, -100, -100, -100]])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a data collator to apply padding as and when necessary and have a look at the working of the same\n", "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)\n", "batch = data_collator([tokenized_dataset[\"train\"][i] for i in range(2)])\n", "batch[\"labels\"]" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "metric = evaluate.load(\"seqeval\")\n", "\n", "def compute_metrics(eval_preds):\n", " logits, labels = eval_preds\n", "\n", " # Get the most probable token prediction\n", " predictions = np.argmax(logits, axis=-1)\n", "\n", " # Remove ignored index (special tokens) and convert to labels\n", " true_labels, true_predictions = [], []\n", " for prediction, label in zip(predictions, labels):\n", " current_prediction, current_label = [], []\n", " for p, l in zip(prediction, label):\n", " if l != -100:\n", " current_label.append(id2label[l])\n", " current_prediction.append(id2label[p])\n", " true_labels.append(current_label)\n", " true_predictions.append(current_prediction)\n", "\n", " # Compute the metrics using above predictions and labels\n", " all_metrics = metric.compute(\n", " predictions=true_predictions, references=true_labels)\n", "\n", " # Return the overall metrics and not individual level metrics\n", " return {\n", " \"precision\": all_metrics[\"overall_precision\"],\n", " \"recall\": all_metrics[\"overall_recall\"],\n", " \"f1\": all_metrics[\"overall_f1\"],\n", " \"accuracy\": all_metrics[\"overall_accuracy\"],\n", " }" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n", "- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "# Create a model for token classification on top of pretrained BERT model\n", "model = AutoModelForTokenClassification.from_pretrained(\n", " checkpoint,\n", " id2label=id2label,\n", " label2id=label2id\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Linear(in_features=768, out_features=37, bias=True)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Check the classifier architecture\n", "model.classifier" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(37, 37, 37)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Have a look at the number of labels, the number of ids created for those labels and the number of activations in the final layer of the model\n", "model.config.num_labels, len(label2id), len(id2label)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", "Token is valid (permission: write).\n", "Your token has been saved to /home/.cache/huggingface/token\n", "Login successful\n" ] } ], "source": [ "# Login to huggingface for uploading the generated model\n", "login(token=os.environ.get(\"HF_TOKEN\"))" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "args = TrainingArguments(\n", " \"dev-ner-ontonote-bert-finetuned\",\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " num_train_epochs=5,\n", " weight_decay=0.01,\n", " push_to_hub=True,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=32\n", ")" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n", " num_rows: 8528\n", " })\n", " validation: Dataset({\n", " features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],\n", " num_rows: 8528\n", " })\n", "})" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:131: FutureWarning: 'Repository' (from 'huggingface_hub.repository') is deprecated and will be removed from version '1.0'. Please prefer the http-based alternatives instead. Given its large adoption in legacy code, the complete removal is only planned on next major release.\n", "For more details, please read https://huggingface.co/docs/huggingface_hub/concepts/git_vs_http.\n", " warnings.warn(warning_message, FutureWarning)\n", "/home/hf/hf-course/chapter7/dev-ner-ontonote-bert-finetuned is already a clone of https://huggingface.co/ElisonSherton/dev-ner-ontonote-bert-finetuned. Make sure you pull the latest changes with `repo.git_pull()`.\n", "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Precision | \n", "Recall | \n", "F1 | \n", "Accuracy | \n", "
---|---|---|---|---|---|---|
1 | \n", "No log | \n", "0.111329 | \n", "0.757552 | \n", "0.797257 | \n", "0.776898 | \n", "0.968852 | \n", "
2 | \n", "0.281100 | \n", "0.055888 | \n", "0.873178 | \n", "0.908711 | \n", "0.890590 | \n", "0.984724 | \n", "
3 | \n", "0.281100 | \n", "0.035979 | \n", "0.914701 | \n", "0.947770 | \n", "0.930942 | \n", "0.990416 | \n", "
4 | \n", "0.063000 | \n", "0.027458 | \n", "0.933327 | \n", "0.960033 | \n", "0.946492 | \n", "0.992793 | \n", "
5 | \n", "0.063000 | \n", "0.024083 | \n", "0.940449 | \n", "0.966845 | \n", "0.953464 | \n", "0.993742 | \n", "
"
],
"text/plain": [
"