bobfromjapan commited on
Commit
3084243
1 Parent(s): 39ed441

Upload 2 files

Browse files
Files changed (2) hide show
  1. insert_punctuation.py +2 -2
  2. train.py +1 -1
insert_punctuation.py CHANGED
@@ -36,7 +36,7 @@ class punctuation_predictor(torch.nn.Module):
36
 
37
 
38
  model = punctuation_predictor(base_model)
39
- model.load_state_dict(torch.load("punctuation_position_model.pth"))
40
  model.eval()
41
 
42
 
@@ -82,6 +82,6 @@ def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1)
82
  if __name__ == "__main__":
83
  print(
84
  process_long_text(
85
- "句読点ありバージョンを書きました句読点があることで僕は逆に読みづらく感じるので句読点無しで書きたいと思います",
86
  )
87
  )
 
36
 
37
 
38
  model = punctuation_predictor(base_model)
39
+ model.load_state_dict(torch.load("weight/punctuation_position_model.pth"))
40
  model.eval()
41
 
42
 
 
82
  if __name__ == "__main__":
83
  print(
84
  process_long_text(
85
+ "女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた",
86
  )
87
  )
train.py CHANGED
@@ -175,4 +175,4 @@ for epoch in range(10):
175
  epoch_loss += loss.item()
176
  progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)})
177
  # %%
178
- torch.save(model.state_dict(), "punctuation_position_model.pth")
 
175
  epoch_loss += loss.item()
176
  progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)})
177
  # %%
178
+ torch.save(model.state_dict(), "weight/punctuation_position_model.pth")