bobfromjapan
commited on
Commit
•
3084243
1
Parent(s):
39ed441
Upload 2 files
Browse files- insert_punctuation.py +2 -2
- train.py +1 -1
insert_punctuation.py
CHANGED
@@ -36,7 +36,7 @@ class punctuation_predictor(torch.nn.Module):
|
|
36 |
|
37 |
|
38 |
model = punctuation_predictor(base_model)
|
39 |
-
model.load_state_dict(torch.load("punctuation_position_model.pth"))
|
40 |
model.eval()
|
41 |
|
42 |
|
@@ -82,6 +82,6 @@ def process_long_text(text, max_length=256, comma_thresh=0.1, period_thresh=0.1)
|
|
82 |
if __name__ == "__main__":
|
83 |
print(
|
84 |
process_long_text(
|
85 |
-
"
|
86 |
)
|
87 |
)
|
|
|
36 |
|
37 |
|
38 |
model = punctuation_predictor(base_model)
|
39 |
+
model.load_state_dict(torch.load("weight/punctuation_position_model.pth"))
|
40 |
model.eval()
|
41 |
|
42 |
|
|
|
82 |
if __name__ == "__main__":
|
83 |
print(
|
84 |
process_long_text(
|
85 |
+
"女は昨夕艶めかしい姿をして彼の浴室の戸を開けた人に違なかった風呂場で彼を驚ろかした大きな髷をいつの間にか崩して尋常の束髪に結い更えたので彼はつい同じ人と気がつかずにいた彼はさらに声を聴いただけで顔を知らなかった伴の男の方をよそながらの初対面といった風に女と眺め比べた",
|
86 |
)
|
87 |
)
|
train.py
CHANGED
@@ -175,4 +175,4 @@ for epoch in range(10):
|
|
175 |
epoch_loss += loss.item()
|
176 |
progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)})
|
177 |
# %%
|
178 |
-
torch.save(model.state_dict(), "punctuation_position_model.pth")
|
|
|
175 |
epoch_loss += loss.item()
|
176 |
progress_bar.set_postfix({"loss": epoch_loss / len(data_loader)})
|
177 |
# %%
|
178 |
+
torch.save(model.state_dict(), "weight/punctuation_position_model.pth")
|