{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "3ca08817", "metadata": {}, "outputs": [], "source": [ "# !pip install seqeval" ] }, { "cell_type": "code", "execution_count": null, "id": "c5958200", "metadata": {}, "outputs": [], "source": [ "# import torch\n", "# torch.cuda.is_available(), torch.cuda.device_count()" ] }, { "cell_type": "code", "execution_count": null, "id": "590c3f48", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "import pickle\n", "import numpy as np\n", "import transformers\n", "from transformers import Trainer\n", "from datasets import load_metric\n", "from datasets import load_dataset\n", "from transformers import AutoTokenizer\n", "from transformers import TrainingArguments\n", "from transformers import AutoModelForTokenClassification\n", "from transformers import DataCollatorForTokenClassification" ] }, { "cell_type": "markdown", "id": "44d7c35c", "metadata": {}, "source": [ "## Helpful funcs " ] }, { "cell_type": "code", "execution_count": null, "id": "5c9e36d9", "metadata": {}, "outputs": [], "source": [ "def align_labels_with_tokens(labels, word_ids):\n", " return [-100 if i is None else labels[i] for i in word_ids]\n", "\n", "def tokenize_and_align_labels(examples):\n", " tokenized_inputs = tokenizer(\n", " examples[\"sequences\"], truncation=True, is_split_into_words=True\n", " )\n", " all_labels = examples[\"ids\"]\n", " new_labels = []\n", " for i, labels in enumerate(all_labels):\n", " word_ids = tokenized_inputs.word_ids(i)\n", " new_labels.append(align_labels_with_tokens(labels, word_ids))\n", "\n", " tokenized_inputs[\"labels\"] = new_labels\n", " return tokenized_inputs\n", "\n", "def compute_metrics(eval_preds):\n", " logits, labels = eval_preds\n", " predictions = np.argmax(logits, axis=-1)\n", "\n", " # Remove ignored index (special tokens) and convert to labels\n", " true_labels = [[label_names[l] for l in label if l != -100] for label in labels]\n", " true_predictions = [\n", " [label_names[p] for (p, l) in zip(prediction, label) if l != -100]\n", " for prediction, label in zip(predictions, labels)\n", " ]\n", " all_metrics = metric.compute(predictions=true_predictions, references=true_labels)\n", " return {\n", " \"precision\": all_metrics[\"overall_precision\"],\n", " \"recall\": all_metrics[\"overall_recall\"],\n", " \"f1\": all_metrics[\"overall_f1\"],\n", " \"accuracy\": all_metrics[\"overall_accuracy\"],\n", " }" ] }, { "cell_type": "markdown", "id": "8760e709", "metadata": {}, "source": [ "## Load Data" ] }, { "cell_type": "code", "execution_count": null, "id": "e8c723f7", "metadata": {}, "outputs": [], "source": [ "raw_datasets = load_dataset(\"surdan/nerel_short\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e540a898", "metadata": {}, "outputs": [], "source": [ "raw_datasets" ] }, { "cell_type": "markdown", "id": "5a4947d1", "metadata": {}, "source": [ "## Preprocess data" ] }, { "cell_type": "code", "execution_count": null, "id": "8829557e", "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"cointegrated/LaBSE-en-ru\"" ] }, { "cell_type": "code", "execution_count": null, "id": "b6c13ad1", "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "id": "ea2c1a9e", "metadata": {}, "outputs": [], "source": [ "tokenized_datasets = raw_datasets.map(\n", " tokenize_and_align_labels,\n", " batched=True,\n", " remove_columns=raw_datasets[\"train\"].column_names,\n", ")" ] }, { "cell_type": "markdown", "id": "e9b5b9b1", "metadata": {}, "source": [ "## Init Training pipeline" ] }, { "cell_type": "code", "execution_count": null, "id": "b24d86e3", "metadata": {}, "outputs": [], "source": [ "with open('id_to_label_map.pickle', 'rb') as f:\n", " map_id_to_label = pickle.load(f)" ] }, { "cell_type": "code", "execution_count": null, "id": "1d90a6d9", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": null, "id": "3d890df2", "metadata": {}, "outputs": [], "source": [ "id2label = {str(k): v for k, v in map_id_to_label.items()}\n", "label2id = {v: k for k, v in id2label.items()}\n", "label_names = list(id2label.values())" ] }, { "cell_type": "code", "execution_count": null, "id": "31bcfd6c", "metadata": {}, "outputs": [], "source": [ "model = AutoModelForTokenClassification.from_pretrained(\n", " model_checkpoint,\n", " id2label=id2label,\n", " label2id=label2id,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "84497580", "metadata": {}, "outputs": [], "source": [ "model.config.num_labels" ] }, { "cell_type": "code", "execution_count": null, "id": "1ccfbf74", "metadata": {}, "outputs": [], "source": [ "args = TrainingArguments(\n", " \"LaBSE_ner_nerel\",\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"no\",\n", " learning_rate=2e-5,\n", " num_train_epochs=25,\n", " weight_decay=0.01,\n", " push_to_hub=False,\n", " per_device_train_batch_size = 4 ## depending on the total volume of memory of your GPU\n", ")" ] }, { "cell_type": "markdown", "id": "c798d567", "metadata": {}, "source": [ "## Train model" ] }, { "cell_type": "code", "execution_count": null, "id": "1348d188", "metadata": {}, "outputs": [], "source": [ "metric = load_metric(\"seqeval\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5cff0367", "metadata": {}, "outputs": [], "source": [ "trainer = Trainer(\n", " model=model,\n", " args=args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"dev\"],\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=tokenizer,\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "id": "576a10f4", "metadata": {}, "outputs": [], "source": [ "trainer.save_model(\"LaBSE_nerel_last_checkpoint\")" ] }, { "cell_type": "code", "execution_count": null, "id": "451d6db1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "hf_env", "language": "python", "name": "hf_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }