|
|
|
import argparse |
|
import os |
|
import sys |
|
from datetime import datetime |
|
import logging |
|
from logging.handlers import RotatingFileHandler |
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoConfig, AutoTokenizer |
|
|
|
|
|
from src.models import DNikudModel, ModelConfig |
|
from src.models_utils import training, evaluate, predict |
|
from src.plot_helpers import ( |
|
generate_plot_by_nikud_dagesh_sin_dict, |
|
generate_word_and_letter_accuracy_plot, |
|
) |
|
from src.running_params import BATCH_SIZE, MAX_LENGTH_SEN |
|
from src.utiles_data import ( |
|
NikudDataset, |
|
Nikud, |
|
create_missing_folders, |
|
extract_text_to_compare_nakdimon, |
|
) |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
assert DEVICE == "cuda" |
|
|
|
|
|
def get_logger( |
|
log_level, name_func, date_time=datetime.now().strftime("%d_%m_%y__%H_%M") |
|
): |
|
log_location = os.path.join( |
|
os.path.join(Path(__file__).parent, "logging"), |
|
f"log_model_{name_func}_{date_time}", |
|
) |
|
create_missing_folders(log_location) |
|
|
|
log_format = "%(asctime)s %(levelname)-8s Thread_%(thread)-6d ::: %(funcName)s(%(lineno)d) ::: %(message)s" |
|
logger = logging.getLogger("algo") |
|
logger.setLevel(getattr(logging, log_level)) |
|
cnsl_log_formatter = logging.Formatter(log_format) |
|
cnsl_handler = logging.StreamHandler() |
|
cnsl_handler.setFormatter(cnsl_log_formatter) |
|
cnsl_handler.setLevel(log_level) |
|
logger.addHandler(cnsl_handler) |
|
|
|
create_missing_folders(log_location) |
|
|
|
file_location = os.path.join(log_location, "Diacritization_Model_DEBUG.log") |
|
file_log_formatter = logging.Formatter(log_format) |
|
|
|
SINGLE_LOG_SIZE = 2 * 1024 * 1024 |
|
MAX_LOG_FILES = 20 |
|
file_handler = RotatingFileHandler( |
|
file_location, mode="a", maxBytes=SINGLE_LOG_SIZE, backupCount=MAX_LOG_FILES |
|
) |
|
file_handler.setFormatter(file_log_formatter) |
|
file_handler.setLevel(log_level) |
|
logger.addHandler(file_handler) |
|
|
|
return logger |
|
|
|
|
|
def evaluate_text( |
|
path, |
|
dnikud_model, |
|
tokenizer_tavbert, |
|
logger, |
|
plots_folder=None, |
|
batch_size=BATCH_SIZE, |
|
): |
|
path_name = os.path.basename(path) |
|
|
|
msg = f"evaluate text: {path_name} on D-nikud Model" |
|
logger.debug(msg) |
|
|
|
if os.path.isfile(path): |
|
dataset = NikudDataset( |
|
tokenizer_tavbert, file=path, logger=logger, max_length=MAX_LENGTH_SEN |
|
) |
|
elif os.path.isdir(path): |
|
dataset = NikudDataset( |
|
tokenizer_tavbert, folder=path, logger=logger, max_length=MAX_LENGTH_SEN |
|
) |
|
else: |
|
raise Exception("input path doesnt exist") |
|
|
|
dataset.prepare_data(name="evaluate") |
|
mtb_dl = torch.utils.data.DataLoader(dataset.prepered_data, batch_size=batch_size) |
|
|
|
word_level_correct, letter_level_correct_dev = evaluate( |
|
dnikud_model, mtb_dl, plots_folder, device=DEVICE |
|
) |
|
|
|
msg = ( |
|
f"Dnikud Model\n{path_name} evaluate\nLetter level accuracy:{letter_level_correct_dev}\n" |
|
f"Word level accuracy: {word_level_correct}" |
|
) |
|
logger.debug(msg) |
|
|
|
|
|
def predict_text( |
|
text_file, |
|
tokenizer_tavbert, |
|
output_file, |
|
logger, |
|
dnikud_model, |
|
compare_nakdimon=False, |
|
): |
|
dataset = NikudDataset( |
|
tokenizer_tavbert, file=text_file, logger=logger, max_length=MAX_LENGTH_SEN |
|
) |
|
|
|
dataset.prepare_data(name="prediction") |
|
mtb_prediction_dl = torch.utils.data.DataLoader( |
|
dataset.prepered_data, batch_size=BATCH_SIZE |
|
) |
|
all_labels = predict(dnikud_model, mtb_prediction_dl, DEVICE) |
|
text_data_with_labels = dataset.back_2_text(labels=all_labels) |
|
|
|
if output_file is None: |
|
for line in text_data_with_labels: |
|
print(line) |
|
else: |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
if compare_nakdimon: |
|
f.write(extract_text_to_compare_nakdimon(text_data_with_labels)) |
|
else: |
|
f.write(text_data_with_labels) |
|
|
|
|
|
def predict_folder( |
|
folder, |
|
output_folder, |
|
logger, |
|
tokenizer_tavbert, |
|
dnikud_model, |
|
compare_nakdimon=False, |
|
): |
|
create_missing_folders(output_folder) |
|
|
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
|
|
if filename.lower().endswith(".txt") and os.path.isfile(file_path): |
|
output_file = os.path.join(output_folder, filename) |
|
predict_text( |
|
file_path, |
|
output_file=output_file, |
|
logger=logger, |
|
tokenizer_tavbert=tokenizer_tavbert, |
|
dnikud_model=dnikud_model, |
|
compare_nakdimon=compare_nakdimon, |
|
) |
|
elif ( |
|
os.path.isdir(file_path) and filename != ".git" and filename != "README.md" |
|
): |
|
sub_folder = file_path |
|
sub_folder_output = os.path.join(output_folder, filename) |
|
predict_folder( |
|
sub_folder, |
|
sub_folder_output, |
|
logger, |
|
tokenizer_tavbert, |
|
dnikud_model, |
|
compare_nakdimon=compare_nakdimon, |
|
) |
|
|
|
|
|
def update_compare_folder(folder, output_folder): |
|
create_missing_folders(output_folder) |
|
|
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
|
|
if filename.lower().endswith(".txt") and os.path.isfile(file_path): |
|
output_file = os.path.join(output_folder, filename) |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
text_data_with_labels = f.read() |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
f.write(extract_text_to_compare_nakdimon(text_data_with_labels)) |
|
elif os.path.isdir(file_path) and filename != ".git": |
|
sub_folder = file_path |
|
sub_folder_output = os.path.join(output_folder, filename) |
|
update_compare_folder(sub_folder, sub_folder_output) |
|
|
|
|
|
def check_files_excepted(folder): |
|
for filename in os.listdir(folder): |
|
file_path = os.path.join(folder, filename) |
|
|
|
if filename.lower().endswith(".txt") and os.path.isfile(file_path): |
|
try: |
|
x = NikudDataset(None, file=file_path) |
|
except: |
|
print(f"failed in file: {filename}") |
|
elif os.path.isdir(file_path) and filename != ".git": |
|
check_files_excepted(file_path) |
|
|
|
|
|
def do_predict( |
|
input_path, output_path, tokenizer_tavbert, logger, dnikud_model, compare_nakdimon |
|
): |
|
if os.path.isdir(input_path): |
|
predict_folder( |
|
input_path, |
|
output_path, |
|
logger, |
|
tokenizer_tavbert, |
|
dnikud_model, |
|
compare_nakdimon=compare_nakdimon, |
|
) |
|
elif os.path.isfile(input_path): |
|
predict_text( |
|
input_path, |
|
output_file=output_path, |
|
logger=logger, |
|
tokenizer_tavbert=tokenizer_tavbert, |
|
dnikud_model=dnikud_model, |
|
compare_nakdimon=compare_nakdimon, |
|
) |
|
else: |
|
raise Exception("Input file not exist") |
|
|
|
|
|
def evaluate_folder(folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder): |
|
msg = f"evaluate sub folder: {folder_path}" |
|
logger.info(msg) |
|
|
|
evaluate_text( |
|
folder_path, |
|
dnikud_model=dnikud_model, |
|
tokenizer_tavbert=tokenizer_tavbert, |
|
logger=logger, |
|
plots_folder=plots_folder, |
|
batch_size=BATCH_SIZE, |
|
) |
|
|
|
msg = f"\n***************************************\n" |
|
logger.info(msg) |
|
|
|
for sub_folder_name in os.listdir(folder_path): |
|
sub_folder_path = os.path.join(folder_path, sub_folder_name) |
|
|
|
if ( |
|
not os.path.isdir(sub_folder_path) |
|
or sub_folder_path == ".git" |
|
or "not_use" in sub_folder_path |
|
or "NakdanResults" in sub_folder_path |
|
): |
|
continue |
|
|
|
evaluate_folder( |
|
sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder |
|
) |
|
|
|
|
|
def do_evaluate( |
|
input_path, |
|
logger, |
|
dnikud_model, |
|
tokenizer_tavbert, |
|
plots_folder, |
|
eval_sub_folders=False, |
|
): |
|
msg = f"evaluate all_data: {input_path}" |
|
logger.info(msg) |
|
|
|
evaluate_text( |
|
input_path, |
|
dnikud_model=dnikud_model, |
|
tokenizer_tavbert=tokenizer_tavbert, |
|
logger=logger, |
|
plots_folder=plots_folder, |
|
batch_size=BATCH_SIZE, |
|
) |
|
|
|
msg = f"\n\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n" |
|
logger.info(msg) |
|
|
|
if eval_sub_folders: |
|
for sub_folder_name in os.listdir(input_path): |
|
sub_folder_path = os.path.join(input_path, sub_folder_name) |
|
|
|
if ( |
|
not os.path.isdir(sub_folder_path) |
|
or sub_folder_path == ".git" |
|
or "not_use" in sub_folder_path |
|
or "NakdanResults" in sub_folder_path |
|
): |
|
continue |
|
|
|
evaluate_folder( |
|
sub_folder_path, logger, dnikud_model, tokenizer_tavbert, plots_folder |
|
) |
|
|
|
|
|
def do_train( |
|
logger, |
|
plots_folder, |
|
dir_model_config, |
|
tokenizer_tavbert, |
|
dnikud_model, |
|
output_trained_model_dir, |
|
data_folder, |
|
n_epochs, |
|
checkpoints_frequency, |
|
learning_rate, |
|
batch_size, |
|
): |
|
msg = "Loading data..." |
|
logger.debug(msg) |
|
|
|
dataset_train = NikudDataset( |
|
tokenizer_tavbert, |
|
folder=os.path.join(data_folder, "train"), |
|
logger=logger, |
|
max_length=MAX_LENGTH_SEN, |
|
is_train=True, |
|
) |
|
dataset_dev = NikudDataset( |
|
tokenizer=tokenizer_tavbert, |
|
folder=os.path.join(data_folder, "dev"), |
|
logger=logger, |
|
max_length=dataset_train.max_length, |
|
is_train=True, |
|
) |
|
dataset_test = NikudDataset( |
|
tokenizer=tokenizer_tavbert, |
|
folder=os.path.join(data_folder, "test"), |
|
logger=logger, |
|
max_length=dataset_train.max_length, |
|
is_train=True, |
|
) |
|
|
|
dataset_train.show_data_labels(plots_folder=plots_folder) |
|
|
|
msg = f"Max length of data: {dataset_train.max_length}" |
|
logger.debug(msg) |
|
|
|
msg = ( |
|
f"Num rows in train data: {len(dataset_train.data)}, " |
|
f"Num rows in dev data: {len(dataset_dev.data)}, " |
|
f"Num rows in test data: {len(dataset_test.data)}" |
|
) |
|
logger.debug(msg) |
|
|
|
msg = "Loading tokenizer and prepare data..." |
|
logger.debug(msg) |
|
|
|
dataset_train.prepare_data(name="train") |
|
dataset_dev.prepare_data(name="dev") |
|
dataset_test.prepare_data(name="test") |
|
|
|
mtb_train_dl = torch.utils.data.DataLoader( |
|
dataset_train.prepered_data, batch_size=batch_size |
|
) |
|
mtb_dev_dl = torch.utils.data.DataLoader( |
|
dataset_dev.prepered_data, batch_size=batch_size |
|
) |
|
|
|
if not os.path.isfile(dir_model_config): |
|
our_model_config = ModelConfig(dataset_train.max_length) |
|
our_model_config.save_to_file(dir_model_config) |
|
|
|
optimizer = torch.optim.Adam(dnikud_model.parameters(), lr=learning_rate) |
|
|
|
msg = "training..." |
|
logger.debug(msg) |
|
|
|
criterion_nikud = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to( |
|
DEVICE |
|
) |
|
criterion_dagesh = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to( |
|
DEVICE |
|
) |
|
criterion_sin = nn.CrossEntropyLoss(ignore_index=Nikud.PAD_OR_IRRELEVANT).to(DEVICE) |
|
|
|
training_params = { |
|
"n_epochs": n_epochs, |
|
"checkpoints_frequency": checkpoints_frequency, |
|
} |
|
( |
|
best_model_details, |
|
best_accuracy, |
|
epochs_loss_train_values, |
|
steps_loss_train_values, |
|
loss_dev_values, |
|
accuracy_dev_values, |
|
) = training( |
|
dnikud_model, |
|
mtb_train_dl, |
|
mtb_dev_dl, |
|
criterion_nikud, |
|
criterion_dagesh, |
|
criterion_sin, |
|
training_params, |
|
logger, |
|
output_trained_model_dir, |
|
optimizer, |
|
device=DEVICE, |
|
) |
|
|
|
generate_plot_by_nikud_dagesh_sin_dict( |
|
epochs_loss_train_values, "Train epochs loss", "Loss", plots_folder |
|
) |
|
generate_plot_by_nikud_dagesh_sin_dict( |
|
steps_loss_train_values, "Train steps loss", "Loss", plots_folder |
|
) |
|
generate_plot_by_nikud_dagesh_sin_dict( |
|
loss_dev_values, "Dev epochs loss", "Loss", plots_folder |
|
) |
|
generate_plot_by_nikud_dagesh_sin_dict( |
|
accuracy_dev_values, "Dev accuracy", "Accuracy", plots_folder |
|
) |
|
generate_word_and_letter_accuracy_plot( |
|
accuracy_dev_values, "Accuracy", plots_folder |
|
) |
|
|
|
msg = "Done" |
|
logger.info(msg) |
|
|
|
|
|
if __name__ == "__main__": |
|
tokenizer_tavbert = AutoTokenizer.from_pretrained("tau/tavbert-he") |
|
|
|
parser = argparse.ArgumentParser( |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
description="""Predict D-nikud""", |
|
) |
|
parser.add_argument( |
|
"-l", |
|
"--log", |
|
dest="log_level", |
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], |
|
default="DEBUG", |
|
help="Set the logging level", |
|
) |
|
parser.add_argument( |
|
"-m", |
|
"--output_model_dir", |
|
type=str, |
|
default="models", |
|
help="save directory for model", |
|
) |
|
subparsers = parser.add_subparsers( |
|
help="sub-command help", dest="command", required=True |
|
) |
|
|
|
parser_predict = subparsers.add_parser("predict", help="diacritize a text files ") |
|
parser_predict.add_argument("input_path", help="input file or folder") |
|
parser_predict.add_argument("output_path", help="output file") |
|
parser_predict.add_argument( |
|
"-ptmp", |
|
"--pretrain_model_path", |
|
type=str, |
|
default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"), |
|
help="pre-train model path - use only if you want to use trained model weights", |
|
) |
|
parser_predict.add_argument( |
|
"-c", |
|
"--compare", |
|
dest="compare_nakdimon", |
|
default=False, |
|
help="predict text for comparing with Nakdimon", |
|
) |
|
parser_predict.set_defaults(func=do_predict) |
|
|
|
parser_evaluate = subparsers.add_parser("evaluate", help="evaluate D-nikud") |
|
parser_evaluate.add_argument("input_path", help="input file or folder") |
|
parser_evaluate.add_argument( |
|
"-ptmp", |
|
"--pretrain_model_path", |
|
type=str, |
|
default=os.path.join(Path(__file__).parent, "models", "Dnikud_best_model.pth"), |
|
help="pre-train model path - use only if you want to use trained model weights", |
|
) |
|
parser_evaluate.add_argument( |
|
"-df", |
|
"--plots_folder", |
|
dest="plots_folder", |
|
default=os.path.join(Path(__file__).parent, "plots"), |
|
help="set the debug folder", |
|
) |
|
parser_evaluate.add_argument( |
|
"-es", |
|
"--eval_sub_folders", |
|
dest="eval_sub_folders", |
|
default=False, |
|
help="accuracy calculation includes the evaluation of sub-folders " |
|
"within the input_path folder, providing independent assessments " |
|
"for each subfolder.", |
|
) |
|
parser_evaluate.set_defaults(func=do_evaluate) |
|
|
|
|
|
|
|
parser_train = subparsers.add_parser("train", help="train D-nikud") |
|
parser_train.add_argument( |
|
"-ptmp", |
|
"--pretrain_model_path", |
|
type=str, |
|
default=None, |
|
help="pre-train model path - use only if you want to use trained model weights", |
|
) |
|
parser_train.add_argument( |
|
"--learning_rate", type=float, default=0.001, help="Learning rate" |
|
) |
|
parser_train.add_argument("--batch_size", type=int, default=32, help="batch_size") |
|
parser_train.add_argument( |
|
"--n_epochs", type=int, default=10, help="number of epochs" |
|
) |
|
parser_train.add_argument( |
|
"--data_folder", |
|
dest="data_folder", |
|
default=os.path.join(Path(__file__).parent, "data"), |
|
help="Set the debug folder", |
|
) |
|
parser_train.add_argument( |
|
"--checkpoints_frequency", |
|
type=int, |
|
default=1, |
|
help="checkpoints frequency for save the model", |
|
) |
|
parser_train.add_argument( |
|
"-df", |
|
"--plots_folder", |
|
dest="plots_folder", |
|
default=os.path.join(Path(__file__).parent, "plots"), |
|
help="Set the debug folder", |
|
) |
|
parser_train.set_defaults(func=do_train) |
|
|
|
args = parser.parse_args() |
|
kwargs = vars(args).copy() |
|
date_time = datetime.now().strftime("%d_%m_%y__%H_%M") |
|
logger = get_logger(kwargs["log_level"], args.command, date_time) |
|
|
|
del kwargs["log_level"] |
|
|
|
kwargs["tokenizer_tavbert"] = tokenizer_tavbert |
|
kwargs["logger"] = logger |
|
|
|
msg = "Loading model..." |
|
logger.debug(msg) |
|
|
|
if args.command in ["evaluate", "predict"] or ( |
|
args.command == "train" and args.pretrain_model_path is not None |
|
): |
|
dir_model_config = os.path.join("models", "config.yml") |
|
config = ModelConfig.load_from_file(dir_model_config) |
|
|
|
dnikud_model = DNikudModel( |
|
config, |
|
len(Nikud.label_2_id["nikud"]), |
|
len(Nikud.label_2_id["dagesh"]), |
|
len(Nikud.label_2_id["sin"]), |
|
device=DEVICE, |
|
).to(DEVICE) |
|
state_dict_model = dnikud_model.state_dict() |
|
state_dict_model.update(torch.load(args.pretrain_model_path)) |
|
dnikud_model.load_state_dict(state_dict_model) |
|
else: |
|
base_model_name = "tau/tavbert-he" |
|
config = AutoConfig.from_pretrained(base_model_name) |
|
dnikud_model = DNikudModel( |
|
config, |
|
len(Nikud.label_2_id["nikud"]), |
|
len(Nikud.label_2_id["dagesh"]), |
|
len(Nikud.label_2_id["sin"]), |
|
pretrain_model=base_model_name, |
|
device=DEVICE, |
|
).to(DEVICE) |
|
|
|
if args.command == "train": |
|
output_trained_model_dir = os.path.join( |
|
kwargs["output_model_dir"], "latest", f"output_models_{date_time}" |
|
) |
|
create_missing_folders(output_trained_model_dir) |
|
dir_model_config = os.path.join(kwargs["output_model_dir"], "config.yml") |
|
kwargs["dir_model_config"] = dir_model_config |
|
kwargs["output_trained_model_dir"] = output_trained_model_dir |
|
del kwargs["pretrain_model_path"] |
|
del kwargs["output_model_dir"] |
|
kwargs["dnikud_model"] = dnikud_model |
|
|
|
del kwargs["command"] |
|
del kwargs["func"] |
|
args.func(**kwargs) |
|
|
|
sys.exit(0) |
|
|