File size: 7,716 Bytes
f338d56 845642f f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 dc74cb9 f338d56 3b81fb5 dc74cb9 f338d56 f082d66 dc74cb9 f082d66 dc74cb9 f338d56 3b81fb5 dc74cb9 f338d56 845642f f338d56 dc74cb9 3b81fb5 dc74cb9 845642f dc74cb9 f338d56 3b81fb5 845642f f082d66 3b81fb5 f338d56 dc74cb9 3b81fb5 dc74cb9 845642f dc74cb9 f338d56 845642f f338d56 dc74cb9 f338d56 3b81fb5 f338d56 dc74cb9 f338d56 845642f dc74cb9 845642f f338d56 845642f f338d56 845642f dc74cb9 845642f 3b81fb5 845642f f338d56 845642f 3b81fb5 06bcf58 3b81fb5 06bcf58 3b81fb5 06bcf58 3b81fb5 06bcf58 3b81fb5 06bcf58 3b81fb5 06bcf58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import sys, os
current_path = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_path)
from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel
from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel
# Main model - ViTGPT2LM
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration
# ViT - as encoder
from transformers import ViTFeatureExtractor
from PIL import Image
import requests
import numpy as np
import jax
import jax.numpy as jnp
# GPT2+LM - as decoder
from transformers import GPT2Tokenizer
max_length = 8
# ================================================================================
# Models preparation
vision_model_name = 'google/vit-base-patch16-224-in21k'
text_model_name = 'asi/gpt-fr-cased-small'
project_encoder = False
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained(
vision_pretrained_model_name_or_path=vision_model_name,
text_pretrained_model_name_or_path=text_model_name,
project_encoder=project_encoder
)
model = flax_vit_gpt2_lm
# ================================================================================
# Inputs preparation
feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name)
tokenizer = GPT2Tokenizer.from_pretrained(text_model_name)
# encoder data
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
# batch dim is added automatically
encoder_inputs = feature_extractor(images=image, return_tensors="jax")
pixel_values = encoder_inputs.pixel_values
print('=' * 60)
print(f'pixel_values.shape = {pixel_values.shape}')
# decoder data
sentence = 'mon chien est mignon'
# IMPORTANT: For training/evaluation/attention_mask/loss
sentence += ' ' + tokenizer.eos_token
# batch dim is added automatically
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np")
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
"""
Shift input ids one token to the right.
"""
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1)
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
return shifted_input_ids
decoder_input_ids = shift_tokens_right(
jnp.array(labels["input_ids"]),
model.config.text_config.pad_token_id,
model.config.decoder_start_token_id
)
decoder_input_ids = np.asarray(decoder_input_ids)
# We need decoder_attention_mask so we can ignore pad tokens from loss
decoder_attention_mask = labels["attention_mask"]
print('=' * 60)
print(f'decoder_inputs = {decoder_input_ids}')
print(f'decoder_input_ids.shape = {decoder_input_ids.shape}')
print(f'decoder_attention_mask = {decoder_attention_mask}')
print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}')
# ================================================================================
# Check `FlaxGPT2LMHeadModel` has the same results in the new version (when no `encoder_outputs` is provided).
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name)
# Generation!
num_beams = 1
gen_kwargs = {"max_length": 6, "num_beams": num_beams}
orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs)
orig_token_ids = np.array(orig_gpt2_generated.sequences)[0]
token_ids = np.array(gpt2_generated.sequences)[0]
orig_caption = tokenizer.decode(orig_token_ids)
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'orig. GPT2 generated token ids: {orig_token_ids}')
print(f'GPT2 generated token ids: {token_ids}')
print('=' * 60)
print(f'orig. GPT2 generated caption: {orig_caption}')
print(f'GPT2 generated caption: {caption}')
assert list(orig_token_ids) == list(token_ids)
# ================================================================================
# model data
model_inputs = {
'pixel_values': pixel_values,
'attention_mask': None,
'decoder_input_ids': decoder_input_ids,
'decoder_attention_mask': decoder_attention_mask,
'decoder_position_ids': None,
}
# ================================================================================
# Check `model.__call__()`
# Model call
model_outputs = model(**model_inputs)
logits = model_outputs[0]
preds = np.argmax(logits, axis=-1)
print('=' * 60)
print('Flax: Vit-GPT2-LM')
print('predicted token ids:')
print(preds)
# encoder_last_hidden_state = model_outputs['encoder_last_hidden_state']
# print(encoder_last_hidden_state)
# encoder_kwargs = {}
# encoder_outputs = flax_vit_gpt2_lm.encode(pixel_values, return_dict=True, **encoder_kwargs)
# print(encoder_outputs['last_hidden_state'])
# ================================================================================
# Check generation
# Generation!
num_beams = 1
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
batch = {'pixel_values': pixel_values}
generated = model.generate(batch['pixel_values'], **gen_kwargs)
token_ids = np.array(generated.sequences)[0]
print('=' * 60)
print(f'generated token ids: {token_ids}')
caption = tokenizer.decode(token_ids)
print('=' * 60)
print(f'generated caption: {caption}')
# ================================================================================
# Check save & load
# save
os.makedirs('./model/', exist_ok=True)
model.save_pretrained(save_directory='./model/')
# load
_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/')
# check if the result is the same as before
_generated = _model.generate(batch['pixel_values'], **gen_kwargs)
_token_ids = np.array(_generated.sequences)[0]
print('=' * 60)
print(f'new generated token ids: {_token_ids}')
print(f'token_ids == new_token_ids: {token_ids == _token_ids}')
# ================================================================================
# Check PyTorch version's output - it should be the same as above
import torch
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel
vision_model_pt = ViTModel.from_pretrained(vision_model_name)
config = GPT2Config.from_pretrained(text_model_name)
# config.is_encoder_decoder = True
config.add_cross_attention = True
text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config)
encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt")
encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs)
encoder_hidden_states = encoder_pt_outputs.last_hidden_state
# model data
text_model_pt_inputs = {
'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32),
'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32),
'position_ids': None,
'encoder_hidden_states': encoder_hidden_states
}
# Model call
text_model_pt_outputs = text_model_pt(**text_model_pt_inputs)
logits = text_model_pt_outputs[0]
preds = np.argmax(logits.detach().numpy(), axis=-1)
print('=' * 60)
print('PyTroch: Vit --> GPT2-LM')
print('predicted token ids:')
print(preds)
#generated = text_model_pt.generate(encoder_outputs=vision_model_pt_outputs, **gen_kwargs)
#token_ids = np.array(generated.sequences)[0]
#print('=' * 60)
#print(f'Pytorch\'s GPT2 LM generated token ids: {token_ids}')
#caption = tokenizer.decode(token_ids)
#print('=' * 60)
#print(f'Pytorch\'s GPT2 LM generated caption: {caption}')
|