File size: 938 Bytes
d4459cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from typing import Any, Dict
from transformers import Pipeline, AutoModel, AutoTokenizer
from transformers.pipelines.base import GenericTensor, ModelOutput


class HiveTokenClassification(Pipeline):
    def _sanitize_parameters(self, **kwargs):
        forward_parameters = {}
        if "output_style" in kwargs:
            forward_parameters["output_style"] = kwargs["output_style"]
        return {}, forward_parameters, {}

    def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
        return input_

    def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
        return self.model.predict(input_tensors, self.tokenizer, output_style=forward_parameters['output_style'])

    def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
        return {"output": model_outputs, "length": len(model_outputs)}