import sys, os, datasets, json current_path = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_path) # jax import jax # Main model - ViTGPT2LM from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration # Vit - as encoder from transformers import ViTFeatureExtractor from PIL import Image import requests import numpy as np # GPT2 / GPT2LM - as decoder from transformers import ViTFeatureExtractor, GPT2Tokenizer ckpt_no = 5 model_name_or_path = f'./outputs/ckpt_{ckpt_no}/' flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_pretrained(model_name_or_path) vit_model_name = 'google/vit-base-patch16-224-in21k' feature_extractor = ViTFeatureExtractor.from_pretrained(vit_model_name) gpt2_model_name = 'asi/gpt-fr-cased-small' tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name) max_length = 32 num_beams = 8 gen_kwargs = {"max_length": max_length, "num_beams": num_beams} @jax.jit def predict_fn(pixel_values): return flax_vit_gpt2_lm.generate(pixel_values, **gen_kwargs) def predict(image): # batch dim is added automatically encoder_inputs = feature_extractor(images=image, return_tensors="jax") pixel_values = encoder_inputs.pixel_values # generation generation = predict_fn(pixel_values) token_ids = np.array(generation.sequences)[0] caption = tokenizer.decode(token_ids) return caption, token_ids if __name__ == '__main__': from datetime import datetime split = 'val' image_id = 322141 p = f'/home/33611/caption/{split}2014/COCO_{split}2014_{str(image_id).zfill(12)}.jpg' image = Image.open(p) caption, token_ids = predict(image) image.close() print(f'token_ids: {token_ids}') print(f'caption: {caption}') ds = datasets.load_dataset('./coco_dataset_script.py', data_dir='/home/33611/caption/') ds = ds['validation'] ds = ds.select(range(100)) predictions = [] for ex in ds: p = ex['image_file'] image = Image.open(p) s = datetime.now() caption, token_ids = predict(image) caption = caption.replace('', '').replace('', '').replace('', '').strip() image.close() e = datetime.now() e = (e - s).total_seconds() print(f' timing: {e}') print(f' caption: {ex["fr"]}') print(f'prediction: {caption}') print('-' * 20) ex['pred'] = caption predictions.append(ex) with open(f'ckpt_{ckpt_no}_preds.json', 'w', encoding='UTF-8') as fp: json.dump(predictions, fp, ensure_ascii=False, indent=4)