surdan commited on
Commit
b93bebe
1 Parent(s): 114df45

Upload Inference.ipynb

Browse files
Files changed (1) hide show
  1. Inference.ipynb +210 -0
Inference.ipynb ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "73f81039",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import pipeline\n",
11
+ "from termcolor import colored\n",
12
+ "import torch"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "b8a8891e",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# !pip install termcolor==1.1.0"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "44668ca1",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "class Ner_Extractor:\n",
33
+ " \n",
34
+ " def __init__(self, model_checkpoint):\n",
35
+ " \n",
36
+ " self.token_pred_pipeline = pipeline(\"token-classification\", \n",
37
+ " model=model_checkpoint, \n",
38
+ " aggregation_strategy=\"average\")\n",
39
+ " \n",
40
+ " @staticmethod\n",
41
+ " def text_color(txt, txt_c=\"blue\", txt_hglt=\"on_yellow\"):\n",
42
+ " return colored(txt, txt_c, txt_hglt)\n",
43
+ " \n",
44
+ " @staticmethod\n",
45
+ " def concat_entities(ner_result):\n",
46
+ " \n",
47
+ " entities = []\n",
48
+ " prev_entity = None\n",
49
+ " prev_end = 0\n",
50
+ " for i in range(len(ner_result)):\n",
51
+ " if (ner_result[i][\"entity_group\"] == prev_entity) &\\\n",
52
+ " (ner_result[i][\"start\"] == prev_end):\n",
53
+ " entities[i-1][2] = ner_result[i][\"end\"]\n",
54
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
55
+ " prev_end = ner_result[i][\"end\"]\n",
56
+ " else:\n",
57
+ " entities.append([ner_result[i][\"entity_group\"], \n",
58
+ " ner_result[i][\"start\"], \n",
59
+ " ner_result[i][\"end\"]])\n",
60
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
61
+ " prev_end = ner_result[i][\"end\"]\n",
62
+ " \n",
63
+ " return entities\n",
64
+ " \n",
65
+ " \n",
66
+ " def colored_text(self, text, entities):\n",
67
+ " \n",
68
+ " colored_text = \"\"\n",
69
+ " init_pos = 0\n",
70
+ " for ent in entities:\n",
71
+ " if ent[1] > init_pos:\n",
72
+ " colored_text += text[init_pos: ent[1]]\n",
73
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
74
+ " init_pos = ent[2]\n",
75
+ " else:\n",
76
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
77
+ " init_pos = ent[2]\n",
78
+ " \n",
79
+ " return colored_text\n",
80
+ " \n",
81
+ " \n",
82
+ " def get_entities(self, text: str):\n",
83
+ " \n",
84
+ " entities = self.token_pred_pipeline(text)\n",
85
+ " concat_ent = self.concat_entities(entities)\n",
86
+ " \n",
87
+ " return concat_ent\n",
88
+ " \n",
89
+ " \n",
90
+ " def show_ents_on_text(self, text: str):\n",
91
+ " \n",
92
+ " entities = self.get_entities(text)\n",
93
+ " \n",
94
+ " return self.colored_text(text, entities)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "aaa0a5bd",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "seqs_example = [\"Из Дзюбы вышел бы отличный бразилец». Интервью Клаудиньо\",\n",
105
+ "\"Самый яркий бразилец «Зенита» рассказал о встрече с Пеле, страшном морозе в Самаре и любимых финтах Роналдиньо\",\n",
106
+ "\"Стали известны подробности нового иска РФС к УЕФА и ФИФА\",\n",
107
+ "\"Реванш «Баварии», голы от «Реала» с «Челси»: ставим на ЛЧ\",\n",
108
+ "\"Кварацхелия не вернется в «Рубин» и станет игроком «Наполи»\",\n",
109
+ "\"«Манчестер Сити» сделал грандиозное предложение по Холанду\",\n",
110
+ "\"В России хотят возродить Кубок лиги. Он проводился в 2003 году\",\n",
111
+ "\"Экс-футболиста сборной Украины уволили с ТВ за слова о россиянах\",\n",
112
+ "\"Экс-игрок «Реала» находится в критическом состоянии после ДТП\",\n",
113
+ "\"Аршавин посмеялся над показателями Глушакова в игре с ЦСКА\",\n",
114
+ "\"Арьен Роббен пробежал 42-километровый марафон\",\n",
115
+ "\"Бывший игрок «Спартака» предложил бить футболистов палками\"]"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "380d9824",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "%%time\n",
126
+ "extractor = Ner_Extractor(model_checkpoint = \"surdan/LaBSE_ner_nerel\")"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "37ebcf51",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "%%time\n",
137
+ "show_entities_in_text = (extractor.show_ents_on_text(i) for i in seqs_example)"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": null,
143
+ "id": "e03b28c7",
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "%%time\n",
148
+ "l_entities = [extractor.get_entities(i) for i in seqs_example]\n",
149
+ "len(l_entities), len(seqs_example)"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "id": "a2d4ae84",
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "for i in range(len(seqs_example)):\n",
160
+ " print(next(show_entities_in_text, \"End of generator\"))\n",
161
+ " print(\"-*-\"*25)"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "47fbcff9",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": []
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "id": "41c32b90",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": []
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "id": "07bb735e",
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": []
187
+ }
188
+ ],
189
+ "metadata": {
190
+ "kernelspec": {
191
+ "display_name": "Python 3 (ipykernel)",
192
+ "language": "python",
193
+ "name": "python3"
194
+ },
195
+ "language_info": {
196
+ "codemirror_mode": {
197
+ "name": "ipython",
198
+ "version": 3
199
+ },
200
+ "file_extension": ".py",
201
+ "mimetype": "text/x-python",
202
+ "name": "python",
203
+ "nbconvert_exporter": "python",
204
+ "pygments_lexer": "ipython3",
205
+ "version": "3.8.10"
206
+ }
207
+ },
208
+ "nbformat": 4,
209
+ "nbformat_minor": 5
210
+ }