|
from typing import Callable, Optional, Tuple |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.core.frozen_dict import FrozenDict, unfreeze |
|
from jax import lax |
|
from jax.random import PRNGKey |
|
from transformers import GPT2Config, FlaxViTModel, ViTConfig |
|
from transformers.modeling_flax_outputs import ( |
|
FlaxCausalLMOutputWithCrossAttentions, |
|
FlaxSeq2SeqLMOutput, |
|
FlaxSeq2SeqModelOutput, |
|
) |
|
from transformers.models.bart.modeling_flax_bart import ( |
|
shift_tokens_right, |
|
) |
|
from .modeling_flax_gpt2 import ( |
|
FlaxGPT2Module, |
|
FlaxGPT2Model, |
|
FlaxGPT2LMHeadModule, |
|
FlaxGPT2LMHeadModel, |
|
FlaxPreTrainedModel |
|
) |
|
from transformers.models.vit.modeling_flax_vit import FlaxViTModule |
|
|
|
from .configuration_vit_gpt2 import ViTGPT2Config |
|
|
|
|
|
def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray: |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = jnp.roll(input_ids, 1, axis=-1) |
|
shifted_input_ids = jax.ops.index_update(shifted_input_ids, (..., 0), decoder_start_token_id) |
|
|
|
shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids) |
|
|
|
return shifted_input_ids |
|
|
|
class FlaxViTGPT2LMModule(nn.Module): |
|
config: ViTGPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
|
|
self.encoder = FlaxViTModule(self.config.vit_config, dtype=self.dtype) |
|
self.decoder = FlaxGPT2LMHeadModule(self.config.gpt2_config, dtype=self.dtype) |
|
|
|
def _get_encoder_module(self): |
|
return self.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.decoder |
|
|
|
def __call__( |
|
self, |
|
pixel_values, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
encoder_outputs = self.encoder( |
|
pixel_values=pixel_values, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
encoder_attention_mask=encoder_attention_mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
|
|
if not return_dict: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return FlaxSeq2SeqLMOutput( |
|
logits=decoder_outputs.logits, |
|
decoder_hidden_states=decoder_outputs.decoder_hidden_states, |
|
decoder_attentions=decoder_outputs.decoder_attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
class FlaxViTGPT2LMForConditionalGenerationModule(nn.Module): |
|
config: ViTGPT2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros |
|
|
|
def setup(self): |
|
self.model = FlaxViTGPT2LMModule(config=self.config, dtype=self.dtype) |
|
|
|
def _get_encoder_module(self): |
|
return self.model.encoder |
|
|
|
def _get_decoder_module(self): |
|
return self.model.decoder |
|
|
|
def __call__( |
|
self, |
|
pixel_values, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
deterministic: bool = True, |
|
): |
|
outputs = self.model( |
|
pixel_values=pixel_values, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
) |
|
|
|
return outputs |
|
|
|
|
|
class FlaxViTGPT2LMPreTrainedModel(FlaxPreTrainedModel): |
|
config_class = ViTGPT2Config |
|
base_model_prefix: str = "model" |
|
module_class: nn.Module = None |
|
|
|
def __init__( |
|
self, |
|
config: ViTGPT2Config, |
|
input_shape: Tuple = None, |
|
seed: int = 0, |
|
dtype: jnp.dtype = jnp.float32, |
|
**kwargs, |
|
): |
|
if input_shape is None: |
|
input_shape = ( |
|
(1, config.vit_config.image_size, config.vit_config.image_size, 3), |
|
(1, 1), |
|
) |
|
|
|
module = self.module_class(config=config, dtype=dtype, **kwargs) |
|
super().__init__( |
|
config, module, input_shape=input_shape, seed=seed, dtype=dtype |
|
) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
|
|
|
pixel_values = jax.random.normal(rng, input_shape[0]) |
|
|
|
|
|
|
|
input_ids = jnp.zeros(input_shape[1], dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
params_rng, dropout_rng = jax.random.split(rng) |
|
rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
|
return self.module.init( |
|
rngs, |
|
pixel_values, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
)["params"] |
|
|
|
def init_cache(self, batch_size, max_length, encoder_outputs): |
|
|
|
input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
|
attention_mask = jnp.ones_like(input_ids) |
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), |
|
input_ids.shape, |
|
) |
|
|
|
def _decoder_forward( |
|
module, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
) |
|
|
|
init_variables = self.module.init( |
|
jax.random.PRNGKey(0), |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
encoder_hidden_states=encoder_outputs[0], |
|
init_cache=True, |
|
method=_decoder_forward, |
|
) |
|
return unfreeze(init_variables["cache"]) |
|
|
|
def encode( |
|
self, |
|
pixel_values: jnp.ndarray, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
def _encoder_forward(module, pixel_values, **kwargs): |
|
encode_module = module._get_encoder_module() |
|
return encode_module(pixel_values, **kwargs) |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
pixel_values=jnp.array(pixel_values, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
method=_encoder_forward, |
|
) |
|
|
|
def decode( |
|
self, |
|
input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
if attention_mask is None: |
|
attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError( |
|
"Make sure to provide `position_ids` when passing `past_key_values`." |
|
) |
|
|
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
return decoder_module( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
) |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs, past = outputs |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs, past = outputs |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def __call__( |
|
self, |
|
pixel_values: jnp.ndarray, |
|
input_ids: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
train: bool = False, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
attention_mask = jnp.ones_like(input_ids) |
|
if position_ids is None: |
|
batch_size, sequence_length = input_ids.shape |
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
|
|
|
return self.module.apply( |
|
{"params": params or self.params}, |
|
pixel_values=jnp.array(pixel_values, dtype=jnp.float32), |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=not train, |
|
rngs=rngs, |
|
) |
|
|
|
|
|
class FlaxViTGPT2LMForConditionalGeneration(FlaxViTGPT2LMPreTrainedModel): |
|
module_class = FlaxViTGPT2LMForConditionalGenerationModule |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def decode( |
|
self, |
|
input_ids, |
|
encoder_outputs, |
|
encoder_attention_mask: Optional[jnp.ndarray] = None, |
|
attention_mask: Optional[jnp.ndarray] = None, |
|
position_ids: Optional[jnp.ndarray] = None, |
|
past_key_values: dict = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
deterministic: bool = True, |
|
params: dict = None, |
|
dropout_rng: PRNGKey = None, |
|
): |
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.return_dict |
|
) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
if encoder_attention_mask is None: |
|
batch_size, sequence_length = encoder_hidden_states.shape[:2] |
|
encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
batch_size, sequence_length = input_ids.shape |
|
if attention_mask is None: |
|
attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
|
if position_ids is None: |
|
if past_key_values is not None: |
|
raise ValueError( |
|
"Make sure to provide `position_ids` when passing `past_key_values`." |
|
) |
|
|
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
|
) |
|
|
|
|
|
rngs = {} |
|
if dropout_rng is not None: |
|
rngs["dropout"] = dropout_rng |
|
|
|
inputs = {"params": params or self.params} |
|
|
|
|
|
|
|
|
|
if past_key_values: |
|
inputs["cache"] = past_key_values |
|
mutable = ["cache"] |
|
else: |
|
mutable = False |
|
|
|
def _decoder_forward( |
|
module, |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
): |
|
decoder_module = module._get_decoder_module() |
|
outputs = decoder_module( |
|
input_ids, |
|
attention_mask, |
|
position_ids, |
|
**kwargs, |
|
) |
|
lm_logits = outputs[0] |
|
|
|
return lm_logits, outputs |
|
|
|
outputs = self.module.apply( |
|
inputs, |
|
input_ids=jnp.array(input_ids, dtype="i4"), |
|
attention_mask=jnp.array(attention_mask, dtype="i4"), |
|
position_ids=jnp.array(position_ids, dtype="i4"), |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
deterministic=deterministic, |
|
rngs=rngs, |
|
mutable=mutable, |
|
method=_decoder_forward, |
|
) |
|
|
|
if past_key_values is None: |
|
lm_logits, outputs = outputs |
|
else: |
|
(lm_logits, outputs), past = outputs |
|
|
|
if return_dict: |
|
outputs = FlaxCausalLMOutputWithCrossAttentions( |
|
logits=lm_logits, |
|
hidden_states=outputs.decoder_hidden_states, |
|
attentions=outputs.decoder_attentions, |
|
cross_attentions=outputs.cross_attentions, |
|
) |
|
else: |
|
outputs = (lm_logits,) + outputs[1:] |
|
|
|
|
|
if past_key_values is not None and return_dict: |
|
outputs["past_key_values"] = unfreeze(past["cache"]) |
|
return outputs |
|
elif past_key_values is not None and not return_dict: |
|
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
|
|
|
return outputs |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
max_length, |
|
encoder_attention_mask: Optional[jnp.DeviceArray] = None, |
|
attention_mask: Optional[jnp.DeviceArray] = None, |
|
encoder_outputs=None, |
|
**kwargs, |
|
): |
|
|
|
batch_size, seq_length = input_ids.shape |
|
|
|
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
|
|
|
|
|
|
|
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
|
if attention_mask is not None: |
|
position_ids = attention_mask.cumsum(axis=-1) - 1 |
|
extended_attention_mask = lax.dynamic_update_slice( |
|
extended_attention_mask, attention_mask, (0, 0) |
|
) |
|
else: |
|
position_ids = jnp.broadcast_to( |
|
jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) |
|
) |
|
|
|
return { |
|
"past_key_values": past_key_values, |
|
"encoder_outputs": encoder_outputs, |
|
"encoder_attention_mask": encoder_attention_mask, |
|
"attention_mask": extended_attention_mask, |
|
"position_ids": position_ids, |
|
} |
|
|
|
def update_inputs_for_generation(self, model_outputs, model_kwargs): |
|
model_kwargs["past_key_values"] = model_outputs.past_key_values |
|
model_kwargs["position_ids"] = ( |
|
model_kwargs["position_ids"][:, -1:] + 1 |
|
) |
|
return model_kwargs |
|
|
|
@classmethod |
|
def from_vit_gpt2_pretrained( |
|
cls, |
|
vit_model_name_or_path: str = None, |
|
gpt2_model_name_or_path: str = None, |
|
*model_args, |
|
**kwargs, |
|
) -> FlaxViTGPT2LMPreTrainedModel: |
|
|
|
kwargs_gpt2 = { |
|
argument[len("gpt2_") :]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("gpt2_") |
|
} |
|
|
|
kwargs_vit = { |
|
argument[len("vit_") :]: value |
|
for argument, value in kwargs.items() |
|
if argument.startswith("vit_") |
|
} |
|
|
|
|
|
for key in kwargs_gpt2.keys(): |
|
del kwargs["gpt2_" + key] |
|
for key in kwargs_vit.keys(): |
|
del kwargs["vit_" + key] |
|
|
|
|
|
gpt2_model = kwargs_gpt2.pop("model", None) |
|
if gpt2_model is None: |
|
assert ( |
|
gpt2_model_name_or_path is not None |
|
), "If `model` is not defined as an argument, a `gpt2_model_name_or_path` has to be defined" |
|
|
|
if "config" not in kwargs_gpt2: |
|
gpt2_config = GPT2Config.from_pretrained(gpt2_model_name_or_path) |
|
kwargs_gpt2["config"] = gpt2_config |
|
|
|
kwargs_gpt2["config"].add_cross_attention = True |
|
gpt2_model = FlaxGPT2LMHeadModel.from_pretrained( |
|
gpt2_model_name_or_path, *model_args, **kwargs_gpt2 |
|
) |
|
|
|
vit_model = kwargs_vit.pop("model", None) |
|
if vit_model is None: |
|
assert ( |
|
vit_model_name_or_path is not None |
|
), "If `model` is not defined as an argument, a `vit_model_name_or_path` has to be defined" |
|
|
|
if "config" not in kwargs_vit: |
|
vit_config = ViTConfig.from_pretrained(vit_model_name_or_path) |
|
kwargs_vit["config"] = vit_config |
|
|
|
vit_model = FlaxViTModel.from_pretrained( |
|
vit_model_name_or_path, *model_args, **kwargs_vit |
|
) |
|
|
|
|
|
dtype = kwargs.pop("dtype", jnp.float32) |
|
config = ViTGPT2Config.from_vit_gpt2_configs( |
|
vit_model.config, gpt2_model.config, **kwargs |
|
) |
|
|
|
|
|
model = cls(config, *model_args, dtype=dtype, **kwargs) |
|
model.params["model"]["encoder"] = vit_model.params |
|
model.params["model"]["decoder"] = gpt2_model.params |
|
|
|
return model |
|
|