import os from typing import Dict, List, Any from long_coref.coref.prediction import CorefPredictor from long_coref.coref.utils import ArchiveContent from allennlp.common.params import Params CHECKPOINT = "coref-spanbert-large-2021.03.10" class PreTrainedPipeline: def __init__(self, path=""): archive_content = ArchiveContent( archive_dir=os.path.join(path, CHECKPOINT), weight_path=os.path.join(path, CHECKPOINT, "weights.th"), config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")), ) self.predictor = CorefPredictor.from_extracted_archive(archive_content) self.predictor.set_device("cpu") def __call__(self, data: str) -> Dict[str, Any]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs prediction = self.predictor.resolve_paragraphs(data.split("\n\n")) return prediction.to_dict()