|
from typing import Any, Optional, Tuple, Union |
|
|
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
from .configuration_aimv2 import AIMv2Config |
|
from flax.core import frozen_dict |
|
from transformers import FlaxPreTrainedModel |
|
from transformers.modeling_flax_outputs import FlaxBaseModelOutput |
|
|
|
__all__ = ["FlaxAIMv2Model"] |
|
|
|
|
|
def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: jax.Array) -> jax.Array: |
|
omega = jnp.arange(embed_dim // 2, dtype=pos.dtype) |
|
omega /= embed_dim / 2.0 |
|
omega = 1.0 / 10000**omega |
|
pos = pos.reshape(-1) |
|
out = pos[:, None] * omega[None, :] |
|
emb_sin, emb_cos = jnp.sin(out), jnp.cos(out) |
|
emb = jnp.concatenate([emb_sin, emb_cos], axis=1) |
|
return emb |
|
|
|
|
|
def get_sincos_pos_embed( |
|
h: int, |
|
w: int, |
|
embed_dim: int, |
|
dtype: jnp.dtype = jnp.float32, |
|
) -> jax.Array: |
|
assert embed_dim % 2 == 0, embed_dim |
|
grid_h = jnp.arange(h, dtype=dtype) |
|
grid_w = jnp.arange(w, dtype=dtype) |
|
grid = jnp.meshgrid(grid_w, grid_h, indexing="xy") |
|
grid = jnp.stack(grid, axis=0) |
|
grid = grid.reshape([2, 1, h, w]) |
|
emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
|
emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
pos_embed = jnp.concatenate([emb_h, emb_w], axis=1) |
|
return pos_embed |
|
|
|
|
|
class FlaxRMSNorm(nn.Module): |
|
eps: float = 1e-6 |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array) -> jax.Array: |
|
dim = x.shape[-1] |
|
scale = self.param("scale", nn.initializers.ones_init(), (dim,)) |
|
output = self._norm(x.astype(jnp.float32)).astype(x.dtype) |
|
output = output * scale.astype(x.dtype) |
|
return output |
|
|
|
def _norm(self, x: jax.Array) -> jax.Array: |
|
return x * jax.lax.rsqrt(jnp.power(x, 2).mean(-1, keepdims=True) + self.eps) |
|
|
|
|
|
class FlaxAIMv2SwiGLUFFN(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array) -> jax.Array: |
|
hidden_features = self.config.intermediate_size |
|
in_features = self.config.hidden_size |
|
bias = self.config.use_bias |
|
|
|
x1 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc1")(x) |
|
x2 = nn.Dense(hidden_features, use_bias=bias, dtype=self.dtype, name="fc3")(x) |
|
x = nn.silu(x1) * x2 |
|
x = nn.Dense(in_features, use_bias=bias, dtype=self.dtype, name="fc2")(x) |
|
return x |
|
|
|
|
|
class FlaxAIMv2PatchEmbed(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array) -> jax.Array: |
|
patch_size = (self.config.patch_size, self.config.patch_size) |
|
x = x.transpose(0, 2, 3, 1) |
|
x = nn.Conv( |
|
self.config.hidden_size, |
|
kernel_size=patch_size, |
|
strides=patch_size, |
|
padding=(0, 0), |
|
dtype=self.dtype, |
|
name="proj", |
|
)(x) |
|
x = jax.lax.collapse(x, 1, 3) |
|
x = FlaxRMSNorm(self.config.rms_norm_eps, name="norm")(x) |
|
return x |
|
|
|
|
|
class FlaxAIMv2ViTPreprocessor(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__(self, x: jax.Array) -> jax.Array: |
|
_, _, H, W = x.shape |
|
patch_h = self.config.patch_size |
|
patch_w = self.config.patch_size |
|
|
|
tokens = FlaxAIMv2PatchEmbed(self.config, dtype=self.dtype, name="patchifier")( |
|
x |
|
) |
|
pos_embed = get_sincos_pos_embed( |
|
H // patch_h, |
|
W // patch_w, |
|
embed_dim=self.config.hidden_size, |
|
dtype=self.dtype, |
|
) |
|
tokens = tokens + pos_embed |
|
return tokens |
|
|
|
|
|
class FlaxAIMv2Attention(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
x: jax.Array, |
|
mask: Optional[jax.Array] = None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
) -> Tuple[jax.Array, Optional[jax.Array]]: |
|
B, N, C = x.shape |
|
dim, num_heads = self.config.hidden_size, self.config.num_attention_heads |
|
|
|
qkv = nn.Dense( |
|
dim * 3, use_bias=self.config.qkv_bias, dtype=self.dtype, name="qkv" |
|
)(x) |
|
qkv = qkv.reshape(B, N, 3, num_heads, C // num_heads).transpose(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
attn_weights = nn.dot_product_attention_weights( |
|
q.swapaxes(-3, -2), |
|
k.swapaxes(-3, -2), |
|
mask=mask, |
|
deterministic=deterministic, |
|
dtype=self.dtype, |
|
) |
|
attn_weights = nn.Dropout( |
|
self.config.attention_dropout, deterministic=deterministic, name="attn_drop" |
|
)(attn_weights) |
|
|
|
x = (attn_weights @ v).swapaxes(1, 2).reshape(B, N, C) |
|
x = nn.Dense(dim, use_bias=self.config.use_bias, dtype=self.dtype, name="proj")( |
|
x |
|
) |
|
x = nn.Dropout( |
|
self.config.projection_dropout, |
|
deterministic=deterministic, |
|
name="proj_drop", |
|
)(x) |
|
return (x, attn_weights) if output_attentions else (x, None) |
|
|
|
|
|
class FlaxAIMv2Block(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
def setup(self): |
|
self.attn = FlaxAIMv2Attention(self.config, dtype=self.dtype, name="attn") |
|
self.norm_1 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_1") |
|
self.mlp = FlaxAIMv2SwiGLUFFN(self.config, dtype=self.dtype, name="mlp") |
|
self.norm_2 = FlaxRMSNorm(self.config.rms_norm_eps, name="norm_2") |
|
|
|
def __call__( |
|
self, |
|
x: jax.Array, |
|
mask: Optional[jax.Array] = None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
) -> Tuple[jax.Array, Optional[jax.Array]]: |
|
features, attention = self.attn( |
|
self.norm_1(x), |
|
mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
) |
|
x = x + features |
|
x = x + self.mlp(self.norm_2(x)) |
|
return x, attention |
|
|
|
|
|
class FlaxAIMv2Transformer(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
tokens: jax.Array, |
|
mask: Optional[jax.Array] = None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
) -> Tuple[ |
|
jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]] |
|
]: |
|
hidden_states = () if output_hidden_states else None |
|
attentions = () if output_attentions else None |
|
for blk_id, block in enumerate(range(self.config.num_hidden_layers)): |
|
tokens, attention = FlaxAIMv2Block( |
|
self.config, dtype=self.dtype, name=f"layers_{blk_id}" |
|
)( |
|
tokens, |
|
mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
) |
|
if output_hidden_states: |
|
hidden_states += (tokens,) |
|
if output_attentions: |
|
attentions += (attention,) |
|
tokens = FlaxRMSNorm(self.config.rms_norm_eps, name="post_trunk_norm")(tokens) |
|
return tokens, hidden_states, attentions |
|
|
|
|
|
class FlaxAIMv2Module(nn.Module): |
|
config: AIMv2Config |
|
dtype: jnp.dtype = jnp.float32 |
|
|
|
@nn.compact |
|
def __call__( |
|
self, |
|
x: jax.Array, |
|
mask: Optional[jax.Array] = None, |
|
deterministic: bool = True, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
) -> Tuple[ |
|
jax.Array, Optional[Tuple[jax.Array, ...]], Optional[Tuple[jax.Array, ...]] |
|
]: |
|
x = FlaxAIMv2ViTPreprocessor( |
|
self.config, dtype=self.dtype, name="preprocessor" |
|
)(x) |
|
x, hidden_states, attentions = FlaxAIMv2Transformer( |
|
self.config, dtype=self.dtype, name="trunk" |
|
)( |
|
x, |
|
mask, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
return x, hidden_states, attentions |
|
|
|
|
|
class FlaxAIMv2PretrainedModel(FlaxPreTrainedModel): |
|
config_class = AIMv2Config |
|
base_model_prefix = "aimv2" |
|
main_input_name = "pixel_values" |
|
|
|
def __init__( |
|
self, |
|
config: AIMv2Config, |
|
input_shape: Optional[Tuple[int, int, int, int]] = None, |
|
dtype: jnp.dtype = jnp.float32, |
|
**kwargs: Any, |
|
): |
|
if input_shape is None: |
|
|
|
input_shape = (1, 3, 224, 224) |
|
super().__init__( |
|
config, |
|
module=FlaxAIMv2Module(config, dtype=dtype), |
|
input_shape=input_shape, |
|
dtype=dtype, |
|
**kwargs, |
|
) |
|
|
|
def init_weights( |
|
self, |
|
rng: jax.Array, |
|
input_shape: Tuple[int, ...], |
|
params: Optional[frozen_dict.FrozenDict] = None, |
|
) -> frozen_dict.FrozenDict: |
|
del params |
|
input_pixels = jnp.empty(input_shape) |
|
params = self.module.init(rng, input_pixels, deterministic=True) |
|
return params["params"] |
|
|
|
|
|
class FlaxAIMv2Model(FlaxAIMv2PretrainedModel): |
|
def __call__( |
|
self, |
|
pixel_values: jax.Array, |
|
params: Optional[frozen_dict.FrozenDict] = None, |
|
mask: Optional[jax.Array] = None, |
|
dropout_rng: Optional[jax.Array] = None, |
|
deterministic: bool = True, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[ |
|
Tuple[jax.Array], |
|
Tuple[jax.Array, Tuple[jax.Array, ...]], |
|
Tuple[jax.Array, Tuple[jax.Array, ...], Tuple[jax.Array, ...]], |
|
FlaxBaseModelOutput, |
|
]: |
|
if params is None: |
|
params = self.params |
|
if output_attentions is None: |
|
output_attentions = self.config.output_attentions |
|
if output_hidden_states is None: |
|
output_hidden_states = self.config.output_hidden_states |
|
if return_dict is None: |
|
return_dict = self.config.use_return_dict |
|
|
|
rngs = None if deterministic else {"dropout": dropout_rng} |
|
|
|
x, hidden_states, attentions = self.module.apply( |
|
{"params": params}, |
|
pixel_values, |
|
mask, |
|
rngs=rngs, |
|
deterministic=deterministic, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
if not return_dict: |
|
res = (x,) |
|
res += (hidden_states,) if output_hidden_states else () |
|
res += (attentions,) if output_attentions else () |
|
return res |
|
|
|
return FlaxBaseModelOutput( |
|
last_hidden_state=x, |
|
hidden_states=hidden_states, |
|
attentions=attentions, |
|
) |
|
|
|
|