# %% import torch from transformers import ( BertTokenizer, BertForMaskedLM, AutoModelForMaskedLM, AutoTokenizer, BertModel, ) import numpy as np import random from itertools import islice from torch.utils.data import Dataset, DataLoader from torch.optim import AdamW from tqdm.auto import tqdm import os model_name = "tohoku-nlp/bert-base-japanese-char-v3" tokenizer = BertTokenizer.from_pretrained(model_name) base_model = BertModel.from_pretrained(model_name) class punctuation_predictor(torch.nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model self.dropout = torch.nn.Dropout(0.2) self.linear = torch.nn.Linear(768, 2) def forward(self, input_ids, attention_mask): last_hidden_state = self.base_model( input_ids=input_ids, attention_mask=attention_mask ).last_hidden_state # get last hidden state token by token and apply linear layer return self.linear(self.dropout(last_hidden_state)) model = punctuation_predictor(base_model) model.load_state_dict(torch.load("punctuation_position_model.pth")) model.eval() def insert_punctuation(input, comma_pos, period_pos): text = [] for i, (c, p) in enumerate(zip(comma_pos, period_pos)): token_id = input[i].item() if token_id > 5: if i < len(input) - 1: if p: text.append(tokenizer.ids_to_tokens[input[i].item()] + "。") elif c: text.append(tokenizer.ids_to_tokens[input[i].item()] + "、") else: text.append(tokenizer.ids_to_tokens[input[i].item()]) else: break return "".join(text) def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1): text = text.replace("、", "").replace("。", "") result = "" for i in range(0, len(text), max_length): no_punctuation_text = text[i : i + max_length] inputs = tokenizer( " ".join(list(no_punctuation_text)), max_length=512, padding="max_length", truncation=True, return_tensors="pt", ) output = model(inputs.input_ids, inputs.attention_mask) output = torch.sigmoid(output) comma_pos = output[0].detach().numpy().T[0] > comma_thresh period_pos = output[0].detach().numpy().T[1] > period_thresh result += insert_punctuation(inputs.input_ids[0], comma_pos, period_pos) return result # %% if __name__ == "__main__": print( process_long_text( "句読点ありバージョンを書きました句読点があることで僕は逆に読みづらく感じるので句読点無しで書きたいと思います", ) )