import sys, os |
current_path = os.path.dirname(os.path.abspath(__file__)) |
sys.path.append(current_path) |
from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel |
from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel |
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration |
from transformers import ViTFeatureExtractor |
from PIL import Image |
import requests |
import numpy as np |
import jax |
import jax.numpy as jnp |
from transformers import GPT2Tokenizer |
max_length = 8 |
vision_model_name = 'google/vit-base-patch16-224-in21k' |
text_model_name = 'asi/gpt-fr-cased-small' |
project_encoder = False |
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained( |
vision_pretrained_model_name_or_path=vision_model_name, |
text_pretrained_model_name_or_path=text_model_name, |
project_encoder=project_encoder |
) |
model = flax_vit_gpt2_lm |
feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name) |
tokenizer = GPT2Tokenizer.from_pretrained(text_model_name) |
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
image = Image.open(requests.get(url, stream=True).raw) |
encoder_inputs = feature_extractor(images=image, return_tensors="jax") |
pixel_values = encoder_inputs.pixel_values |
print('=' * 60) |
print(f'pixel_values.shape = {pixel_values.shape}') |
sentence = 'mon chien est mignon' |
sentence += ' ' + tokenizer.eos_token |
with tokenizer.as_target_tokenizer(): |
labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np") |
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: |
""" |
Shift input ids one token to the right. |
""" |
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) |
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) |
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
return shifted_input_ids |
decoder_input_ids = shift_tokens_right( |
jnp.array(labels["input_ids"]), |
model.config.text_config.pad_token_id, |
model.config.decoder_start_token_id |
) |
decoder_input_ids = np.asarray(decoder_input_ids) |
decoder_attention_mask = labels["attention_mask"] |
print('=' * 60) |
print(f'decoder_inputs = {decoder_input_ids}') |
print(f'decoder_input_ids.shape = {decoder_input_ids.shape}') |
print(f'decoder_attention_mask = {decoder_attention_mask}') |
print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}') |
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name) |
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name) |
num_beams = 1 |
gen_kwargs = {"max_length": 6, "num_beams": num_beams} |
orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) |
gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) |
orig_token_ids = np.array(orig_gpt2_generated.sequences)[0] |
token_ids = np.array(gpt2_generated.sequences)[0] |
orig_caption = tokenizer.decode(orig_token_ids) |
caption = tokenizer.decode(token_ids) |
print('=' * 60) |
print(f'orig. GPT2 generated token ids: {orig_token_ids}') |
print(f'GPT2 generated token ids: {token_ids}') |
print('=' * 60) |
print(f'orig. GPT2 generated caption: {orig_caption}') |
print(f'GPT2 generated caption: {caption}') |
assert list(orig_token_ids) == list(token_ids) |
model_inputs = { |
'pixel_values': pixel_values, |
'attention_mask': None, |
'decoder_input_ids': decoder_input_ids, |
'decoder_attention_mask': decoder_attention_mask, |
'decoder_position_ids': None, |
} |
model_outputs = model(**model_inputs) |
logits = model_outputs[0] |
preds = np.argmax(logits, axis=-1) |
print('=' * 60) |
print('Flax: Vit-GPT2-LM') |
print('predicted token ids:') |
print(preds) |
num_beams = 1 |
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
batch = {'pixel_values': pixel_values} |
generated = model.generate(batch['pixel_values'], **gen_kwargs) |
token_ids = np.array(generated.sequences)[0] |
print('=' * 60) |
print(f'generated token ids: {token_ids}') |
caption = tokenizer.decode(token_ids) |
print('=' * 60) |
print(f'generated caption: {caption}') |
os.makedirs('./model/', exist_ok=True) |
model.save_pretrained(save_directory='./model/') |
_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/') |
_generated = _model.generate(batch['pixel_values'], **gen_kwargs) |
_token_ids = np.array(_generated.sequences)[0] |
print('=' * 60) |
print(f'new generated token ids: {_token_ids}') |
print(f'token_ids == new_token_ids: {token_ids == _token_ids}') |
import torch |
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel |
vision_model_pt = ViTModel.from_pretrained(vision_model_name) |
config = GPT2Config.from_pretrained(text_model_name) |
config.add_cross_attention = True |
text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config) |
encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt") |
encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs) |
encoder_hidden_states = encoder_pt_outputs.last_hidden_state |
text_model_pt_inputs = { |
'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32), |
'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32), |
'position_ids': None, |
'encoder_hidden_states': encoder_hidden_states |
} |
text_model_pt_outputs = text_model_pt(**text_model_pt_inputs) |
logits = text_model_pt_outputs[0] |
preds = np.argmax(logits.detach().numpy(), axis=-1) |
print('=' * 60) |
print('PyTroch: Vit --> GPT2-LM') |
print('predicted token ids:') |
print(preds) |