Spaces:
Running
Running
import argparse | |
import torch | |
from data import load_tokenizer | |
from model import load_model | |
from demo import run_gradio | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_dir', default='/data/mohamed/data') | |
parser.add_argument('--aim_repo', default='/data/mohamed/') | |
parser.add_argument('--ckpt', default='electra-base.pt') | |
parser.add_argument('--aim_exp', default='mimic-decisions-1215') | |
parser.add_argument('--label_encoding', default='multiclass') | |
parser.add_argument('--multiclass', action='store_true') | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument('--save_losses', action='store_true') | |
parser.add_argument('--task', default='token', choices=['seq', 'token']) | |
parser.add_argument('--max_len', type=int, default=512) | |
parser.add_argument('--num_layers', type=int, default=3) | |
parser.add_argument('--kernels', nargs=3, type=int, default=[1,2,3]) | |
parser.add_argument('--model', default='roberta-base',) | |
parser.add_argument('--model_name', default='google/electra-base-discriminator',) | |
parser.add_argument('--gpu', default='0') | |
parser.add_argument('--grad_accumulation', default=2, type=int) | |
parser.add_argument('--pheno_id', type=int) | |
parser.add_argument('--unseen_pheno', type=int) | |
parser.add_argument('--text_subset') | |
parser.add_argument('--pheno_n', type=int, default=500) | |
parser.add_argument('--hidden_size', type=int, default=100) | |
parser.add_argument('--emb_size', type=int, default=400) | |
parser.add_argument('--total_steps', type=int, default=5000) | |
parser.add_argument('--train_log', type=int, default=500) | |
parser.add_argument('--val_log', type=int, default=1000) | |
parser.add_argument('--seed', default = '0') | |
parser.add_argument('--num_phenos', type=int, default=10) | |
parser.add_argument('--num_decs', type=int, default=9) | |
parser.add_argument('--num_umls_tags', type=int, default=33) | |
parser.add_argument('--batch_size', type=int, default=8) | |
parser.add_argument('--pos_weight', type=float, default=1.25) | |
parser.add_argument('--alpha_distil', type=float, default=1) | |
parser.add_argument('--distil', action='store_true') | |
parser.add_argument('--distil_att', action='store_true') | |
parser.add_argument('--distil_ckpt') | |
parser.add_argument('--use_umls', action='store_true') | |
parser.add_argument('--include_nolabel', action='store_true') | |
parser.add_argument('--truncate_train', action='store_true') | |
parser.add_argument('--truncate_eval', action='store_true') | |
parser.add_argument('--load_ckpt', action='store_true') | |
parser.add_argument('--gradio', action='store_true') | |
parser.add_argument('--optuna', action='store_true') | |
parser.add_argument('--mimic_data', action='store_true') | |
parser.add_argument('--eval_only', action='store_true') | |
parser.add_argument('--lr', type=float, default=4e-5) | |
parser.add_argument('--resample', default='') | |
parser.add_argument('--verbose', type=bool, default=True) | |
parser.add_argument('--use_crf', type=bool) | |
parser.add_argument('--print_spans', action='store_true') | |
args = parser.parse_args() | |
if args.task == 'seq' and args.pheno_id is not None: | |
args.num_labels = 1 | |
elif args.task == 'seq': | |
args.num_labels = args.num_phenos | |
elif args.task == 'token': | |
if args.use_umls: | |
args.num_labels = args.num_umls_tags | |
else: | |
args.num_labels = args.num_decs | |
if args.label_encoding == 'multiclass': | |
args.num_labels = args.num_labels * 2 + 1 | |
elif args.label_encoding == 'bo': | |
args.num_labels *= 2 | |
elif args.label_encoding == 'boe': | |
args.num_labels *= 3 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = load_tokenizer(args.model_name) | |
model = load_model(args, device)[0] | |
model.eval() | |
torch.set_grad_enabled(False) | |
run_gradio(model, tokenizer) | |