|
import os |
|
from typing import Dict, List, Any |
|
from long_coref.coref.prediction import CorefPredictor |
|
from long_coref.coref.utils import ArchiveContent |
|
from allennlp.common.params import Params |
|
|
|
CHECKPOINT = "coref-spanbert-large-2021.03.10" |
|
|
|
|
|
class PreTrainedPipeline: |
|
def __init__(self, path=""): |
|
archive_content = ArchiveContent( |
|
archive_dir=os.path.join(path, CHECKPOINT), |
|
weight_path=os.path.join(path, CHECKPOINT, "weights.th"), |
|
config=Params.from_file(os.path.join(path, CHECKPOINT, "config.json")), |
|
) |
|
self.predictor = CorefPredictor.from_extracted_archive(archive_content) |
|
self.predictor.set_device("cpu") |
|
|
|
def __call__(self, data: str) -> Dict[str, Any]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
date (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
prediction = self.predictor.resolve_paragraphs(data.split("\n\n")) |
|
return prediction.to_dict() |
|
|