File size: 1,061 Bytes
8f6a7e9 71a9ba7 2e98537 71a9ba7 2e98537 71a9ba7 8f6a7e9 71a9ba7 8f6a7e9 71a9ba7 8f6a7e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from typing import Dict, List
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
# in line with the default config of the model
CONFIG = {
'max_length': 512,
'num_return_sequences': 1,
'no_repeat_ngram_size': 2,
'top_k': 50,
'top_p': 0.95,
'do_sample': True,
}
class EndpointHandler:
def __init__(self, path: str = ""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
def __call__(self, data: Dict[str, str]) -> List[Dict[str, str]]:
inputs = data.pop('inputs', None)
if inputs is None or inputs == '':
return [{'generated_text': 'No input provided'}]
# preprocess
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
# inference
output_ids = self.model.generate(input_ids, **CONFIG)
# postprocess
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return [{'generated_text': response}] |