File size: 938 Bytes
d4459cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from typing import Any, Dict
from transformers import Pipeline, AutoModel, AutoTokenizer
from transformers.pipelines.base import GenericTensor, ModelOutput
class HiveTokenClassification(Pipeline):
def _sanitize_parameters(self, **kwargs):
forward_parameters = {}
if "output_style" in kwargs:
forward_parameters["output_style"] = kwargs["output_style"]
return {}, forward_parameters, {}
def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]:
return input_
def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput:
return self.model.predict(input_tensors, self.tokenizer, output_style=forward_parameters['output_style'])
def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
return {"output": model_outputs, "length": len(model_outputs)}
|