RuTextSegModel
Model for Russian text segmentation, trained on wiki and news corpora
Model description
This model is a top-level part of HierBERT model and solves the problem of text segmentation as a token classification at the sentence level. The ai-forever/sbert_large_nlu_ru with max pooling is used as a low-level model (sentence embedding generator). It's recommended to use this model only with specified low-level model with defined pooling for embeddings.
Intended uses & limitations
How to use
Here is how to use this model in PyTorch:
import torch
import torch.nn as nn
from transformers import BertForTokenClassification, AutoModel, AutoTokenizer
from razdel import sentenize
class BertForTextSegmentationEmbeddings(nn.Module):
def __init__(self, config, embeddings_dim=768):
super(BertForTextSegmentationEmbeddings, self).__init__()
self.config = config
self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
def forward(self, inputs_embeds, position_ids=None, input_ids=None, token_type_ids=None, past_key_values_length=None):
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
device = inputs_embeds.device
assert seq_length <= self.config.max_position_embeddings, \
f"Too long sequence is passed {seq_length}. Maximum allowed sequence length is {self.config.max_position_embeddings}"
if position_ids is None:
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).expand(input_shape)
position_embeddings = self.position_embeddings(position_ids)
embeddings = inputs_embeds + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertForTextSegmentation(BertForTokenClassification):
def __init__(self, config):
super(BertForTextSegmentation, self).__init__(config)
self.bert.base_model.embeddings = BertForTextSegmentationEmbeddings(config)
self.init_weights()
def max_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
return torch.max(token_embeddings, 1)[0]
def create_embeddings(sentences, tokenizer, model):
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input.to(device))
# Perform pooling. In this case, max pooling.
sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask'])
return sentence_embeddings
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
emb_tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_nlu_ru")
emb_model = AutoModel.from_pretrained("ai-forever/sbert_large_nlu_ru")
model = BertForTextSegmentation.from_pretrained("mlenjoyneer/RuTextSegModel")
emb_model.to(device)
model.to(device)
text = """В Норильске за годы работы телефона доверия консультанты приняли в общей сложности порядка 75 тысяч обращений, сообщает «Заполярная Правда». Служба психологической помощи появилась в 2000 году. Руководитель службы профилактики наркомании Елена Слатвицкая рассказала журналистам, что в Заполярье настал период, когда ухудшается психо– эмоциональное состояние населения. Это происходит на входе в полярную ночь и на выходе из нее. Осень является кризисным моментом. Сейчас на телефоне доверия работают 15 специалистов. Каждый — под своим псевдонимом. Тему беседы определяет звонящий. Это могут быть наркомания и алкоголизм, ВИЧ–инфекция и прочие заболевания и зависимости, кризисы семейных отношений и многое другое. Сотрудники службы отметили, что больше стало звонков по поводу суицидальных намерений. Наибольшее количество обращений по суицидам пришлось на октябрь — ноябрь. Много звонков как от мужчин, так и от женщин с вопросами об одиночестве. Лидерами по количеству обращений пока остаются женщины. В сентябре в Норильске обнаружили тело девятиклассницы. По версии следствия, девочка сбросилась с крыши. В январе подросток нанес себе порезы стеклом от разбитой бутылки, пытаясь покончить с собой. Мальчик поссорился с матерью и в ходе ссоры нанес себе несколько порезов. Проводится расследование."""
input_embeds = create_embeddings([s.text for s in sentenize(text)], emb_tokenizer, emb_model).unsqueeze(0)
outputs = model(inputs_embeds=input_embeds)
logits = outputs.logits.cpu()
preds = logits.argmax(axis=2).tolist()[0] # [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]
# true_labels = [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]
Training data
Model trained on mlenjoyneer/RuTextSegNews and mlenjoyneer/RuTextSegWiki datasets.
Evaluation results
Train Dataset | Test Dataset | F1_total | F1_1 | Pk | Pk_5 | WinDiff | WinDiff_5 |
---|---|---|---|---|---|---|---|
News+Wiki | News | 0.88 | 0.80 | 0.16 | 0.11 | 0.20 | 0.35 |
News+Wiki | Wiki | 0.89 | 0.80 | 0.18 | 0.16 | 0.09 | 0.19 |
Citation info
In progress
- Downloads last month
- 9