image-captioning / model.py
ydshieh
try import model related packages
686f21e
raw
history blame
447 Bytes
import os, sys
import numpy as np
from PIL import Image
import jax
from transformers import ViTFeatureExtractor
from transformers import GPT2Tokenizer
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
# Main model - ViTGPT2LM
# from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
def predict(image):
return 'dummy caption!', ['dummy', 'caption', '!'], [1, 2, 3]