|
import sys, os |
|
|
|
current_path = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_path) |
|
|
|
from transformers import FlaxGPT2LMHeadModel as Orig_FlaxGPT2LMHeadModel |
|
from vit_gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel |
|
|
|
|
|
from vit_gpt2.modeling_flax_vit_gpt2_lm import FlaxViTGPT2LMForConditionalGeneration |
|
|
|
|
|
from transformers import ViTFeatureExtractor |
|
from PIL import Image |
|
import requests |
|
import numpy as np |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
from transformers import GPT2Tokenizer |
|
|
|
max_length = 8 |
|
|
|
|
|
|
|
|
|
vision_model_name = 'google/vit-base-patch16-224-in21k' |
|
text_model_name = 'asi/gpt-fr-cased-small' |
|
|
|
project_encoder = False |
|
|
|
flax_vit_gpt2_lm = FlaxViTGPT2LMForConditionalGeneration.from_vision_text_pretrained( |
|
vision_pretrained_model_name_or_path=vision_model_name, |
|
text_pretrained_model_name_or_path=text_model_name, |
|
project_encoder=project_encoder |
|
) |
|
model = flax_vit_gpt2_lm |
|
|
|
|
|
|
|
|
|
feature_extractor = ViTFeatureExtractor.from_pretrained(vision_model_name) |
|
tokenizer = GPT2Tokenizer.from_pretrained(text_model_name) |
|
|
|
|
|
url = 'http://images.cocodataset.org/val2017/000000039769.jpg' |
|
image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
encoder_inputs = feature_extractor(images=image, return_tensors="jax") |
|
pixel_values = encoder_inputs.pixel_values |
|
|
|
print('=' * 60) |
|
print(f'pixel_values.shape = {pixel_values.shape}') |
|
|
|
|
|
sentence = 'mon chien est mignon' |
|
|
|
sentence += ' ' + tokenizer.eos_token |
|
|
|
|
|
with tokenizer.as_target_tokenizer(): |
|
labels = tokenizer(sentence, max_length=max_length, padding="max_length", truncation=True, return_tensors="np") |
|
|
|
|
|
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) |
|
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) |
|
|
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
|
|
return shifted_input_ids |
|
|
|
decoder_input_ids = shift_tokens_right( |
|
jnp.array(labels["input_ids"]), |
|
model.config.text_config.pad_token_id, |
|
model.config.decoder_start_token_id |
|
) |
|
decoder_input_ids = np.asarray(decoder_input_ids) |
|
|
|
decoder_attention_mask = labels["attention_mask"] |
|
|
|
print('=' * 60) |
|
print(f'decoder_inputs = {decoder_input_ids}') |
|
print(f'decoder_input_ids.shape = {decoder_input_ids.shape}') |
|
print(f'decoder_attention_mask = {decoder_attention_mask}') |
|
print(f'decoder_attention_mask.shape = {decoder_attention_mask.shape}') |
|
|
|
|
|
|
|
|
|
orig_gpt2_lm = Orig_FlaxGPT2LMHeadModel.from_pretrained(text_model_name) |
|
gpt2_lm = FlaxGPT2LMHeadModel.from_pretrained(text_model_name) |
|
|
|
|
|
num_beams = 1 |
|
gen_kwargs = {"max_length": 6, "num_beams": num_beams} |
|
|
|
orig_gpt2_generated = orig_gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) |
|
gpt2_generated = gpt2_lm.generate(decoder_input_ids[:, 0:3], **gen_kwargs) |
|
|
|
orig_token_ids = np.array(orig_gpt2_generated.sequences)[0] |
|
token_ids = np.array(gpt2_generated.sequences)[0] |
|
|
|
orig_caption = tokenizer.decode(orig_token_ids) |
|
caption = tokenizer.decode(token_ids) |
|
|
|
print('=' * 60) |
|
print(f'orig. GPT2 generated token ids: {orig_token_ids}') |
|
print(f'GPT2 generated token ids: {token_ids}') |
|
|
|
print('=' * 60) |
|
print(f'orig. GPT2 generated caption: {orig_caption}') |
|
print(f'GPT2 generated caption: {caption}') |
|
|
|
assert list(orig_token_ids) == list(token_ids) |
|
|
|
|
|
|
|
|
|
model_inputs = { |
|
'pixel_values': pixel_values, |
|
'attention_mask': None, |
|
'decoder_input_ids': decoder_input_ids, |
|
'decoder_attention_mask': decoder_attention_mask, |
|
'decoder_position_ids': None, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
model_outputs = model(**model_inputs) |
|
logits = model_outputs[0] |
|
preds = np.argmax(logits, axis=-1) |
|
|
|
print('=' * 60) |
|
print('Flax: Vit-GPT2-LM') |
|
print('predicted token ids:') |
|
print(preds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_beams = 1 |
|
gen_kwargs = {"max_length": max_length, "num_beams": num_beams} |
|
|
|
batch = {'pixel_values': pixel_values} |
|
generated = model.generate(batch['pixel_values'], **gen_kwargs) |
|
token_ids = np.array(generated.sequences)[0] |
|
|
|
print('=' * 60) |
|
print(f'generated token ids: {token_ids}') |
|
|
|
caption = tokenizer.decode(token_ids) |
|
|
|
print('=' * 60) |
|
print(f'generated caption: {caption}') |
|
|
|
|
|
|
|
|
|
|
|
os.makedirs('./model/', exist_ok=True) |
|
model.save_pretrained(save_directory='./model/') |
|
|
|
|
|
_model = FlaxViTGPT2LMForConditionalGeneration.from_pretrained('./model/') |
|
|
|
|
|
_generated = _model.generate(batch['pixel_values'], **gen_kwargs) |
|
_token_ids = np.array(_generated.sequences)[0] |
|
|
|
print('=' * 60) |
|
print(f'new generated token ids: {_token_ids}') |
|
print(f'token_ids == new_token_ids: {token_ids == _token_ids}') |
|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import ViTModel, GPT2Config, GPT2LMHeadModel |
|
|
|
vision_model_pt = ViTModel.from_pretrained(vision_model_name) |
|
config = GPT2Config.from_pretrained(text_model_name) |
|
|
|
config.add_cross_attention = True |
|
text_model_pt = GPT2LMHeadModel.from_pretrained(text_model_name, config=config) |
|
|
|
encoder_pt_inputs = feature_extractor(images=image, return_tensors="pt") |
|
encoder_pt_outputs = vision_model_pt(**encoder_pt_inputs) |
|
encoder_hidden_states = encoder_pt_outputs.last_hidden_state |
|
|
|
|
|
text_model_pt_inputs = { |
|
'input_ids': torch.tensor(decoder_input_ids, dtype=torch.int32), |
|
'attention_mask': torch.tensor(decoder_attention_mask, dtype=torch.int32), |
|
'position_ids': None, |
|
'encoder_hidden_states': encoder_hidden_states |
|
} |
|
|
|
|
|
text_model_pt_outputs = text_model_pt(**text_model_pt_inputs) |
|
logits = text_model_pt_outputs[0] |
|
preds = np.argmax(logits.detach().numpy(), axis=-1) |
|
|
|
print('=' * 60) |
|
print('PyTroch: Vit --> GPT2-LM') |
|
print('predicted token ids:') |
|
print(preds) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|