{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForTokenClassification\n", "from transformers import AutoTokenizer\n", "\n", "from datasets import load_dataset\n", "from pprint import pprint\n", "from collections import Counter\n", "import random\n", "import evaluate\n", "import numpy as np\n", "\n", "import os\n", "from huggingface_hub import login\n", "from transformers import TrainingArguments, Trainer\n", "from transformers import DataCollatorForTokenClassification" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Define the checkpoint and get access to the huggingface token for uploading the model to huggingface hub\n", "checkpoint = \"bert-base-cased\"\n", "os.environ[\"HF_TOKEN\"] = open(\n", " \"/home/hf/hf-course/chapter7/hf-token.txt\", \"r\").readlines()[0]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 8528\n", " })\n", " validation: Dataset({\n", " features: ['text', 'entities', 'entities-suggestion', 'entities-suggestion-metadata', 'external_id', 'metadata'],\n", " num_rows: 8528\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the dataset\n", "dataset = load_dataset(\"louisguitton/dev-ner-ontonotes\")\n", "dataset" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'entities': [],\n", " 'entities-suggestion': {'end': [30],\n", " 'label': ['PERSON'],\n", " 'score': [1.0],\n", " 'start': [23],\n", " 'text': ['Camilla']},\n", " 'entities-suggestion-metadata': {'agent': 'gold_labels',\n", " 'score': None,\n", " 'type': None},\n", " 'external_id': None,\n", " 'metadata': '{}',\n", " 'text': 'The horse is basically Camilla /.'}\n" ] } ], "source": [ "# Have a look at one sample example in the dataset\n", "pprint(dataset[\"train\"].shuffle().take(1)[0])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['O', 'B-CARDINAL', 'I-CARDINAL', 'B-DATE', 'I-DATE', 'B-EVENT', 'I-EVENT', 'B-FAC', 'I-FAC', 'B-GPE', 'I-GPE', 'B-LANGUAGE', 'I-LANGUAGE', 'B-LAW', 'I-LAW', 'B-LOC', 'I-LOC', 'B-MONEY', 'I-MONEY', 'B-NORP', 'I-NORP', 'B-ORDINAL', 'I-ORDINAL', 'B-ORG', 'I-ORG', 'B-PERCENT', 'I-PERCENT', 'B-PERSON', 'I-PERSON', 'B-PRODUCT', 'I-PRODUCT', 'B-QUANTITY', 'I-QUANTITY', 'B-TIME', 'I-TIME', 'B-WORK_OF_ART', 'I-WORK_OF_ART']\n", "Counter({'GPE': 2268, 'PERSON': 2020, 'ORG': 1740, 'DATE': 1507, 'CARDINAL': 938, 'NORP': 847, 'MONEY': 274, 'ORDINAL': 232, 'TIME': 214, 'LOC': 204, 'PERCENT': 177, 'EVENT': 143, 'WORK_OF_ART': 142, 'FAC': 115, 'QUANTITY': 100, 'PRODUCT': 72, 'LAW': 40, 'LANGUAGE': 33})\n" ] } ], "source": [ "# Have a look at the distribution of all the labels\n", "entity_types = []\n", "\n", "for element in dataset[\"train\"]:\n", " entity_types.extend(element[\"entities-suggestion\"][\"label\"])\n", "\n", "entities = sorted(set(entity_types))\n", "final_entities = [\"O\"]\n", "for entity in entities:\n", " final_entities.extend([f\"B-{entity}\", f\"I-{entity}\"])\n", "print(final_entities)\n", "print(Counter(entity_types))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Create a couple of dictionaries to map all the entities to integer ids and vice versa\n", "id2label = {i: label for i, label in enumerate(final_entities)}\n", "label2id = {v: k for k, v in id2label.items()}" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] } ], "source": [ "# Create the tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BertTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})\n" ] } ], "source": [ "# Have a look at the tokenizer\n", "pprint(tokenizer)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Tokenize one sample and check what all is returned\n", "output = tokenizer(dataset[\"train\"][0][\"text\"], return_offsets_mapping=True)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.keys()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'start': [2, 40, 53, 108, 122],\n", " 'end': [9, 45, 56, 113, 137],\n", " 'label': ['NORP', 'CARDINAL', 'CARDINAL', 'PRODUCT', 'LOC'],\n", " 'text': ['Russian', 'three', '118', 'Kursk', 'the Barents Sea'],\n", " 'score': [1.0, 1.0, 1.0, 1.0, 1.0]}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Have a look at the entities\n", "dataset[\"train\"][\"entities-suggestion\"][0]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def in_span(source_start, source_end, target_start, target_end):\n", " \"\"\"\n", " Function to check if the target span is contained within the source span\n", " \"\"\"\n", " if (target_start >= source_start) and (target_end <= source_end):\n", " return True\n", " return False\n", "\n", "\n", "def tokenize_and_create_labels(example):\n", " \"\"\"\n", " Function to tokenize the example and subsequently create labels. The labels provided will not be aligned with the tokens (after wordpiece tokenization); hence this step.\n", " \"\"\"\n", " outputs = tokenizer(\n", " example[\"text\"], truncation=True, return_offsets_mapping=True)\n", "\n", " output_labels = []\n", " n_samples = len(example[\"text\"])\n", "\n", " # Do for all the samples in the batch\n", " for i in range(n_samples):\n", " # Do not take the first and last offsets as they belong to a special token (CLS and SEP respectively)\n", " offsets = outputs[\"offset_mapping\"][i][1:-1]\n", " num_tokens = len(offsets)\n", "\n", " # Entity spans\n", " entity_starts = example[\"entities-suggestion\"][i][\"start\"]\n", " entity_ends = example[\"entities-suggestion\"][i][\"end\"]\n", "\n", " # Labels and their number\n", " text_labels = example[\"entities-suggestion\"][i][\"label\"]\n", " num_entities = len(text_labels)\n", "\n", " labels = []\n", "\n", " entities = example[\"entities-suggestion\"][i]\n", "\n", " # If there are no spans, it will all be a list of Os\n", " if len(entities[\"start\"]) == 0:\n", " labels = [label2id[\"O\"] for _ in range(num_tokens)]\n", " # Otherwise check span by span\n", " else:\n", " idx = 0\n", " source_start, source_end = entity_starts[idx], entity_ends[idx]\n", " previous_label = \"O\"\n", "\n", " for loop_idx, (start, end) in enumerate(offsets):\n", " # By default, the token is an O token\n", " lab = \"O\"\n", "\n", " # While you have not exceeded the number of identities provided\n", " if idx < num_entities:\n", " # While you have not stepped ahead of the next identity span\n", " if start > source_end:\n", " # If you have reached the end of the identities annotated, simply fill in the remainder of the tokens as O\n", " if idx == num_entities - 1:\n", " lab = \"O\"\n", " remainder = [\n", " label2id[\"O\"] for _ in range(num_tokens - loop_idx)\n", " ]\n", " labels.extend(remainder)\n", " break\n", " else:\n", " idx += 1\n", "\n", " # If the idx is refreshed, then consider new span\n", " source_start, source_end = entity_starts[idx], entity_ends[idx]\n", "\n", " # Check if current token is within the source span\n", " if in_span(source_start, source_end, start, end):\n", " # Check if the previous label was an O, if so then this one would begin with a B- else an I-\n", " lab = \"B-\" if previous_label == \"O\" else \"I-\"\n", " lab = lab + text_labels[idx]\n", " else:\n", " lab = \"O\"\n", "\n", " labels.append(label2id[lab])\n", " previous_label = lab\n", " # The first and last tokens are reserved for special words [CLS] and [SEP], hence modify their indices accordingly\n", " output_labels.append([-100] + labels + [-100])\n", " outputs[\"labels\"] = output_labels\n", "\n", " return outputs" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "tokenized_dataset = dataset.map(tokenize_and_create_labels, batched=True,\n", " remove_columns=dataset[\"train\"].column_names)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "14b7a117c7c4418aa3d0d08eb7563add", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/5 [00:00\n", " \n", " \n", " [1335/1335 09:17, Epoch 5/5]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossPrecisionRecallF1Accuracy
1No log0.1113290.7575520.7972570.7768980.968852
20.2811000.0558880.8731780.9087110.8905900.984724
30.2811000.0359790.9147010.9477700.9309420.990416
40.0630000.0274580.9333270.9600330.9464920.992793
50.0630000.0240830.9404490.9668450.9534640.993742

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n", "/home/huggingface/lib/python3.10/site-packages/seqeval/metrics/v1.py:57: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n", " _warn_prf(average, modifier, msg_start, len(result))\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=1335, training_loss=0.1388676861252231, metrics={'train_runtime': 562.8544, 'train_samples_per_second': 75.757, 'train_steps_per_second': 2.372, 'total_flos': 1425922860395136.0, 'train_loss': 0.1388676861252231, 'epoch': 5.0})" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(\n", " model=model,\n", " args=args,\n", " data_collator=data_collator,\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"validation\"],\n", " compute_metrics=compute_metrics,\n", " tokenizer=tokenizer\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "To https://huggingface.co/ElisonSherton/dev-ner-ontonote-bert-finetuned\n", " 41c8386..27067f9 main -> main\n", "\n" ] } ], "source": [ "trainer.push_to_hub(\n", " commit_message=\"🤗 Training of first BERT based NER task completed!!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 4 }