|
import argparse |
|
|
|
|
|
def get_parser(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data_dir", default="telecom_data/", type=str, |
|
help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.", ) |
|
parser.add_argument("--data_file", default="data_filter.pkl", type=str) |
|
parser.add_argument("--ckpt_dir", default="./PLMs/chinese-roberta-wwm-ext", type=str, |
|
help="The checkpoints dir. Should contain the pretrained model.", ) |
|
parser.add_argument("--preprocessor", default="BasePreprocessor", type=str, |
|
help="Name of preprocessor.", ) |
|
parser.add_argument("--device", default="cuda:0", type=str) |
|
parser.add_argument("--batch_size", default=128, type=int) |
|
parser.add_argument("--max_epoch", default=100, type=int) |
|
parser.add_argument("--top_k", default=5, type=int) |
|
parser.add_argument("--output_name", default='ELECT_test_output.json', type=str) |
|
return parser |
|
|
|
''' |
|
python main_elect_inference.py \ |
|
--data_file jicheng_questions.json \ |
|
--output_name jicheng_questions_output.json |
|
''' |