surdan commited on
Commit
54a3a0a
1 Parent(s): 570410a

Upload Inference.ipynb

Browse files
Files changed (1) hide show
  1. Inference.ipynb +208 -0
Inference.ipynb ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "73f81039",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import pipeline\n",
11
+ "from termcolor import colored\n",
12
+ "import torch"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "b8a8891e",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# !pip install termcolor==1.1.0"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "44668ca1",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "class Ner_Extractor:\n",
33
+ " \n",
34
+ " def __init__(self, model_checkpoint):\n",
35
+ " \n",
36
+ " self.token_pred_pipeline = pipeline(\"token-classification\", \n",
37
+ " model=model_checkpoint, \n",
38
+ " aggregation_strategy=\"average\")\n",
39
+ " \n",
40
+ " @staticmethod\n",
41
+ " def text_color(txt, txt_c=\"blue\", txt_hglt=\"on_yellow\"):\n",
42
+ " return colored(txt, txt_c, txt_hglt)\n",
43
+ " \n",
44
+ " @staticmethod\n",
45
+ " def concat_entities(ner_result):\n",
46
+ " \n",
47
+ " entities = []\n",
48
+ " prev_entity = None\n",
49
+ " prev_end = 0\n",
50
+ " for i in range(len(ner_result)):\n",
51
+ " if (ner_result[i][\"entity_group\"] == prev_entity) &\\\n",
52
+ " (ner_result[i][\"start\"] == prev_end):\n",
53
+ " entities[i-1][2] = ner_result[i][\"end\"]\n",
54
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
55
+ " prev_end = ner_result[i][\"end\"]\n",
56
+ " else:\n",
57
+ " entities.append([ner_result[i][\"entity_group\"], \n",
58
+ " ner_result[i][\"start\"], \n",
59
+ " ner_result[i][\"end\"]])\n",
60
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
61
+ " prev_end = ner_result[i][\"end\"]\n",
62
+ " \n",
63
+ " return entities\n",
64
+ " \n",
65
+ " \n",
66
+ " def colored_text(self, text, entities):\n",
67
+ " \n",
68
+ " colored_text = \"\"\n",
69
+ " init_pos = 0\n",
70
+ " for ent in entities:\n",
71
+ " if ent[1] > init_pos:\n",
72
+ " colored_text += text[init_pos: ent[1]]\n",
73
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
74
+ " init_pos = ent[2]\n",
75
+ " else:\n",
76
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
77
+ " init_pos = ent[2]\n",
78
+ " \n",
79
+ " return colored_text\n",
80
+ " \n",
81
+ " \n",
82
+ " def get_entities(self, text: str):\n",
83
+ " \n",
84
+ " entities = self.token_pred_pipeline(text)\n",
85
+ " concat_ent = self.concat_entities(entities)\n",
86
+ " \n",
87
+ " return concat_ent\n",
88
+ " \n",
89
+ " \n",
90
+ " def show_ents_on_text(self, text: str):\n",
91
+ " \n",
92
+ " entities = self.get_entities(text)\n",
93
+ " \n",
94
+ " return self.colored_text(text, entities)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "aaa0a5bd",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "seqs_example = [\"Из Дзюбы вышел бы отличный бразилец». Интервью Клаудиньо\",\n",
105
+ "\"Самый яркий бразилец «Зенита» рассказал о встрече с Пеле, страшном морозе в Самаре и любимых финтах Роналдиньо\",\n",
106
+ "\"Стали известны подробности нового иска РФС к УЕФА и ФИФА\",\n",
107
+ "\"Реванш «Баварии», голы от «Реала» с «Челси»: ставим на ЛЧ\",\n",
108
+ "\"Кварацхелия не вернется в «Рубин» и станет игроком «Наполи»\",\n",
109
+ "\"«Манчестер Сити» сделал грандиозное предложение по Холанду\",\n",
110
+ "\"В России хотят возродить Кубок лиги. Он проводился в 2003 году\",\n",
111
+ "\"Экс-игрок «Реала» находится в критическом состоянии после ДТП\",\n",
112
+ "\"Аршавин посмеялся над показателями Глушакова в игре с ЦСКА\",\n",
113
+ "\"Арьен Роббен пробежал 42-километровый марафон\"]"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "id": "380d9824",
120
+ "metadata": {},
121
+ "outputs": [],
122
+ "source": [
123
+ "%%time\n",
124
+ "extractor = Ner_Extractor(model_checkpoint = \"surdan/LaBSE_ner_nerel\")"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "id": "37ebcf51",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "%%time\n",
135
+ "show_entities_in_text = (extractor.show_ents_on_text(i) for i in seqs_example)"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "id": "e03b28c7",
142
+ "metadata": {},
143
+ "outputs": [],
144
+ "source": [
145
+ "%%time\n",
146
+ "l_entities = [extractor.get_entities(i) for i in seqs_example]\n",
147
+ "len(l_entities), len(seqs_example)"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "id": "a2d4ae84",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "for i in range(len(seqs_example)):\n",
158
+ " print(next(show_entities_in_text, \"End of generator\"))\n",
159
+ " print(\"-*-\"*25)"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "47fbcff9",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": []
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "41c32b90",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": []
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "07bb735e",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": []
185
+ }
186
+ ],
187
+ "metadata": {
188
+ "kernelspec": {
189
+ "display_name": "Python 3 (ipykernel)",
190
+ "language": "python",
191
+ "name": "python3"
192
+ },
193
+ "language_info": {
194
+ "codemirror_mode": {
195
+ "name": "ipython",
196
+ "version": 3
197
+ },
198
+ "file_extension": ".py",
199
+ "mimetype": "text/x-python",
200
+ "name": "python",
201
+ "nbconvert_exporter": "python",
202
+ "pygments_lexer": "ipython3",
203
+ "version": "3.8.10"
204
+ }
205
+ },
206
+ "nbformat": 4,
207
+ "nbformat_minor": 5
208
+ }