|
import torch |
|
import gc |
|
from ts.torch_handler.base_handler import BaseHandler |
|
from transformers import GPT2LMHeadModel |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SampleTransformerModel(BaseHandler): |
|
def __init__(self): |
|
super(SampleTransformerModel, self).__init__() |
|
self.model = None |
|
self.device = None |
|
self.initialized = False |
|
|
|
def load_model(self, model_dir): |
|
self.model = GPT2LMHeadModel.from_pretrained(model_dir, return_dict=True) |
|
self.model.to(self.device) |
|
|
|
def initialize(self, ctx): |
|
|
|
properties = ctx.system_properties |
|
model_dir = properties.get("model_dir") |
|
self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu") |
|
|
|
self.load_model(model_dir) |
|
|
|
self.model.eval() |
|
self.initialized = True |
|
|
|
def preprocess(self, requests): |
|
input_batch = {} |
|
for idx, data in enumerate(requests): |
|
input_ids = torch.tensor([data.get("body").get("text")]).to(self.device) |
|
input_batch["input_ids"] = input_ids |
|
input_batch["num_samples"] = data.get("body").get("num_samples") |
|
input_batch["length"] = data.get("body").get("length") + len(data.get("body").get("text")) |
|
del requests |
|
gc.collect() |
|
return input_batch |
|
|
|
def inference(self, input_batch): |
|
input_ids = input_batch["input_ids"] |
|
length = input_batch["length"] |
|
|
|
inference_output = self.model.generate(input_ids, |
|
bos_token_id=self.model.config.bos_token_id, |
|
eos_token_id=self.model.config.eos_token_id, |
|
pad_token_id=self.model.config.eos_token_id, |
|
do_sample=True, |
|
max_length=length, |
|
top_k=50, |
|
top_p=0.95, |
|
no_repeat_ngram_size=2, |
|
num_return_sequences=input_batch["num_samples"]) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
del input_batch |
|
gc.collect() |
|
return inference_output |
|
|
|
def postprocess(self, inference_output): |
|
output = inference_output.cpu().numpy().tolist() |
|
del inference_output |
|
gc.collect() |
|
return [output] |
|
|
|
def handle(self, data, context): |
|
|
|
data = self.preprocess(data) |
|
data = self.inference(data) |
|
data = self.postprocess(data) |
|
return data |
|
|