vit-gpt2-coco-en / pipeline.py
ydshieh
upload pipeline.py
bfb3cee
raw
history blame
1.41 kB
import os
from typing import Dict, List, Any
from PIL import Image
import jax
from transformers import ViTFeatureExtractor, AutoTokenizer, FlaxVisionEncoderDecoderModel
class PreTrainedPipeline():
def __init__(self, path=""):
model_dir = path
self.model = FlaxVisionEncoderDecoderModel.from_pretrained(model_dir)
self.feature_extractor = ViTFeatureExtractor.from_pretrained(model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
max_length = 16
num_beams = 4
self.gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
@jax.jit
def _generate(pixel_values):
output_ids = self.model.generate(pixel_values, **self.gen_kwargs).sequences
return output_ids
self.generate = _generate
# compile the model
image_path = os.path.join(path, 'val_000000039769.jpg')
image = Image.open(image_path)
self(image)
image.close()
def __call__(self, inputs: "Image.Image") -> List[str]:
"""
Args:
Return:
"""
pixel_values = self.feature_extractor(images=inputs, return_tensors="np").pixel_values
output_ids = self.generate(pixel_values)
preds = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds