from typing import Any, Dict from transformers import Pipeline, AutoModel, AutoTokenizer from transformers.pipelines.base import GenericTensor, ModelOutput class HiveTokenClassification(Pipeline): def _sanitize_parameters(self, **kwargs): forward_parameters = {} if "output_style" in kwargs: forward_parameters["output_style"] = kwargs["output_style"] return {}, forward_parameters, {} def preprocess(self, input_: Any, **preprocess_parameters: Dict) -> Dict[str, GenericTensor]: return input_ def _forward(self, input_tensors: Dict[str, GenericTensor], **forward_parameters: Dict) -> ModelOutput: return self.model.predict(input_tensors, self.tokenizer, output_style=forward_parameters['output_style']) def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any: return {"output": model_outputs, "length": len(model_outputs)}