{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyOjIcXwclnDxtt6VzX+P9Fq", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "widgets": { "application/vnd.jupyter.widget-state+json": { "f39f11d1075547bd81c2a71ab4e9d056": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3c55a93ff9394478b696ad604223c406", "IPY_MODEL_ccf1f54ee22249028e2f8cc6d3397079", "IPY_MODEL_5743023322c2401bbea11750794498aa" ], "layout": "IPY_MODEL_593f36aefdb249608bee030b9905e2b0" } }, "3c55a93ff9394478b696ad604223c406": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a820159a6d814a74a1ab3013521a41b4", "placeholder": "", "style": "IPY_MODEL_d894666cefde4a6cb4864f9d6aecfdb8", "value": "100%" } }, "ccf1f54ee22249028e2f8cc6d3397079": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_cfc6ee0c109d428e96ac3778902104ab", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_095602ae14d141c9b74296bc18b390f7", "value": 1 } }, "5743023322c2401bbea11750794498aa": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_caac98cc706143d689e7eaddb25be4b4", "placeholder": "", "style": "IPY_MODEL_88a7cf382ff64b78bbd440e3f82e02ea", "value": " 1/1 [00:00<00:00, 34.69it/s]" } }, "593f36aefdb249608bee030b9905e2b0": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a820159a6d814a74a1ab3013521a41b4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d894666cefde4a6cb4864f9d6aecfdb8": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "cfc6ee0c109d428e96ac3778902104ab": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "095602ae14d141c9b74296bc18b390f7": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "caac98cc706143d689e7eaddb25be4b4": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "88a7cf382ff64b78bbd440e3f82e02ea": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "f2096bde6a524ff9a87c11c118a2c2ee": { "model_module": "@jupyter-widgets/controls", "model_name": "HBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_ae005c5b1ed6445097f27c9eadb44d00", "IPY_MODEL_9fa8126bacd749a3b374996c0bee7891", "IPY_MODEL_9558b9c30ce348928989e241942aecef" ], "layout": "IPY_MODEL_bd9374b595da45498874f2cc62cad206" } }, "ae005c5b1ed6445097f27c9eadb44d00": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_437cfda1d7ac4731a4673735ae8072e1", "placeholder": "", "style": "IPY_MODEL_a7aa8c0ce09c4ee4bbccb2720b0ad4b4", "value": " 0%" } }, "9fa8126bacd749a3b374996c0bee7891": { "model_module": "@jupyter-widgets/controls", "model_name": "FloatProgressModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "danger", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_ce8e7aa1b6154f78931aa32fd38c032f", "max": 1, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_dbd56444ab4a4d1799c5f164fae61e54", "value": 0 } }, "9558b9c30ce348928989e241942aecef": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_0a55eb4a977044a891795ffe23adfb28", "placeholder": "", "style": "IPY_MODEL_5d0e9fb3a94a4f738545a84d470c846f", "value": " 0/1 [00:00<?, ?ba/s]" } }, "bd9374b595da45498874f2cc62cad206": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "437cfda1d7ac4731a4673735ae8072e1": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a7aa8c0ce09c4ee4bbccb2720b0ad4b4": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "ce8e7aa1b6154f78931aa32fd38c032f": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "dbd56444ab4a4d1799c5f164fae61e54": { "model_module": "@jupyter-widgets/controls", "model_name": "ProgressStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "0a55eb4a977044a891795ffe23adfb28": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "5d0e9fb3a94a4f738545a84d470c846f": { "model_module": "@jupyter-widgets/controls", "model_name": "DescriptionStyleModel", "model_module_version": "1.5.0", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } } } } }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Ww6WXCvy-0lg", "outputId": "118a86a5-ed10-4970-f6c0-c6f1b76ff48b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: transformers in /usr/local/lib/python3.7/dist-packages (4.23.1)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.7/dist-packages (2.6.1)\n", "Requirement already satisfied: evaluate in /usr/local/lib/python3.7/dist-packages (0.3.0)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.64.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.6)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.10.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.7/dist-packages (from transformers) (0.13.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2022.6.2)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.8.0)\n", "Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.13.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.1.1)\n", "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.9)\n", "Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.7/dist-packages (from datasets) (0.70.13)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.7/dist-packages (from datasets) (3.1.0)\n", "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.18.0)\n", "Requirement already satisfied: dill<0.3.6 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.5.1)\n", "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.8.2)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.3)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n", "Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (0.13.0)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (22.1.0)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (4.0.2)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.8.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (6.0.2)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (1.2.0)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp->datasets) (2.1.1)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2022.9.24)\n", "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n", "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n", "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.25.11)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.9.0)\n", "Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2022.4)\n", "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n" ] } ], "source": [ "!pip install transformers==4.22.1 datasets==2.5.2 evaluate==0.2.2" ] }, { "cell_type": "code", "source": [ "import torch\n", "import numpy as np\n", "\n", "\n", "# 1. Dataset\n", "from datasets import load_dataset\n", "dataset = load_dataset(\"Adapting/abstract-keyphrases\")\n", "\n", "# 2. Model\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Adapting/KeyBartAdapter\")\n", "\n", "model = AutoModelForSeq2SeqLM.from_pretrained(\"Adapting/KeyBartAdapter\", revision = '9c3ed39c6ed5c7e141363e892d77cf8f589d5999')\n", "\n", "\n", "# 3. preprocess dataset\n", "dataset = dataset.shuffle()\n", "def preprocess_function(examples):\n", " inputs = examples['Abstract']\n", " targets = examples['Keywords']\n", " model_inputs = tokenizer(inputs, truncation=True)\n", "\n", " # Set up the tokenizer for targets\n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer(targets, truncation=True)\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs\n", "\n", "tokenized_dataset = dataset.map(\n", " preprocess_function,\n", " batched=True,\n", " remove_columns=dataset[\"train\"].column_names,\n", ")\n", "\n", "# 4. evaluation metrics\n", "def compute_metrics(eval_preds):\n", " preds = eval_preds.predictions\n", " labels = eval_preds.label_ids\n", " if isinstance(preds, tuple):\n", " preds = preds[0]\n", " print(preds.shape)\n", " if len(preds.shape) == 3:\n", " preds = preds.argmax(axis=-1)\n", " \n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", " # Replace -100 in the labels as we can't decode them.\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Some simple post-processing\n", " decoded_preds = [a.strip().split(';') for a in decoded_preds]\n", " decoded_labels = [a.strip().split(';') for a in decoded_labels]\n", "\n", "\n", " precs, recalls, f_scores = [], [], []\n", " num_match, num_pred, num_gold = [], [], []\n", " for pred, label in zip(decoded_preds, decoded_labels):\n", " pred_set = set(pred)\n", " label_set = set(label)\n", " match_set = label_set.intersection(pred_set)\n", " p = float(len(match_set)) / float(len(pred_set)) if len(pred_set) > 0 else 0.0\n", " r = float(len(match_set)) / float(len(label_set)) if len(label_set) > 0 else 0.0\n", " f1 = float(2 * (p * r)) / (p + r) if (p + r) > 0 else 0.0\n", " precs.append(p)\n", " recalls.append(r)\n", " f_scores.append(f1)\n", " num_match.append(len(match_set))\n", " num_pred.append(len(pred_set))\n", " num_gold.append(len(label_set))\n", " \n", " # print(f'raw_PRED: {raw_pred}')\n", " print(f'PRED: num={len(pred_set)} - {pred_set}')\n", " print(f'GT: num={len(label_set)} - {label_set}')\n", " print(f'p={p}, r={r}, f1={f1}')\n", " print('-' * 20)\n", "\n", " result = {\n", " 'precision@M': np.mean(precs) * 100.0,\n", " 'recall@M': np.mean(recalls) * 100.0,\n", " 'fscore@M': np.mean(f_scores) * 100.0,\n", " 'num_match': np.mean(num_match),\n", " 'num_pred': np.mean(num_pred),\n", " 'num_gold': np.mean(num_gold),\n", " }\n", "\n", " result = {k: round(v, 2) for k, v in result.items()}\n", " return result\n", "\n", "# 5. train\n", "from transformers import DataCollatorForSeq2Seq,Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "\n", "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n", "\n", "model_name = 'KeyBartAdapter'\n", "num_epoch = 30\n", "\n", "args = Seq2SeqTrainingArguments(\n", " model_name,\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=4,\n", " per_device_eval_batch_size=4,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=num_epoch,\n", " logging_steps=4,\n", " load_best_model_at_end=True,\n", " metric_for_best_model='fscore@M',\n", " predict_with_generate=True,\n", " fp16=torch.cuda.is_available(), # speeds up training on modern GPUs.\n", " # eval_accumulation_steps=10,\n", ")\n", "\n", "trainer = Seq2SeqTrainer(\n", " model,\n", " args,\n", " train_dataset=tokenized_dataset[\"train\"],\n", " eval_dataset=tokenized_dataset[\"train\"],\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", ")\n", "\n", "trainer.train()\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "f39f11d1075547bd81c2a71ab4e9d056", "3c55a93ff9394478b696ad604223c406", "ccf1f54ee22249028e2f8cc6d3397079", "5743023322c2401bbea11750794498aa", "593f36aefdb249608bee030b9905e2b0", "a820159a6d814a74a1ab3013521a41b4", "d894666cefde4a6cb4864f9d6aecfdb8", "cfc6ee0c109d428e96ac3778902104ab", "095602ae14d141c9b74296bc18b390f7", "caac98cc706143d689e7eaddb25be4b4", "88a7cf382ff64b78bbd440e3f82e02ea", "f2096bde6a524ff9a87c11c118a2c2ee", "ae005c5b1ed6445097f27c9eadb44d00", "9fa8126bacd749a3b374996c0bee7891", "9558b9c30ce348928989e241942aecef", "bd9374b595da45498874f2cc62cad206", "437cfda1d7ac4731a4673735ae8072e1", "a7aa8c0ce09c4ee4bbccb2720b0ad4b4", "ce8e7aa1b6154f78931aa32fd38c032f", "dbd56444ab4a4d1799c5f164fae61e54", "0a55eb4a977044a891795ffe23adfb28", "5d0e9fb3a94a4f738545a84d470c846f" ] }, "id": "OYPmfKRY-6tC", "outputId": "eaf77000-920e-4aec-e38c-1f9f3f3f4c1c" }, "execution_count": null, "outputs": [ { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "WARNING:datasets.builder:Using custom data configuration Adapting--abstract-keyphrases-4811abd1e624c6b0\n", "WARNING:datasets.builder:Found cached dataset csv (/root/.cache/huggingface/datasets/Adapting___csv/Adapting--abstract-keyphrases-4811abd1e624c6b0/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f39f11d1075547bd81c2a71ab4e9d056", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at Adapting/KeyBartAdapter were not used when initializing BartForConditionalGeneration: ['model.decoder.decoder.layers.11.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.5.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.11.self_attn.out_proj.weight', 'model.decoder.decoder.layers.2.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.9.fc1.weight', 'model.decoder.decoder.layers.8.self_attn.out_proj.weight', 'model.decoder.decoder.layers.3.self_attn.out_proj.weight', 'model.decoder.adapters.7.up_proj.weight', 'model.decoder.decoder.layers.0.self_attn.q_proj.bias', 'model.decoder.decoder.layers.1.final_layer_norm.bias', 'model.decoder.decoder.layers.5.self_attn.q_proj.weight', 'model.decoder.decoder.layers.11.fc2.weight', 'model.decoder.decoder.layers.3.self_attn.q_proj.weight', 'model.decoder.decoder.layers.6.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.3.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.11.self_attn.v_proj.weight', 'model.decoder.decoder.layers.1.self_attn.out_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.2.self_attn.k_proj.bias', 'model.decoder.decoder.layers.9.final_layer_norm.weight', 'model.decoder.adapters.11.down_proj.weight', 'model.decoder.decoder.layers.11.self_attn.q_proj.weight', 'model.decoder.decoder.layers.11.fc2.bias', 'model.decoder.decoder.layers.11.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.10.final_layer_norm.weight', 'model.decoder.decoder.layers.0.self_attn_layer_norm.weight', 'model.decoder.adapters.5.up_proj.weight', 'model.decoder.decoder.layers.7.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.6.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.11.self_attn.k_proj.weight', 'model.decoder.decoder.layers.3.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.9.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.9.self_attn.v_proj.weight', 'model.decoder.adapters.8.down_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.4.fc1.weight', 'model.decoder.decoder.layers.7.self_attn.out_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.1.self_attn.k_proj.bias', 'model.decoder.decoder.layers.6.self_attn.v_proj.weight', 'model.decoder.decoder.layernorm_embedding.weight', 'model.decoder.decoder.layers.5.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.0.fc2.weight', 'model.decoder.decoder.layers.5.fc1.weight', 'model.decoder.decoder.layers.3.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.7.self_attn.out_proj.bias', 'model.decoder.decoder.layers.9.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.4.self_attn.v_proj.weight', 'model.decoder.decoder.layers.4.self_attn.q_proj.weight', 'model.decoder.adapters.6.down_proj.weight', 'model.decoder.decoder.layers.3.fc2.weight', 'model.decoder.decoder.layers.10.self_attn.out_proj.bias', 'model.decoder.decoder.layers.11.self_attn.out_proj.bias', 'model.decoder.decoder.layers.4.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.2.fc1.bias', 'model.decoder.decoder.layers.4.self_attn.out_proj.weight', 'model.decoder.decoder.layers.1.self_attn.v_proj.weight', 'model.decoder.decoder.layers.2.final_layer_norm.bias', 'model.decoder.decoder.layers.11.fc1.weight', 'model.decoder.adapters.10.layerNorm.weight', 'model.decoder.decoder.layers.3.encoder_attn.q_proj.weight', 'model.decoder.adapters.1.up_proj.weight', 'model.decoder.decoder.layers.10.self_attn.q_proj.bias', 'model.decoder.decoder.layers.6.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.2.self_attn.q_proj.bias', 'model.decoder.decoder.layers.8.fc1.weight', 'model.decoder.decoder.layers.8.self_attn.q_proj.bias', 'model.decoder.decoder.layers.5.encoder_attn.out_proj.bias', 'model.decoder.adapters.2.layerNorm.bias', 'model.decoder.decoder.layers.5.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.7.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.8.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.8.self_attn.q_proj.weight', 'model.decoder.decoder.layers.3.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.11.self_attn.q_proj.bias', 'model.decoder.decoder.layers.7.self_attn.k_proj.bias', 'model.decoder.adapters.9.layerNorm.bias', 'model.decoder.decoder.layers.6.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.6.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.4.final_layer_norm.weight', 'model.decoder.decoder.layers.2.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.6.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.0.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.8.self_attn.k_proj.weight', 'model.decoder.decoder.layers.10.encoder_attn_layer_norm.bias', 'model.decoder.adapters.4.up_proj.weight', 'model.decoder.decoder.layers.9.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.2.final_layer_norm.weight', 'model.decoder.decoder.layers.6.fc2.bias', 'model.decoder.decoder.layers.11.encoder_attn_layer_norm.weight', 'model.decoder.adapters.4.layerNorm.weight', 'model.decoder.decoder.layers.9.fc1.bias', 'model.decoder.decoder.layers.2.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.4.self_attn.v_proj.bias', 'model.decoder.decoder.layers.3.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.0.self_attn.k_proj.bias', 'model.decoder.decoder.layers.8.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.10.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.2.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.6.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.11.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.8.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.8.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.11.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.2.fc2.bias', 'model.decoder.decoder.layers.8.self_attn.v_proj.bias', 'model.decoder.decoder.layers.9.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.2.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.10.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.7.fc1.bias', 'model.decoder.decoder.layers.0.final_layer_norm.weight', 'model.decoder.decoder.layers.9.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.1.fc1.bias', 'model.decoder.decoder.layers.9.self_attn.q_proj.weight', 'model.decoder.decoder.layers.6.self_attn.v_proj.bias', 'model.decoder.decoder.layers.6.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.8.final_layer_norm.bias', 'model.decoder.decoder.layers.5.fc2.weight', 'model.decoder.decoder.layers.9.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.7.self_attn.q_proj.weight', 'model.decoder.decoder.layers.10.fc2.weight', 'model.decoder.decoder.layers.10.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.11.self_attn.k_proj.bias', 'model.decoder.decoder.layers.4.fc1.bias', 'model.decoder.decoder.layers.2.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.11.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.9.self_attn.k_proj.weight', 'model.decoder.decoder.layers.2.self_attn.out_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.8.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.5.self_attn.out_proj.bias', 'model.decoder.adapters.8.up_proj.weight', 'model.decoder.decoder.layers.6.self_attn.q_proj.bias', 'model.decoder.decoder.layers.3.fc1.bias', 'model.decoder.decoder.layers.10.final_layer_norm.bias', 'model.decoder.decoder.layers.1.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.5.self_attn.v_proj.weight', 'model.decoder.adapters.6.layerNorm.bias', 'model.decoder.decoder.layers.1.self_attn.k_proj.weight', 'model.decoder.adapters.2.down_proj.weight', 'model.decoder.decoder.layers.6.self_attn.out_proj.weight', 'model.decoder.decoder.layers.3.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.6.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.0.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.6.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.0.self_attn.v_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn.v_proj.weight', 'model.decoder.adapters.0.layerNorm.bias', 'model.decoder.decoder.layers.6.fc2.weight', 'model.decoder.decoder.layers.5.self_attn.v_proj.bias', 'model.decoder.decoder.layers.1.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.5.self_attn.out_proj.weight', 'model.decoder.decoder.layers.7.final_layer_norm.weight', 'model.decoder.decoder.layers.0.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.8.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.7.self_attn.q_proj.bias', 'model.decoder.adapters.4.layerNorm.bias', 'model.decoder.decoder.layers.0.self_attn.out_proj.weight', 'model.decoder.decoder.layers.4.self_attn.out_proj.bias', 'model.decoder.decoder.layers.0.encoder_attn_layer_norm.weight', 'model.decoder.adapters.3.up_proj.weight', 'model.decoder.decoder.layers.1.fc2.weight', 'model.decoder.decoder.layers.1.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.3.self_attn.out_proj.bias', 'model.decoder.decoder.layers.10.fc2.bias', 'model.decoder.decoder.layers.2.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.9.encoder_attn.v_proj.bias', 'model.decoder.adapters.2.layerNorm.weight', 'model.decoder.adapters.11.layerNorm.bias', 'model.decoder.decoder.layers.0.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.1.self_attn.q_proj.weight', 'model.decoder.decoder.layers.8.fc2.weight', 'model.decoder.decoder.layers.4.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.7.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.2.self_attn.q_proj.weight', 'model.decoder.decoder.layers.5.self_attn.k_proj.bias', 'model.decoder.decoder.layers.0.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.3.final_layer_norm.bias', 'model.decoder.decoder.layers.10.fc1.bias', 'model.decoder.decoder.layers.6.fc1.bias', 'model.decoder.adapters.7.layerNorm.weight', 'model.decoder.decoder.layers.9.self_attn.q_proj.bias', 'model.decoder.decoder.layers.8.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.9.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.9.self_attn.k_proj.bias', 'model.decoder.decoder.layers.0.self_attn.v_proj.bias', 'model.decoder.decoder.layers.0.fc1.bias', 'model.decoder.decoder.layers.10.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.6.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.2.fc2.weight', 'model.decoder.adapters.10.up_proj.weight', 'model.decoder.decoder.layers.4.encoder_attn.out_proj.bias', 'model.decoder.adapters.2.up_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.3.fc1.weight', 'model.decoder.decoder.layers.2.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.10.self_attn.q_proj.weight', 'model.decoder.adapters.4.down_proj.weight', 'model.decoder.decoder.layers.0.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.3.fc2.bias', 'model.decoder.decoder.layers.2.self_attn.out_proj.bias', 'model.decoder.adapters.10.down_proj.weight', 'model.decoder.decoder.layers.0.fc2.bias', 'model.decoder.decoder.layers.7.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.5.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.0.self_attn.q_proj.weight', 'model.decoder.decoder.layers.11.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.9.final_layer_norm.bias', 'model.decoder.decoder.layers.1.self_attn.q_proj.bias', 'model.decoder.decoder.layers.4.encoder_attn.q_proj.bias', 'model.decoder.adapters.5.layerNorm.bias', 'model.decoder.decoder.layers.4.self_attn.q_proj.bias', 'model.decoder.decoder.layers.11.fc1.bias', 'model.decoder.decoder.layers.0.final_layer_norm.bias', 'model.decoder.decoder.layers.10.self_attn.k_proj.weight', 'model.decoder.decoder.layers.4.final_layer_norm.bias', 'model.decoder.decoder.layers.6.final_layer_norm.bias', 'model.decoder.decoder.layers.5.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.1.self_attn_layer_norm.weight', 'model.decoder.adapters.6.layerNorm.weight', 'model.decoder.decoder.layers.0.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.9.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.11.final_layer_norm.bias', 'model.decoder.decoder.layers.11.encoder_attn.k_proj.weight', 'model.decoder.adapters.11.layerNorm.weight', 'model.decoder.adapters.8.layerNorm.weight', 'model.decoder.adapters.3.down_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.4.fc2.weight', 'model.decoder.decoder.layers.10.self_attn.k_proj.bias', 'model.decoder.decoder.layers.10.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.4.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.7.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.7.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.9.self_attn.out_proj.bias', 'model.decoder.decoder.layers.7.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.10.fc1.weight', 'model.decoder.decoder.layers.7.fc2.weight', 'model.decoder.decoder.layers.8.self_attn.v_proj.weight', 'model.decoder.decoder.layers.11.self_attn.v_proj.bias', 'model.decoder.decoder.layers.9.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.6.self_attn.k_proj.bias', 'model.decoder.decoder.layers.6.self_attn.q_proj.weight', 'model.decoder.decoder.layers.8.self_attn.out_proj.bias', 'model.decoder.decoder.layers.3.self_attn.v_proj.weight', 'model.decoder.decoder.layers.4.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.8.self_attn.k_proj.bias', 'model.decoder.adapters.1.layerNorm.weight', 'model.decoder.decoder.layers.8.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.3.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.10.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.5.self_attn.q_proj.bias', 'model.decoder.adapters.7.layerNorm.bias', 'model.decoder.decoder.layers.4.encoder_attn.k_proj.bias', 'model.decoder.adapters.8.layerNorm.bias', 'model.decoder.decoder.layers.10.self_attn.out_proj.weight', 'model.decoder.decoder.layers.0.encoder_attn.out_proj.bias', 'model.decoder.decoder.layers.7.self_attn.k_proj.weight', 'model.decoder.decoder.layers.5.fc1.bias', 'model.decoder.decoder.layers.0.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.0.self_attn.out_proj.bias', 'model.decoder.decoder.layers.8.fc1.bias', 'model.decoder.decoder.layers.1.self_attn.v_proj.bias', 'model.decoder.decoder.layers.7.self_attn.v_proj.weight', 'model.decoder.decoder.layers.5.final_layer_norm.bias', 'model.decoder.adapters.0.layerNorm.weight', 'model.decoder.decoder.layers.8.final_layer_norm.weight', 'model.decoder.decoder.layers.3.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.0.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.0.fc1.weight', 'model.decoder.decoder.layers.5.self_attn.k_proj.weight', 'model.decoder.decoder.layers.3.final_layer_norm.weight', 'model.decoder.decoder.layers.7.self_attn.v_proj.bias', 'model.decoder.decoder.embed_tokens.weight', 'model.decoder.decoder.layers.0.self_attn.k_proj.weight', 'model.decoder.decoder.layers.6.fc1.weight', 'model.decoder.decoder.layers.10.self_attn.v_proj.bias', 'model.decoder.decoder.layers.7.fc2.bias', 'model.decoder.decoder.layernorm_embedding.bias', 'model.decoder.adapters.9.up_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.10.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.4.fc2.bias', 'model.decoder.decoder.layers.1.final_layer_norm.weight', 'model.decoder.decoder.layers.8.fc2.bias', 'model.decoder.decoder.layers.6.self_attn.k_proj.weight', 'model.decoder.decoder.layers.1.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.2.self_attn.k_proj.weight', 'model.decoder.decoder.layers.1.fc2.bias', 'model.decoder.adapters.9.down_proj.weight', 'model.decoder.adapters.0.down_proj.weight', 'model.decoder.decoder.layers.8.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.1.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.5.fc2.bias', 'model.decoder.adapters.11.up_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.1.fc1.weight', 'model.decoder.decoder.layers.11.encoder_attn.out_proj.bias', 'model.decoder.adapters.0.up_proj.weight', 'model.decoder.decoder.layers.3.self_attn.k_proj.weight', 'model.decoder.decoder.layers.9.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.7.encoder_attn.k_proj.weight', 'model.decoder.adapters.3.layerNorm.bias', 'model.decoder.decoder.layers.2.encoder_attn.out_proj.weight', 'model.decoder.adapters.5.layerNorm.weight', 'model.decoder.decoder.layers.10.self_attn.v_proj.weight', 'model.decoder.decoder.layers.4.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.1.self_attn.out_proj.bias', 'model.decoder.decoder.layers.9.self_attn.v_proj.bias', 'model.decoder.decoder.layers.10.encoder_attn.q_proj.bias', 'model.decoder.decoder.layers.10.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.11.final_layer_norm.weight', 'model.decoder.decoder.layers.3.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.11.encoder_attn.q_proj.weight', 'model.decoder.adapters.6.up_proj.weight', 'model.decoder.decoder.layers.7.encoder_attn.k_proj.bias', 'model.decoder.adapters.7.down_proj.weight', 'model.decoder.decoder.layers.5.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.8.encoder_attn_layer_norm.bias', 'model.decoder.decoder.layers.2.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.10.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.9.self_attn.out_proj.weight', 'model.decoder.decoder.layers.7.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.4.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.10.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.4.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.3.self_attn.k_proj.bias', 'model.decoder.adapters.3.layerNorm.weight', 'model.decoder.decoder.layers.4.self_attn.k_proj.bias', 'model.decoder.adapters.10.layerNorm.bias', 'model.decoder.decoder.layers.4.self_attn.k_proj.weight', 'model.decoder.decoder.layers.4.self_attn_layer_norm.weight', 'model.decoder.decoder.layers.7.fc1.weight', 'model.decoder.decoder.layers.6.self_attn.out_proj.bias', 'model.decoder.decoder.layers.2.fc1.weight', 'model.decoder.decoder.layers.8.encoder_attn_layer_norm.weight', 'model.decoder.decoder.layers.6.encoder_attn.q_proj.weight', 'model.decoder.decoder.layers.11.self_attn_layer_norm.weight', 'model.decoder.decoder.embed_positions.weight', 'model.decoder.decoder.layers.7.encoder_attn.v_proj.bias', 'model.decoder.decoder.layers.8.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.7.final_layer_norm.bias', 'model.decoder.decoder.layers.11.encoder_attn.out_proj.weight', 'model.decoder.decoder.layers.3.self_attn.v_proj.bias', 'model.decoder.decoder.layers.3.encoder_attn.k_proj.weight', 'model.decoder.decoder.layers.2.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.2.self_attn.v_proj.bias', 'model.decoder.decoder.layers.6.final_layer_norm.weight', 'model.decoder.decoder.layers.2.self_attn.v_proj.weight', 'model.decoder.decoder.layers.2.encoder_attn.k_proj.bias', 'model.decoder.decoder.layers.9.fc2.bias', 'model.decoder.decoder.layers.9.encoder_attn.v_proj.weight', 'model.decoder.decoder.layers.4.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.5.final_layer_norm.weight', 'model.decoder.adapters.1.down_proj.weight', 'model.decoder.decoder.layers.7.self_attn_layer_norm.bias', 'model.decoder.decoder.layers.9.fc2.weight', 'model.decoder.adapters.1.layerNorm.bias', 'model.decoder.decoder.layers.3.self_attn.q_proj.bias', 'model.decoder.decoder.layers.3.encoder_attn.v_proj.weight', 'model.decoder.adapters.5.down_proj.weight', 'model.decoder.adapters.9.layerNorm.weight']\n", "- This IS expected if you are initializing BartForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BartForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f2096bde6a524ff9a87c11c118a2c2ee", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00, ?ba/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "metadata": { "tags": null }, "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.7/dist-packages/transformers/tokenization_utils_base.py:3543: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n", " \"`as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your \"\n", "Using cuda_amp half precision backend\n", "/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " FutureWarning,\n", "***** Running training *****\n", " Num examples = 50\n", " Num Epochs = 30\n", " Instantaneous batch size per device = 4\n", " Total train batch size (w. parallel, distributed & accumulation) = 4\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 390\n", "You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Precision@m | \n", "Recall@m | \n", "Fscore@m | \n", "Num Match | \n", "Num Pred | \n", "Num Gold | \n", "
---|---|---|---|---|---|---|---|---|
1 | \n", "1.484700 | \n", "0.823063 | \n", "55.000000 | \n", "37.330000 | \n", "43.590000 | \n", "1.220000 | \n", "2.240000 | \n", "3.280000 | \n", "
2 | \n", "0.836800 | \n", "0.575592 | \n", "63.070000 | \n", "62.600000 | \n", "61.430000 | \n", "1.960000 | \n", "3.240000 | \n", "3.280000 | \n", "
3 | \n", "0.664900 | \n", "0.431628 | \n", "78.070000 | \n", "68.970000 | \n", "71.810000 | \n", "2.180000 | \n", "2.860000 | \n", "3.280000 | \n", "
4 | \n", "0.390500 | \n", "0.280106 | \n", "90.670000 | \n", "69.630000 | \n", "77.340000 | \n", "2.160000 | \n", "2.420000 | \n", "3.280000 | \n", "
5 | \n", "0.344900 | \n", "0.316271 | \n", "92.170000 | \n", "70.200000 | \n", "78.500000 | \n", "2.240000 | \n", "2.460000 | \n", "3.280000 | \n", "
6 | \n", "0.355400 | \n", "0.223609 | \n", "90.830000 | \n", "80.430000 | \n", "84.580000 | \n", "2.560000 | \n", "2.860000 | \n", "3.280000 | \n", "
7 | \n", "0.343300 | \n", "0.320883 | \n", "88.000000 | \n", "80.430000 | \n", "83.370000 | \n", "2.540000 | \n", "2.920000 | \n", "3.280000 | \n", "
8 | \n", "0.308200 | \n", "0.186449 | \n", "91.500000 | \n", "83.500000 | \n", "86.710000 | \n", "2.640000 | \n", "2.920000 | \n", "3.280000 | \n", "
9 | \n", "0.156900 | \n", "0.259272 | \n", "91.430000 | \n", "86.770000 | \n", "88.500000 | \n", "2.740000 | \n", "3.040000 | \n", "3.280000 | \n", "
10 | \n", "0.130000 | \n", "0.191015 | \n", "89.430000 | \n", "85.700000 | \n", "87.120000 | \n", "2.700000 | \n", "3.080000 | \n", "3.280000 | \n", "
11 | \n", "0.131900 | \n", "0.144332 | \n", "89.930000 | \n", "86.870000 | \n", "88.090000 | \n", "2.740000 | \n", "3.100000 | \n", "3.280000 | \n", "
12 | \n", "0.151100 | \n", "0.160923 | \n", "92.270000 | \n", "85.270000 | \n", "88.070000 | \n", "2.700000 | \n", "2.960000 | \n", "3.280000 | \n", "
13 | \n", "0.104600 | \n", "0.181388 | \n", "91.770000 | \n", "85.530000 | \n", "88.080000 | \n", "2.700000 | \n", "2.980000 | \n", "3.280000 | \n", "
14 | \n", "0.110200 | \n", "0.247938 | \n", "91.430000 | \n", "86.370000 | \n", "88.310000 | \n", "2.720000 | \n", "3.020000 | \n", "3.280000 | \n", "
15 | \n", "0.157600 | \n", "0.240022 | \n", "92.270000 | \n", "87.600000 | \n", "89.380000 | \n", "2.760000 | \n", "3.040000 | \n", "3.280000 | \n", "
16 | \n", "0.126500 | \n", "0.133782 | \n", "93.270000 | \n", "88.330000 | \n", "90.320000 | \n", "2.800000 | \n", "3.040000 | \n", "3.280000 | \n", "
17 | \n", "0.053800 | \n", "0.158040 | \n", "91.430000 | \n", "89.000000 | \n", "89.930000 | \n", "2.820000 | \n", "3.140000 | \n", "3.280000 | \n", "
18 | \n", "0.056000 | \n", "0.250004 | \n", "92.030000 | \n", "87.030000 | \n", "88.970000 | \n", "2.740000 | \n", "3.020000 | \n", "3.280000 | \n", "
19 | \n", "0.083200 | \n", "0.167435 | \n", "93.370000 | \n", "86.770000 | \n", "89.330000 | \n", "2.740000 | \n", "2.980000 | \n", "3.280000 | \n", "
20 | \n", "0.099200 | \n", "0.160180 | \n", "90.200000 | \n", "87.700000 | \n", "88.680000 | \n", "2.760000 | \n", "3.120000 | \n", "3.280000 | \n", "
21 | \n", "0.031500 | \n", "0.144671 | \n", "91.370000 | \n", "87.430000 | \n", "88.940000 | \n", "2.760000 | \n", "3.080000 | \n", "3.280000 | \n", "
22 | \n", "0.062100 | \n", "0.167008 | \n", "90.370000 | \n", "88.100000 | \n", "89.010000 | \n", "2.780000 | \n", "3.140000 | \n", "3.280000 | \n", "
23 | \n", "0.054600 | \n", "0.124082 | \n", "91.270000 | \n", "88.500000 | \n", "89.580000 | \n", "2.800000 | \n", "3.120000 | \n", "3.280000 | \n", "
24 | \n", "0.033500 | \n", "0.136747 | \n", "92.270000 | \n", "87.830000 | \n", "89.510000 | \n", "2.780000 | \n", "3.060000 | \n", "3.280000 | \n", "
25 | \n", "0.076700 | \n", "0.128727 | \n", "93.270000 | \n", "87.170000 | \n", "89.570000 | \n", "2.760000 | \n", "3.000000 | \n", "3.280000 | \n", "
26 | \n", "0.092200 | \n", "0.175439 | \n", "90.270000 | \n", "87.830000 | \n", "88.770000 | \n", "2.780000 | \n", "3.140000 | \n", "3.280000 | \n", "
27 | \n", "0.089700 | \n", "0.185708 | \n", "90.930000 | \n", "88.500000 | \n", "89.430000 | \n", "2.800000 | \n", "3.140000 | \n", "3.280000 | \n", "
28 | \n", "0.022900 | \n", "0.171717 | \n", "91.270000 | \n", "88.500000 | \n", "89.580000 | \n", "2.800000 | \n", "3.120000 | \n", "3.280000 | \n", "
29 | \n", "0.038700 | \n", "0.142286 | \n", "91.270000 | \n", "88.500000 | \n", "89.580000 | \n", "2.800000 | \n", "3.120000 | \n", "3.280000 | \n", "
"
],
"text/plain": [
" "
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"***** Running Evaluation *****\n",
" Num examples = 50\n",
" Batch size = 4\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"(50, 20)\n",
"PRED: num=4 - {'transfer learning', 'transformer', 'natural language processing', 'fine-tuning'}\n",
"GT: num=4 - {'transfer learning', 'transformer', 'natural language processing', 'fine-tuning'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'hypergraph attention network', 'recommender system', 'session-based recommendation system'}\n",
"GT: num=3 - {'hypergraph attention network', 'recommender system', 'session-based recommendation system'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'descriptive text caption', 'generating audio sample', 'auto-regressive generative model'}\n",
"GT: num=3 - {'descriptive text caption', 'generating audio sample', 'auto-regressive generative model'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'human motion generation', 'motion diffusion model', 'diffusion model'}\n",
"GT: num=3 - {'human motion generation', 'motion diffusion model', 'diffusion model'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'spiking neural network', 'self-attention mechanism', 'biological property'}\n",
"GT: num=3 - {'spiking neural network', 'self-attention mechanism', 'biological property'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'knowledge grounding', 'medical question understanding and answering', 'semantic', 'medical question answering system'}\n",
"GT: num=4 - {'knowledge grounding', 'medical question understanding and answering system', 'semantic self-supervision', 'medical question answering system'}\n",
"p=0.5, r=0.5, f1=0.5\n",
"--------------------\n",
"PRED: num=3 - {'synthesize diverse', 'multi-modal image and text representations', 'neural rendering'}\n",
"GT: num=4 - {'natural language descriptions', 'synthesize diverse 3D objects', 'multi-modal image and text representations', 'neural rendering'}\n",
"p=0.6666666666666666, r=0.5, f1=0.5714285714285715\n",
"--------------------\n",
"PRED: num=2 - {'qualitative data', 'qualitative visualization'}\n",
"GT: num=2 - {'qualitative data', 'qualitative visualization'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'graph neural network', 'topology design', 'feature selection and fusion strategy'}\n",
"GT: num=3 - {'graph neural network', 'topology design', 'feature selection and fusion strategy'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'transfer learning', 'multi-task learning', 'unsupervised learning', 'gaussian mixture model'}\n",
"GT: num=4 - {'transfer learning', 'multi-task learning', 'unsupervised learning', 'gaussian mixture model'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'cognitive multilayer network', 'cognition', 'c', 'mental lexicon'}\n",
"GT: num=3 - {'cognitive multilayer network', 'cognitive network', 'mental lexicon'}\n",
"p=0.5, r=0.6666666666666666, f1=0.5714285714285715\n",
"--------------------\n",
"PRED: num=3 - {'spatiotemporal representation learning', 'pretext task', 'pre-training'}\n",
"GT: num=3 - {'spatiotemporal representation learning', 'pretext task', 'pre-training'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'language-family adapter', 'multilingual model', 'machine translation', 'natural language processing'}\n",
"GT: num=4 - {'language-family adapter', 'multilingual model', 'machine translation', 'natural language processing'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'fine-tuning', 'prompt-free', 'sentence transformer'}\n",
"GT: num=3 - {'fine-tuning', 'prompt-free', 'sentence transformer'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'', 'knowledge transferability', 'medical image domain', 'large-scale pre-trained vision language model'}\n",
"GT: num=5 - {'knowledge transferability', 'medical prompt', 'medical image domain', 'domain transfer capability', 'large-scale pre-trained vision language model'}\n",
"p=0.75, r=0.6, f1=0.6666666666666665\n",
"--------------------\n",
"PRED: num=4 - {'bayesian optimization', 'one-', 'stochastic expensive black box function', 'acquisition function'}\n",
"GT: num=4 - {'one-shot hybrid kg', 'bayesian optimization', 'stochastic expensive black box function', 'acquisition function'}\n",
"p=0.75, r=0.75, f1=0.75\n",
"--------------------\n",
"PRED: num=4 - {'convolutional neural network', 'handwritten text recognition', 'end-to-end', 'feature'}\n",
"GT: num=4 - {'feature extraction', 'convolutional neural network', 'handwritten text recognition', 'end-to-end'}\n",
"p=0.75, r=0.75, f1=0.75\n",
"--------------------\n",
"PRED: num=4 - {'speech processing system', 'multilingual supervision', 'robust speech processing', 'multitask supervision'}\n",
"GT: num=4 - {'speech processing system', 'multilingual supervision', 'robust speech processing', 'multitask supervision'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {\"fitts' law\", 'two-component mixture structure', 'expectation-conditional'}\n",
"GT: num=3 - {'expectation-conditional-maximization', \"fitts' law\", 'two-component mixture structure'}\n",
"p=0.6666666666666666, r=0.6666666666666666, f1=0.6666666666666666\n",
"--------------------\n",
"PRED: num=4 - {'maximal k', 'cohesive structure', 'mining maximal subgraph', 'bipartite graph'}\n",
"GT: num=5 - {'mining maximal subgraph', 'bipartite graph', 'maximal k-biplex', 'cohesive structure', 'reverse search framework'}\n",
"p=0.75, r=0.6, f1=0.6666666666666665\n",
"--------------------\n",
"PRED: num=3 - {'operator-algebraic', 'quantum field theory', 'minkowski space'}\n",
"GT: num=3 - {'operator-algebraic', 'quantum field theory', 'minkowski space'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'conditional diffusion transformer', 'data-driven approach', 'optimize neural networks'}\n",
"GT: num=3 - {'conditional diffusion transformer', 'data-driven approach', 'optimize neural networks'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'kg representation learning', 'knowledge graph', 'prompt learning'}\n",
"GT: num=3 - {'kg representation learning', 'knowledge graph', 'prompt learning'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'cross-modal representation', 'vision-language pre-trained model', 'dual encoder'}\n",
"GT: num=5 - {'multi-view contrastive learning', 'contrastive learning', 'dual encoder', 'vision-language pre-trained model', 'cross-modal representation'}\n",
"p=1.0, r=0.6, f1=0.7499999999999999\n",
"--------------------\n",
"PRED: num=3 - {'free-text retrieval', 'knowledge base', 'question answering over knowledge base'}\n",
"GT: num=3 - {'free-text retrieval', 'knowledge base', 'question answering over knowledge base'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'pre-trained conversational model', 'implicit commonsense (cs) knowledge', 'dialogue'}\n",
"GT: num=5 - {'implicit commonsense (cs) knowledge', 'pre-trained conversational model', 'dialogue agent', 'two-way learning', 'external knowledge'}\n",
"p=0.6666666666666666, r=0.4, f1=0.5\n",
"--------------------\n",
"PRED: num=5 - {'quantitative communication', 'information theory', 'communication complexity', 'quantum network', 'quantum computer'}\n",
"GT: num=5 - {'quantitative communication', 'information theory', 'communication complexity', 'quantum network', 'quantum computer'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'phoneme recognition', 'speech recognition', 'phonetic feature extraction'}\n",
"GT: num=3 - {'phoneme recognition', 'speech recognition', 'phonetic feature extraction'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'cross entropy loss', 'soft prompt learning'}\n",
"GT: num=2 - {'cross entropy loss', 'soft prompt learning'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'semantic segmentation', 'convolutional network architecture'}\n",
"GT: num=2 - {'semantic segmentation', 'convolutional network architecture'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'neural machine translation', 'low frequency word prediction', 'token-level contrast', 'representation learning'}\n",
"GT: num=4 - {'token-level contrastive learning', 'neural machine translation', 'low frequency word prediction', 'representation learning'}\n",
"p=0.75, r=0.75, f1=0.75\n",
"--------------------\n",
"PRED: num=3 - {'reinforcement learning', 'markov decision process', 'sequential decoding'}\n",
"GT: num=3 - {'reinforcement learning', 'markov decision process', 'sequential decoding'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'slimmable networks', 'self-supervised learning'}\n",
"GT: num=2 - {'slimmable networks', 'self-supervised learning'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'semi-supervised segmentation network', 'segmentation of images', 'medical AI'}\n",
"GT: num=4 - {'semi-supervised segmentation network', 'segmentation of images', 'contrastive learning', 'medical AI'}\n",
"p=1.0, r=0.75, f1=0.8571428571428571\n",
"--------------------\n",
"PRED: num=2 - {'legal statute identification task', 'citation network'}\n",
"GT: num=2 - {'legal statute identification task', 'citation network'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'multi-dialect automatic speech recognition', 'multi dialectal', 'deep neural network'}\n",
"GT: num=4 - {'acoustic model', 'multi-dialect automatic speech recognition', 'multi-dialectal corpus', 'deep neural network'}\n",
"p=0.6666666666666666, r=0.5, f1=0.5714285714285715\n",
"--------------------\n",
"PRED: num=4 - {'language-vision', 'language', 'deep learning', 'open-source deep learning library'}\n",
"GT: num=4 - {'language-vision', 'language-vision tasks', 'deep learning', 'open-source deep learning library'}\n",
"p=0.75, r=0.75, f1=0.75\n",
"--------------------\n",
"PRED: num=3 - {'task-agnostic', 'semi-parametric language model', 'zero-shot generalization'}\n",
"GT: num=3 - {'task-agnostic', 'semi-parametric language model', 'zero-shot generalization'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'cooperative environmental learning', 'multi-robot system', 'distributed learning'}\n",
"GT: num=3 - {'cooperative environmental learning', 'multi-robot system', 'distributed learning'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'', 'real-time and human-interpretable decision-making', 'time-incremental learning'}\n",
"GT: num=3 - {'decision tree', 'real-time and human-interpretable decision-making', 'time-incremental learning'}\n",
"p=0.6666666666666666, r=0.6666666666666666, f1=0.6666666666666666\n",
"--------------------\n",
"PRED: num=3 - {'Non-referential face image quality assessment method', 'quality assessment method', 'face recognition'}\n",
"GT: num=3 - {'Non-referential face image quality assessment method', 'quality assessment method', 'face recognition'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'knowledge-intensive language tasks', 'benchmark'}\n",
"GT: num=2 - {'knowledge-intensive language tasks', 'benchmark'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=3 - {'semi-supervised learning', 'neural processes', 'image classification task'}\n",
"GT: num=3 - {'semi-supervised learning', 'neural processes', 'image classification task'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'sparse attention and monotonic attention', 'automatic speech recognition'}\n",
"GT: num=2 - {'sparse attention and monotonic attention', 'automatic speech recognition'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=4 - {'information removal', 'privacy', 'deep learning', 'bias'}\n",
"GT: num=4 - {'information removal', 'privacy', 'deep learning', 'bias'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'data budgeting problem', 'data'}\n",
"GT: num=2 - {'data budgeting problem', 'data'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'gender classification algorithms', 'benchmark database'}\n",
"GT: num=2 - {'gender classification algorithms', 'benchmark database'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=5 - {'sparse recovery', 'convex optimization', 'deep learning', 'neural networks', 'relu'}\n",
"GT: num=5 - {'sparse recovery', 'convex optimization', 'relu networks', 'deep learning', 'neural networks'}\n",
"p=0.8, r=0.8, f1=0.8000000000000002\n",
"--------------------\n",
"PRED: num=2 - {'genetic variant', 'boolean relations'}\n",
"GT: num=2 - {'genetic variant', 'boolean relations'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n",
"PRED: num=2 - {'text-to-image diffusion model', 'subject recontextualization'}\n",
"GT: num=2 - {'text-to-image diffusion model', 'subject recontextualization'}\n",
"p=1.0, r=1.0, f1=1.0\n",
"--------------------\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Saving model checkpoint to KeyBartAdapter/checkpoint-390\n",
"Configuration saved in KeyBartAdapter/checkpoint-390/config.json\n",
"Model weights saved in KeyBartAdapter/checkpoint-390/pytorch_model.bin\n",
"tokenizer config file saved in KeyBartAdapter/checkpoint-390/tokenizer_config.json\n",
"Special tokens file saved in KeyBartAdapter/checkpoint-390/special_tokens_map.json\n",
"Deleting older checkpoint [KeyBartAdapter/checkpoint-364] due to args.save_total_limit\n",
"\n",
"\n",
"Training completed. Do not forget to share your model on huggingface.co/models =)\n",
"\n",
"\n",
"Loading best model from KeyBartAdapter/checkpoint-208 (score: 90.32).\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=390, training_loss=0.24197826716953363, metrics={'train_runtime': 2544.9523, 'train_samples_per_second': 0.589, 'train_steps_per_second': 0.153, 'total_flos': 961529921028096.0, 'train_loss': 0.24197826716953363, 'epoch': 30.0})"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"source": [
"abstract = '''Large text-to-image models achieved a remarkable leap in the evolution of AI, enabling high-quality and diverse synthesis of images from a given text prompt. However, these models lack the ability to mimic the appearance of subjects in a given reference set and synthesize novel renditions of them in different contexts. In this work, we present a new approach for \"personalization\" of text-to-image diffusion models (specializing them to users' needs). Given as input just a few images of a subject, we fine-tune a pretrained text-to-image model (Imagen, although our method is not limited to a specific model) such that it learns to bind a unique identifier with that specific subject. Once the subject is embedded in the output domain of the model, the unique identifier can then be used to synthesize fully-novel photorealistic images of the subject contextualized in different scenes. By leveraging the semantic prior embedded in the model with a new autogenous class-specific prior preservation loss, our technique enables synthesizing the subject in diverse scenes, poses, views, and lighting conditions that do not appear in the reference images. We apply our technique to several previously-unassailable tasks, including subject recontextualization, text-guided view synthesis, appearance modification, and artistic rendering (all while preserving the subject's key features). Project page: https://dreambooth.github.io/'''"
],
"metadata": {
"id": "jfKXREnThJ_T"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import Text2TextGenerationPipeline\n",
"pipe = Text2TextGenerationPipeline(model=model,tokenizer=tokenizer,device= 0)\n",
"\n",
"\n",
"pipe(abstract)"
],
"metadata": {
"id": "06Fard5UfHqQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "998a5db2-2026-429b-fffb-71443608ed37"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[{'generated_text': 'text-to-image diffusion model;subject recontextualization'}]"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"# # 6. push to hub\n",
"# commit_msg = f'{model_name}_{num_epoch}'\n",
"# tokenizer.push_to_hub(commit_message=commit_msg, repo_id=model_name )\n",
"# model.push_to_hub(commit_message=commit_msg, repo_id=model_name)"
],
"metadata": {
"id": "nYXWw7brJMhQ"
},
"execution_count": null,
"outputs": []
}
]
}\n",
" \n",
"
\n",
" \n",
" \n",
" \n",
" Epoch \n",
" Training Loss \n",
" Validation Loss \n",
" Precision@m \n",
" Recall@m \n",
" Fscore@m \n",
" Num Match \n",
" Num Pred \n",
" Num Gold \n",
" \n",
" \n",
" 1 \n",
" 1.484700 \n",
" 0.823063 \n",
" 55.000000 \n",
" 37.330000 \n",
" 43.590000 \n",
" 1.220000 \n",
" 2.240000 \n",
" 3.280000 \n",
" \n",
" \n",
" 2 \n",
" 0.836800 \n",
" 0.575592 \n",
" 63.070000 \n",
" 62.600000 \n",
" 61.430000 \n",
" 1.960000 \n",
" 3.240000 \n",
" 3.280000 \n",
" \n",
" \n",
" 3 \n",
" 0.664900 \n",
" 0.431628 \n",
" 78.070000 \n",
" 68.970000 \n",
" 71.810000 \n",
" 2.180000 \n",
" 2.860000 \n",
" 3.280000 \n",
" \n",
" \n",
" 4 \n",
" 0.390500 \n",
" 0.280106 \n",
" 90.670000 \n",
" 69.630000 \n",
" 77.340000 \n",
" 2.160000 \n",
" 2.420000 \n",
" 3.280000 \n",
" \n",
" \n",
" 5 \n",
" 0.344900 \n",
" 0.316271 \n",
" 92.170000 \n",
" 70.200000 \n",
" 78.500000 \n",
" 2.240000 \n",
" 2.460000 \n",
" 3.280000 \n",
" \n",
" \n",
" 6 \n",
" 0.355400 \n",
" 0.223609 \n",
" 90.830000 \n",
" 80.430000 \n",
" 84.580000 \n",
" 2.560000 \n",
" 2.860000 \n",
" 3.280000 \n",
" \n",
" \n",
" 7 \n",
" 0.343300 \n",
" 0.320883 \n",
" 88.000000 \n",
" 80.430000 \n",
" 83.370000 \n",
" 2.540000 \n",
" 2.920000 \n",
" 3.280000 \n",
" \n",
" \n",
" 8 \n",
" 0.308200 \n",
" 0.186449 \n",
" 91.500000 \n",
" 83.500000 \n",
" 86.710000 \n",
" 2.640000 \n",
" 2.920000 \n",
" 3.280000 \n",
" \n",
" \n",
" 9 \n",
" 0.156900 \n",
" 0.259272 \n",
" 91.430000 \n",
" 86.770000 \n",
" 88.500000 \n",
" 2.740000 \n",
" 3.040000 \n",
" 3.280000 \n",
" \n",
" \n",
" 10 \n",
" 0.130000 \n",
" 0.191015 \n",
" 89.430000 \n",
" 85.700000 \n",
" 87.120000 \n",
" 2.700000 \n",
" 3.080000 \n",
" 3.280000 \n",
" \n",
" \n",
" 11 \n",
" 0.131900 \n",
" 0.144332 \n",
" 89.930000 \n",
" 86.870000 \n",
" 88.090000 \n",
" 2.740000 \n",
" 3.100000 \n",
" 3.280000 \n",
" \n",
" \n",
" 12 \n",
" 0.151100 \n",
" 0.160923 \n",
" 92.270000 \n",
" 85.270000 \n",
" 88.070000 \n",
" 2.700000 \n",
" 2.960000 \n",
" 3.280000 \n",
" \n",
" \n",
" 13 \n",
" 0.104600 \n",
" 0.181388 \n",
" 91.770000 \n",
" 85.530000 \n",
" 88.080000 \n",
" 2.700000 \n",
" 2.980000 \n",
" 3.280000 \n",
" \n",
" \n",
" 14 \n",
" 0.110200 \n",
" 0.247938 \n",
" 91.430000 \n",
" 86.370000 \n",
" 88.310000 \n",
" 2.720000 \n",
" 3.020000 \n",
" 3.280000 \n",
" \n",
" \n",
" 15 \n",
" 0.157600 \n",
" 0.240022 \n",
" 92.270000 \n",
" 87.600000 \n",
" 89.380000 \n",
" 2.760000 \n",
" 3.040000 \n",
" 3.280000 \n",
" \n",
" \n",
" 16 \n",
" 0.126500 \n",
" 0.133782 \n",
" 93.270000 \n",
" 88.330000 \n",
" 90.320000 \n",
" 2.800000 \n",
" 3.040000 \n",
" 3.280000 \n",
" \n",
" \n",
" 17 \n",
" 0.053800 \n",
" 0.158040 \n",
" 91.430000 \n",
" 89.000000 \n",
" 89.930000 \n",
" 2.820000 \n",
" 3.140000 \n",
" 3.280000 \n",
" \n",
" \n",
" 18 \n",
" 0.056000 \n",
" 0.250004 \n",
" 92.030000 \n",
" 87.030000 \n",
" 88.970000 \n",
" 2.740000 \n",
" 3.020000 \n",
" 3.280000 \n",
" \n",
" \n",
" 19 \n",
" 0.083200 \n",
" 0.167435 \n",
" 93.370000 \n",
" 86.770000 \n",
" 89.330000 \n",
" 2.740000 \n",
" 2.980000 \n",
" 3.280000 \n",
" \n",
" \n",
" 20 \n",
" 0.099200 \n",
" 0.160180 \n",
" 90.200000 \n",
" 87.700000 \n",
" 88.680000 \n",
" 2.760000 \n",
" 3.120000 \n",
" 3.280000 \n",
" \n",
" \n",
" 21 \n",
" 0.031500 \n",
" 0.144671 \n",
" 91.370000 \n",
" 87.430000 \n",
" 88.940000 \n",
" 2.760000 \n",
" 3.080000 \n",
" 3.280000 \n",
" \n",
" \n",
" 22 \n",
" 0.062100 \n",
" 0.167008 \n",
" 90.370000 \n",
" 88.100000 \n",
" 89.010000 \n",
" 2.780000 \n",
" 3.140000 \n",
" 3.280000 \n",
" \n",
" \n",
" 23 \n",
" 0.054600 \n",
" 0.124082 \n",
" 91.270000 \n",
" 88.500000 \n",
" 89.580000 \n",
" 2.800000 \n",
" 3.120000 \n",
" 3.280000 \n",
" \n",
" \n",
" 24 \n",
" 0.033500 \n",
" 0.136747 \n",
" 92.270000 \n",
" 87.830000 \n",
" 89.510000 \n",
" 2.780000 \n",
" 3.060000 \n",
" 3.280000 \n",
" \n",
" \n",
" 25 \n",
" 0.076700 \n",
" 0.128727 \n",
" 93.270000 \n",
" 87.170000 \n",
" 89.570000 \n",
" 2.760000 \n",
" 3.000000 \n",
" 3.280000 \n",
" \n",
" \n",
" 26 \n",
" 0.092200 \n",
" 0.175439 \n",
" 90.270000 \n",
" 87.830000 \n",
" 88.770000 \n",
" 2.780000 \n",
" 3.140000 \n",
" 3.280000 \n",
" \n",
" \n",
" 27 \n",
" 0.089700 \n",
" 0.185708 \n",
" 90.930000 \n",
" 88.500000 \n",
" 89.430000 \n",
" 2.800000 \n",
" 3.140000 \n",
" 3.280000 \n",
" \n",
" \n",
" 28 \n",
" 0.022900 \n",
" 0.171717 \n",
" 91.270000 \n",
" 88.500000 \n",
" 89.580000 \n",
" 2.800000 \n",
" 3.120000 \n",
" 3.280000 \n",
" \n",
" \n",
" 29 \n",
" 0.038700 \n",
" 0.142286 \n",
" 91.270000 \n",
" 88.500000 \n",
" 89.580000 \n",
" 2.800000 \n",
" 3.120000 \n",
" 3.280000 \n",
" \n",
" \n",
" \n",
"30 \n",
" 0.070100 \n",
" 0.140700 \n",
" 91.270000 \n",
" 88.500000 \n",
" 89.580000 \n",
" 2.800000 \n",
" 3.120000 \n",
" 3.280000 \n",
"