Mirror / src /inference.py
Spico's picture
update
5953ef9
raw
history blame contribute delete
No virus
503 Bytes
import os
from rex.utils.logging import logger
from src.task import MrcTaggingTask
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = ""
task = MrcTaggingTask.from_taskdir(
"outputs/bert_mrc_ner",
load_best_model=True,
update_config={
"skip_train": True,
"debug_mode": False,
},
)
cases = ["123123", "123123"]
logger.info(f"Cases: {cases}")
ents = task.predict(cases)
logger.info(f"Results: {ents}")