surdan commited on
Commit
c4c0ace
1 Parent(s): 055be6c

Upload Inference.ipynb

Browse files
Files changed (1) hide show
  1. Inference.ipynb +221 -0
Inference.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "73f81039",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from transformers import pipeline\n",
11
+ "from termcolor import colored\n",
12
+ "import torch"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "b8a8891e",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "# !pip install termcolor==1.1.0"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "id": "44668ca1",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "class Ner_Extractor:\n",
33
+ " \n",
34
+ " def __init__(self, model_checkpoint):\n",
35
+ " \n",
36
+ " self.token_pred_pipeline = pipeline(\"token-classification\", \n",
37
+ " model=model_checkpoint, \n",
38
+ " aggregation_strategy=\"average\")\n",
39
+ " \n",
40
+ " @staticmethod\n",
41
+ " def text_color(txt, txt_c=\"blue\", txt_hglt=\"on_yellow\"):\n",
42
+ " return colored(txt, txt_c, txt_hglt)\n",
43
+ " \n",
44
+ " @staticmethod\n",
45
+ " def concat_entities(ner_result):\n",
46
+ " \n",
47
+ " entities = []\n",
48
+ " prev_entity = None\n",
49
+ " prev_end = 0\n",
50
+ " for i in range(len(ner_result)):\n",
51
+ " if (ner_result[i][\"entity_group\"] == prev_entity) &\\\n",
52
+ " (ner_result[i][\"start\"] == prev_end):\n",
53
+ " entities[i-1][2] = ner_result[i][\"end\"]\n",
54
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
55
+ " prev_end = ner_result[i][\"end\"]\n",
56
+ " else:\n",
57
+ " entities.append([ner_result[i][\"entity_group\"], \n",
58
+ " ner_result[i][\"start\"], \n",
59
+ " ner_result[i][\"end\"]])\n",
60
+ " prev_entity = ner_result[i][\"entity_group\"]\n",
61
+ " prev_end = ner_result[i][\"end\"]\n",
62
+ " \n",
63
+ " return entities\n",
64
+ " \n",
65
+ " \n",
66
+ " def colored_text(self, text, entities):\n",
67
+ " \n",
68
+ " colored_text = \"\"\n",
69
+ " init_pos = 0\n",
70
+ " for ent in entities:\n",
71
+ " if ent[1] > init_pos:\n",
72
+ " colored_text += text[init_pos: ent[1]]\n",
73
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
74
+ " init_pos = ent[2]\n",
75
+ " else:\n",
76
+ " colored_text += self.text_color(text[ent[1]: ent[2]]) + f\"({ent[0]})\"\n",
77
+ " init_pos = ent[2]\n",
78
+ " \n",
79
+ " return colored_text\n",
80
+ " \n",
81
+ " \n",
82
+ " def get_entities(self, text):\n",
83
+ " \n",
84
+ " entities = self.token_pred_pipeline(text)\n",
85
+ " concat_ent = self.concat_entities(entities)\n",
86
+ " \n",
87
+ " return concat_ent\n",
88
+ " \n",
89
+ " \n",
90
+ " def show_ents_on_text(self, text: str):\n",
91
+ " \n",
92
+ " entities = self.get_entities(text)\n",
93
+ " \n",
94
+ " return self.colored_text(text, entities)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "30e9efd9",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "seqs_example = [\"Минобороны: ракетами «Калибр» уничтожена техника дивизиона С-300, поставленная из Европы\",\n",
105
+ "\"Боррель подтвердил стремление ЕС к военному сценарию урегулирования конфликта на Украине\",\n",
106
+ "\"Ericsson приостановит бизнес в России на неопределенный срок\",\n",
107
+ "\"Минобороны заявило о захвате ВС России новых танков ВСУ под Изюмом\",\n",
108
+ "\"Макрон вышел в лидеры в первом туре выборов во Франции после обработки 97% бюллетеней\",\n",
109
+ "\"Глава МИД Литвы: страны ЕС начали работу над шестым пакетом санкций против России\",\n",
110
+ "\"«Интеррос» Потанина объявил о покупке Росбанка у Societe Generale\",\n",
111
+ "\"Доллар и евро на Мосбирже подорожали на открытии торгов\",\n",
112
+ "\"Басурин заявил, что порт в Мариуполе освобожден на 80%\",\n",
113
+ "\"Путин поручил Промсвязьбанку открыть отделения в Крыму до октября\",\n",
114
+ "\"Milliyet: Турция рассматривает покупку российских Су-57 в случае отказа США продавать F-16\",\n",
115
+ "\"Швеция и Финляндия подадут заявку на вступление в НАТО летом текущего года\",\n",
116
+ "\"Сенатор Косачев предупредил об использовании оружия массового поражения на Украине\",\n",
117
+ "\"Встреча президента Путина и канцлера Нехаммера пройдет без журналистов и пресс-конференции\",\n",
118
+ "\"Кадыров заявил о широкомасштабном наступлении на города и сёла Украины\"]"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "380d9824",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": [
128
+ "extractor = Ner_Extractor(model_checkpoint = \"surdan/LaBSE_ner_nerel\")"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "37ebcf51",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "show_entities_in_text = (extractor.show_ents_on_text(i) for i in seqs_example)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "14807823",
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": []
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "e03b28c7",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": []
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": null,
160
+ "id": "a2d4ae84",
161
+ "metadata": {},
162
+ "outputs": [],
163
+ "source": [
164
+ "print(next(show_entities_in_text, \"Конец\"))"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "e8ab57d1",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": []
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": null,
178
+ "id": "47fbcff9",
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": []
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "id": "41c32b90",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": []
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "07bb735e",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": []
198
+ }
199
+ ],
200
+ "metadata": {
201
+ "kernelspec": {
202
+ "display_name": "Python 3 (ipykernel)",
203
+ "language": "python",
204
+ "name": "python3"
205
+ },
206
+ "language_info": {
207
+ "codemirror_mode": {
208
+ "name": "ipython",
209
+ "version": 3
210
+ },
211
+ "file_extension": ".py",
212
+ "mimetype": "text/x-python",
213
+ "name": "python",
214
+ "nbconvert_exporter": "python",
215
+ "pygments_lexer": "ipython3",
216
+ "version": "3.8.10"
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 5
221
+ }