{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "73f81039", "metadata": {}, "outputs": [], "source": [ "from transformers import pipeline\n", "from termcolor import colored\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "id": "b8a8891e", "metadata": {}, "outputs": [], "source": [ "# !pip install termcolor==1.1.0" ] }, { "cell_type": "code", "execution_count": null, "id": "44668ca1", "metadata": {}, "outputs": [], "source": [ "class Ner_Extractor:\n", " \n", " def __init__(self, model_checkpoint):\n", " \n", " self.token_pred_pipeline = pipeline(\"token-classification\", \n", " model=model_checkpoint, \n", " aggregation_strategy=\"average\")\n", " \n", " @staticmethod\n", " def text_color(txt, txt_c=\"blue\", txt_hglt=\"on_yellow\"):\n", " return colored(txt, txt_c, txt_hglt)\n", " \n", " @staticmethod\n", " def concat_entities(ner_result):\n", " \n", " entities = []\n", " prev_entity = None\n", " prev_end = 0\n", " for i in range(len(ner_result)):\n", " if (ner_result[i][\"entity_group\"] == prev_entity) &\\\n", " (ner_result[i][\"start\"] == prev_end):\n", " entities[i-1][2] = ner_result[i][\"end\"]\n", " prev_entity = ner_result[i][\"entity_group\"]\n", " prev_end = ner_result[i][\"end\"]\n", " else:\n", " entities.append([ner_result[i][\"entity_group\"], \n", " ner_result[i][\"start\"], \n", " ner_result[i][\"end\"]])\n", " prev_entity = ner_result[i][\"entity_group\"]\n", " prev_end = ner_result[i][\"end\"]\n", " \n", " return entities\n", " \n", " \n", " def colored_text(self, text, entities):\n", " \n", " colored_text = \"\"\n", " init_pos = 0\n", " for ent in entities:\n", " if ent[1] > init_pos:\n", " colored_text += text[init_pos: ent[1]]\n", " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n", " init_pos = ent[2]\n", " else:\n", " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n", " init_pos = ent[2]\n", " \n", " return colored_text\n", " \n", " \n", " def get_entities(self, text: str):\n", " \n", " entities = self.token_pred_pipeline(text)\n", " concat_ent = self.concat_entities(entities)\n", " \n", " return concat_ent\n", " \n", " \n", " def show_ents_on_text(self, text: str):\n", " \n", " entities = self.get_entities(text)\n", " \n", " return self.colored_text(text, entities)" ] }, { "cell_type": "code", "execution_count": null, "id": "aaa0a5bd", "metadata": {}, "outputs": [], "source": [ "seqs_example = [\"Из Дзюбы вышел бы отличный бразилец». Интервью Клаудиньо\",\n", "\"Самый яркий бразилец «Зенита» рассказал о встрече с Пеле, страшном морозе в Самаре и любимых финтах Роналдиньо\",\n", "\"Стали известны подробности нового иска РФС к УЕФА и ФИФА\",\n", "\"Реванш «Баварии», голы от «Реала» с «Челси»: ставим на ЛЧ\",\n", "\"Кварацхелия не вернется в «Рубин» и станет игроком «Наполи»\",\n", "\"«Манчестер Сити» сделал грандиозное предложение по Холанду\",\n", "\"В России хотят возродить Кубок лиги. Он проводился в 2003 году\",\n", "\"Экс-футболиста сборной Украины уволили с ТВ за слова о россиянах\",\n", "\"Экс-игрок «Реала» находится в критическом состоянии после ДТП\",\n", "\"Аршавин посмеялся над показателями Глушакова в игре с ЦСКА\",\n", "\"Арьен Роббен пробежал 42-километровый марафон\",\n", "\"Бывший игрок «Спартака» предложил бить футболистов палками\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "380d9824", "metadata": {}, "outputs": [], "source": [ "%%time\n", "extractor = Ner_Extractor(model_checkpoint = \"surdan/LaBSE_ner_nerel\")" ] }, { "cell_type": "code", "execution_count": null, "id": "37ebcf51", "metadata": {}, "outputs": [], "source": [ "%%time\n", "show_entities_in_text = (extractor.show_ents_on_text(i) for i in seqs_example)" ] }, { "cell_type": "code", "execution_count": null, "id": "e03b28c7", "metadata": {}, "outputs": [], "source": [ "%%time\n", "l_entities = [extractor.get_entities(i) for i in seqs_example]\n", "len(l_entities), len(seqs_example)" ] }, { "cell_type": "code", "execution_count": null, "id": "a2d4ae84", "metadata": {}, "outputs": [], "source": [ "for i in range(len(seqs_example)):\n", " print(next(show_entities_in_text, \"End of generator\"))\n", " print(\"-*-\"*25)" ] }, { "cell_type": "code", "execution_count": null, "id": "47fbcff9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "41c32b90", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "07bb735e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }