DanilO0o commited on
Commit
3160d55
1 Parent(s): 466d937

added third model

Browse files
images/tg_metrics.png ADDED
models/lstm_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:71be720865fe2c50ad11fce1f0a2cea5327ca757b1dfcd6dd8fccae0c88e1e8a
3
+ size 565373
models/vocab_to_int.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2eb33c6ed312fd0720d7b36bbc251bf8e6845c128c8dd7d4e9583f50bfbcb130
3
+ size 401345
pages/model.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageFilter, ImageDraw
2
+ import streamlit as st
3
+
4
+ import pickle
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch import Tensor
9
+ from dataclasses import dataclass
10
+ from typing import Union
11
+ import re
12
+ import string
13
+ import pymorphy3
14
+ from nltk.corpus import stopwords
15
+ stop_words = set(stopwords.words("english"))
16
+
17
+
18
+ # ------------------------------------------------------------#
19
+ # Упрощенный метод создания класса
20
+
21
+ @dataclass
22
+ class ConfigRNN:
23
+ vocab_size: int # сколько слов - столько embedding-ов; для инициализации embedding параметров
24
+ device: str
25
+ n_layers: int
26
+ embedding_dim: int # чем больше, тем сложнее можно закодировать слово
27
+ hidden_size: int
28
+ seq_len: int
29
+ bidirectional: Union[bool, int]
30
+
31
+
32
+ net_config = ConfigRNN(
33
+ vocab_size=17259 + 1, # -> hand
34
+ device="cpu",
35
+ n_layers=1,
36
+ embedding_dim=8, # не лучшее значение, но в рамках задачи сойдет
37
+ hidden_size=16,
38
+ seq_len=30, # -> hand
39
+ bidirectional=False,
40
+ )
41
+ # ------------------------------------------------------------#
42
+
43
+
44
+ class LSTMClassifier(nn.Module):
45
+ def __init__(self, rnn_conf=net_config) -> None:
46
+ super().__init__()
47
+
48
+ self.embedding_dim = rnn_conf.embedding_dim
49
+ self.hidden_size = rnn_conf.hidden_size
50
+ self.bidirectional = rnn_conf.bidirectional
51
+ self.n_layers = rnn_conf.n_layers
52
+
53
+ self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
54
+ self.lstm = nn.LSTM(
55
+ input_size=self.embedding_dim,
56
+ hidden_size=self.hidden_size,
57
+ bidirectional=self.bidirectional,
58
+ batch_first=True,
59
+ num_layers=self.n_layers,
60
+ dropout=0.5
61
+ )
62
+ self.bidirect_factor = 2 if self.bidirectional else 1
63
+ self.clf = nn.Sequential(
64
+ nn.Linear(self.hidden_size * self.bidirect_factor, 32),
65
+ nn.Dropout(),
66
+ nn.Tanh(),
67
+ nn.Dropout(),
68
+ nn.Linear(32, 5) # len(df['label'].unique())
69
+ )
70
+
71
+ def model_description(self):
72
+ direction = "bidirect" if self.bidirectional else "onedirect"
73
+ return f"lstm_{direction}_{self.n_layers}"
74
+
75
+ def forward(self, x: torch.Tensor):
76
+ embeddings = self.embedding(x)
77
+ out, _ = self.lstm(embeddings)
78
+ # print(out.shape)
79
+ # [все элементы батча, последний h_n, все элементы последнего h_n]
80
+ out = out[:, -1, :]
81
+ # print(out.shape)
82
+ out = self.clf(out)
83
+ return out
84
+ # ------------------------------------------------------------#
85
+ # Загрузка модели
86
+
87
+
88
+ @st.cache_resource
89
+ def load_model():
90
+ model = LSTMClassifier(net_config)
91
+ model.load_state_dict(torch.load(
92
+ "models/lstm_weights.pth", map_location=torch.device("cpu")))
93
+ model.eval()
94
+ return model
95
+
96
+
97
+ model_lstm = load_model()
98
+ # ------------------------------------------------------------#
99
+
100
+
101
+ def padding(text_int: list, seq_len: int) -> np.ndarray:
102
+ """Make left-sided padding for input list of tokens
103
+
104
+ Args:
105
+ review_int (list): input list of tokens
106
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
107
+
108
+ Returns:
109
+ np.array: padded sequences
110
+ """
111
+ features = np.zeros((len(text_int), seq_len), dtype=int)
112
+ for i, review in enumerate(text_int):
113
+ if len(review) <= seq_len:
114
+ zeros = list(np.zeros(seq_len - len(review)))
115
+ new = zeros + review
116
+ else:
117
+ new = review[:seq_len]
118
+ features[i, :] = np.array(new)
119
+ return features
120
+
121
+
122
+ morph = pymorphy3.MorphAnalyzer()
123
+
124
+
125
+ def lemmatize(text):
126
+ # Разбиваем текст на слова
127
+ words = text.split()
128
+
129
+ # Лемматизируем каждое слово и убираем стоп-слова
130
+ lemmatized_words = [morph.parse(word)[0].normal_form for word in words]
131
+
132
+ # Собираем текст из лемматизированных слов
133
+ lemmatized_text = ' '.join(lemmatized_words)
134
+ return lemmatized_text
135
+
136
+
137
+ def data_preprocessing(text):
138
+ # From Phase 1
139
+ text = re.sub(r':[a-zA-Z]+:', '', text) # Убираем смайлики
140
+ text = text.lower() # Переводим текст в нижний регистр
141
+ text = re.sub(r'@[\w_-]+', '', text) # Убираем упоминания пользователей
142
+ text = re.sub(r'#(\w+)', '', text) # Убираем хэштеги
143
+ text = re.sub(r'\d+', '', text) # Убираем цифры
144
+ # Убираем ссылки
145
+ text = re.sub(
146
+ r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
147
+ text = re.sub(r'\s+', ' ', text) # Убираем лишние пробелы
148
+ # Удаление английских слов
149
+ text = ' '.join(re.findall(r'\b[а-яА-ЯёЁ]+\b', text))
150
+ # From Phase 2
151
+ text = re.sub("<.*?>", "", text) # html tags
152
+ text = "".join([c for c in text if c not in string.punctuation])
153
+ splitted_text = [word for word in text.split() if word not in stop_words]
154
+ text = " ".join(splitted_text)
155
+ return text.strip()
156
+
157
+
158
+ def preprocess_single_string(
159
+ input_string: str,
160
+ seq_len: int,
161
+ vocab_to_int: dict,
162
+ verbose: bool = False
163
+ ) -> Tensor:
164
+ """Function for all preprocessing steps on a single string
165
+
166
+ Args:
167
+ input_string (str): input single string for preprocessing
168
+ seq_len (int): max length of sequence, it len(review_int[i]) > seq_len it will be trimmed, else it will be padded by zeros
169
+ vocab_to_int (dict, optional): word corpus {'word' : int index}. Defaults to vocab_to_int.
170
+
171
+ Returns:
172
+ list: preprocessed string
173
+ """
174
+ preprocessed_string = lemmatize(input_string)
175
+ preprocessed_string = data_preprocessing(input_string)
176
+ result_list = []
177
+ for word in preprocessed_string.split():
178
+ try:
179
+ result_list.append(vocab_to_int[word])
180
+ except KeyError as e:
181
+ if verbose:
182
+ print(f'{e}: not in dictionary!')
183
+ pass
184
+ result_padded = padding([result_list], seq_len)[0]
185
+
186
+ return Tensor(result_padded)
187
+ # ------------------------------------------------------------#
188
+
189
+
190
+ st.title("Классификация тематики новостей из телеграм каналов")
191
+ # st.write('Model summary:')
192
+ text = st.text_input('Input some news')
193
+ text_4_test = text
194
+
195
+ # Загрузка словаря из файла
196
+ with open('model/vocab_to_int.pkl', 'rb') as f:
197
+ vocab_to_int = pickle.load(f)
198
+
199
+ if text != '':
200
+ test_review = preprocess_single_string(
201
+ text_4_test, net_config.seq_len, vocab_to_int)
202
+ test_review = torch.tensor(test_review, dtype=torch.int64)
203
+ result = torch.sigmoid(model_lstm(test_review.unsqueeze(0)))
204
+ num = result.argmax().item()
205
+
206
+ st.write('---')
207
+ st.write('Initial text:')
208
+ st.write(text)
209
+ st.write('---')
210
+ st.write('Preprocessing:')
211
+ st.write(data_preprocessing(text))
212
+ st.write('---')
213
+ st.write('Classes:')
214
+ classes = ['крипта', 'мода', 'спорт', 'технологии', 'финансы']
215
+ st.write('крипта *', 'мода *', 'спорт *', 'технологии *', 'финансы')
216
+ st.write('---')
217
+
218
+ st.write('Predict:')
219
+ if text != '':
220
+ st.write('Classification: ', classes[num])
221
+ st.write('Label num: ', num)
222
+
223
+ # Загружаем изображение через PIL
224
+ image = Image.open("images/tg_metrics.png")
225
+
226
+ # Отображение
227
+ st.image(image, caption="Кошмареус переобучения", use_column_width=True)