bobfromjapan's picture
Upload 2 files
3084243 verified
# %%
import torch
from transformers import (
BertTokenizer,
BertForMaskedLM,
AutoModelForMaskedLM,
AutoTokenizer,
BertModel,
)
import numpy as np
import random
from itertools import islice
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, SGD
from tqdm import tqdm
import os
def index_to_onehot(l, length):
# l=[1, 5], len=6 -> [0,1,0,0,0,1]
return [1 if i in l else 0 for i in range(length)]
def get_punctuation_position(tokenized_text, tokenizer):
# adjust comma_pos and period_pos
count = 0
comma_pos = []
period_pos = []
punctuation_removed_text = []
comma_id = tokenizer.convert_tokens_to_ids("、")
period_id = tokenizer.convert_tokens_to_ids("。")
for i, c in enumerate(tokenized_text):
if c == comma_id:
comma_pos.append(i - count - 1)
count += 1
elif c == period_id:
period_pos.append(i - count - 1)
count += 1
else:
punctuation_removed_text.append(c)
if len(punctuation_removed_text) < 512:
punctuation_removed_text += [tokenizer.pad_token_id] * (
512 - len(punctuation_removed_text)
)
return (
torch.tensor(punctuation_removed_text),
[
index_to_onehot(comma_pos, 512),
index_to_onehot(period_pos, 512),
],
)
# %%
# get_punctuation_position("今日は、いい天気です。")
# # %%
# index_to_onehot([1, 2, 3, 4, 5], 7)
# tokenizer = BertTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char")
# tokenized_text = tokenizer(
# "今 日 は 、 い い 天 気 で す 。",
# max_length=512,
# padding="max_length",
# truncation=True,
# return_tensors="pt",
# )
# inputs, label = get_punctuation_position(tokenized_text["input_ids"][0], tokenizer)
# print(inputs) # ->tensor([ 2, 732, 48, 12, 19, 19, 411, 343, 17, 46, 3, 0, 0, 0, ...])
# print(label) # -> [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...], # 点の位置(最初に[SOS]が入るため、1つずれる)
# -> [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...]] # 丸の位置
# %%
class PunctuationPositionDataset(torch.utils.data.Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
text = self.data[idx]
text = " ".join(list(text))
inputs = self.tokenizer(
text,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt",
)
# if idx % 100 == 0:
# print(masked_text, label)
input_ids, label = get_punctuation_position(
inputs["input_ids"][0], self.tokenizer
)
label = torch.tensor(label, dtype=torch.float32).transpose(0, 1)
return (input_ids, inputs.attention_mask.squeeze(), label.squeeze(), text)
# %%
model_name = "tohoku-nlp/bert-base-japanese-char-v3"
tokenizer = BertTokenizer.from_pretrained(model_name)
base_model = BertModel.from_pretrained(model_name)
# %%
class punctuation_predictor(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
self.dropout = torch.nn.Dropout(0.2)
self.linear = torch.nn.Linear(768, 2)
def forward(self, input_ids, attention_mask):
last_hidden_state = self.base_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state
# get last hidden state token by token and apply linear layer
return self.linear(self.dropout(last_hidden_state))
model = punctuation_predictor(base_model)
# %%
# a = tokenizer("今 日 は い い 天 気 で す 。",max_length=512,
# padding="max_length",
# truncation=True,
# return_tensors="pt",)
# %%
with open("data/train.txt", "r") as f:
texts = f.readlines()
dataset = PunctuationPositionDataset(texts, tokenizer)
# %%
data_loader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=8,
)
# %%
# set lr to 5e-5 to base model
optimizer = AdamW(
[
{"params": model.base_model.parameters(), "lr": 5e-5},
{"params": model.linear.parameters(), "lr": 1e-3},
],
)
criteria = torch.nn.BCEWithLogitsLoss()
# %%
model.train()
model.to("cuda")
for epoch in range(10):
epoch_loss = 0.0
progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}")
for batch in progress_bar:
input_ids, attention_masks, labels, text = batch
input_ids = input_ids.to("cuda")
attention_masks = attention_masks.to("cuda")
labels = labels.to("cuda")
outputs = model(input_ids=input_ids, attention_mask=attention_masks)
loss = criteria(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)})
# %%
torch.save(model.state_dict(), "weight/punctuation_position_model.pth")