# coding=utf-8 # Copyright 2023 The EleutherAI and The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Flax GPT NeoX model.""" from typing import Optional, Tuple import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring from transformers.models.gpt_neox.configuration_gpt_neox import GPTNeoXConfig from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "EleutherAI/gpt-neox-20b" _CONFIG_FOR_DOC = "GPTNeoXConfig" GPT_NEOX_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a Flax nn [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: config ([`GPTNeoXConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. **Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and [`~FlaxPreTrainedModel.to_bf16`]. """ GPT_NEOX_INPUTS_DOCSTRING = r""" Args: input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ def rotate_half(hidden_states): first_half = hidden_states[..., : hidden_states.shape[-1] // 2] second_half = hidden_states[..., hidden_states.shape[-1] // 2 :] return jnp.concatenate((-second_half, first_half), axis=-1) class FlaxGPTNeoXRotaryEmbedding(nn.Module): dim: int max_position_embeddings: int base: int = 10000 dtype: jnp.dtype = jnp.float32 def setup(self): self.inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2).astype(self.dtype) / self.dim)) self.cos_cached, self.sin_cached = self._compute_cos_sin(self.max_position_embeddings) def _get_cos_sin_cache(self, seq_len): if seq_len > self.max_position_embeddings: return self._compute_cos_sin(seq_len) else: return self.cos_cached, self.sin_cached def _compute_cos_sin(self, seq_len): t = jnp.arange(seq_len, dtype=self.inv_freq.dtype) freqs = jnp.outer(t, self.inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0) sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0) return cos, sin def __call__(self, seq_len=None): cos_cached, sin_cached = self._get_cos_sin_cache(seq_len) return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...] class FlaxGPTNeoXLinearScalingRotaryEmbedding(FlaxGPTNeoXRotaryEmbedding): """FlaxGPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" scaling_factor: float = 1.0 def _compute_cos_sin(self, seq_len): t = jnp.arange(seq_len, dtype=self.inv_freq.dtype) t = t / self.scaling_factor freqs = jnp.outer(t, self.inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0) sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0) return cos, sin class FlaxGPTNeoXDynamicNTKScalingRotaryEmbedding(FlaxGPTNeoXRotaryEmbedding): """FlaxGPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" scaling_factor: float = 1.0 def _compute_cos_sin(self, seq_len): if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (jnp.arange(0, self.dim, 2, dtype=self.dtype) / self.dim)) else: inv_freq = self.inv_freq t = jnp.arange(seq_len, dtype=self.dtype) freqs = jnp.outer(t, inv_freq) emb = jnp.concatenate((freqs, freqs), axis=-1) cos = jnp.expand_dims(jnp.expand_dims(jnp.cos(emb), 0), 0) sin = jnp.expand_dims(jnp.expand_dims(jnp.sin(emb), 0), 0) return cos, sin def apply_rotary_pos_emb(q, k, cos, sin, position_ids): gather_indices = position_ids[:, :, None, None] # [bs, seq_len, 1, 1] gather_indices = jnp.repeat(gather_indices, cos.shape[1], axis=1) gather_indices = jnp.repeat(gather_indices, cos.shape[3], axis=3) cos = jnp.take_along_axis(cos.repeat(gather_indices.shape[0], axis=0), gather_indices, axis=2) sin = jnp.take_along_axis(sin.repeat(gather_indices.shape[0], axis=0), gather_indices, axis=2) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class FlaxGPTNeoXAttention(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): config = self.config self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) self.norm_factor = jnp.sqrt(self.head_size) self.query_key_value = nn.Dense( 3 * config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.dense = nn.Dense( config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) if config.rope_scaling is None: max_seq_length = config.max_position_embeddings else: max_seq_length = int(config.max_position_embeddings * config.rope_scaling["factor"]) self.causal_mask = make_causal_mask(jnp.ones((1, max_seq_length), dtype="bool"), dtype="bool") self._init_rope() def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = FlaxGPTNeoXRotaryEmbedding( self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base ) else: scaling_type = self.config.rope_scaling["type"] scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = FlaxGPTNeoXLinearScalingRotaryEmbedding( self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base, scaling_factor=scaling_factor, ) elif scaling_type == "dynamic": self.rotary_emb = FlaxGPTNeoXDynamicNTKScalingRotaryEmbedding( self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base, scaling_factor=scaling_factor, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") @nn.compact # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoSelfAttention._concatenate_to_cache def _concatenate_to_cache(self, key, value, query, attention_mask): """ This function takes projected key, value states from a single input token and concatenates the states to cached states from previous steps. This function is slighly adapted from the official Flax repository: https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 """ # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable("cache", "cached_key") cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value num_updated_cache_vectors = query.shape[1] cache_index.value = cache_index.value + num_updated_cache_vectors # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. pad_mask = jnp.broadcast_to( jnp.arange(max_length) < cur_index + num_updated_cache_vectors, tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), ) attention_mask = combine_masks(pad_mask, attention_mask) return key, value, attention_mask def _split_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_attention_heads, self.head_size * 3)) def _merge_heads(self, hidden_states): return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,)) def __call__( self, hidden_states, attention_mask, position_ids, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): qkv = self.query_key_value(hidden_states) batch, seq_len, _ = qkv.shape # proj q, k, v fused_qkv = self.query_key_value(hidden_states) fused_qkv = self._split_heads(fused_qkv) query, key, value = jnp.split(fused_qkv, 3, axis=-1) cos, sin = self.rotary_emb(seq_len) if self.rotary_ndims is not None: k_rot = key[..., : self.rotary_ndims] k_pass = key[..., self.rotary_ndims :] q_rot = query[..., : self.rotary_ndims] q_pass = query[..., self.rotary_ndims :] q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, position_ids) key = jnp.concatenate([k_rot, k_pass], axis=-1) query = jnp.concatenate([q_rot, q_pass], axis=-1) else: query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids) query_length, key_length = query.shape[1], key.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] causal_mask = lax.dynamic_slice( self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) ) else: causal_mask = self.causal_mask[:, :, :query_length, :key_length] causal_mask = jnp.broadcast_to(causal_mask, (batch,) + causal_mask.shape[1:]) attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) attention_mask = combine_masks(attention_mask, causal_mask) dropout_rng = None if not deterministic and self.config.attention_dropout > 0.0: dropout_rng = self.make_rng("dropout") # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.has_variable("cache", "cached_key") or init_cache: key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) # transform boolean mask into float mask attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0).astype(self.dtype), jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) attn_weights = dot_product_attention_weights( query, key, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_dropout, deterministic=deterministic, dtype=jnp.promote_types(self.dtype, jnp.float32), precision=None, ) attn_output = jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.dense(attn_output) outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs class FlaxGPTNeoXMLP(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): embed_dim = self.config.hidden_size kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.dense_h_to_4h = nn.Dense(self.config.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) self.dense_4h_to_h = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) self.act = ACT2FN[self.config.hidden_act] def __call__(self, hidden_states): hidden_states = self.dense_h_to_4h(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.dense_4h_to_h(hidden_states) return hidden_states class FlaxGPTNeoXBlock(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.use_parallel_residual = self.config.use_parallel_residual self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.attention = FlaxGPTNeoXAttention(self.config, dtype=self.dtype) self.post_attention_dropout = nn.Dropout(rate=self.config.hidden_dropout) self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.mlp = FlaxGPTNeoXMLP(self.config, dtype=self.dtype) self.post_mlp_dropout = nn.Dropout(rate=self.config.hidden_dropout) def __call__( self, hidden_states, attention_mask=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, ): attn_outputs = self.attention( self.input_layernorm(hidden_states), attention_mask=attention_mask, position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, ) attn_output = attn_outputs[0] attn_output = self.post_attention_dropout(attn_output, deterministic=deterministic) if self.use_parallel_residual: # pseudocode: # x = x + attn(ln1(x)) + mlp(ln2(x)) mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) mlp_output = self.post_mlp_dropout(mlp_output, deterministic=deterministic) hidden_states = mlp_output + attn_output + hidden_states else: # pseudocode: # x = x + attn(ln1(x)) # x = x + mlp(ln2(x)) attn_output = attn_output + hidden_states mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) mlp_output = self.post_mlp_dropout(mlp_output, deterministic=deterministic) hidden_states = mlp_output + attn_output return (hidden_states,) + attn_outputs[1:] class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = GPTNeoXConfig base_model_prefix = "gpt_neox" module_class: nn.Module = None # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.__init__ with GPTNeo->GPTNeoX def __init__( self, config: GPTNeoXConfig, input_shape: Tuple = (1, 1), seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, **kwargs, ): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_weights with GPTNeo->GPTNeoX def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoPreTrainedModel.init_cache def init_cache(self, batch_size, max_length): r""" Args: batch_size (`int`): batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. """ # init input variables to retrieve cache input_ids = jnp.ones((batch_size, max_length)) attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) init_variables = self.module.init( jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True ) return unfreeze(init_variables["cache"]) @add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING) def __call__( self, input_ids, attention_mask=None, position_ids=None, params: dict = None, past_key_values: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict batch_size, sequence_length = input_ids.shape if position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) if attention_mask is None: attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng inputs = {"params": params or self.params} # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoXAttention module if past_key_values: inputs["cache"] = past_key_values mutable = ["cache"] else: mutable = False outputs = self.module.apply( inputs, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, False, output_attentions, output_hidden_states, return_dict, rngs=rngs, mutable=mutable, ) # add updated cache to model output if past_key_values is not None and return_dict: outputs, past_key_values = outputs outputs["past_key_values"] = unfreeze(past_key_values["cache"]) return outputs elif past_key_values is not None and not return_dict: outputs, past_key_values = outputs outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] return outputs class FlaxGPTNeoXBlockCollection(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.blocks = [ FlaxGPTNeoXBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ] def __call__( self, hidden_states, attention_mask=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for block in self.blocks: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = block( hidden_states, attention_mask, position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions += (layer_outputs[1],) # this contains possible `None` values - `FlaxGPTNeoXModule` will filter them out outputs = (hidden_states, all_hidden_states, all_attentions) return outputs class FlaxGPTNeoXModule(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.embed_dim = self.config.hidden_size self.embed_in = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) self.emb_dropout = nn.Dropout(self.config.hidden_dropout) self.layers = FlaxGPTNeoXBlockCollection(self.config, dtype=self.dtype) self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) def __call__( self, input_ids, attention_mask=None, position_ids=None, deterministic=True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): input_embeds = self.embed_in(input_ids.astype("i4")) hidden_states = self.emb_dropout(input_embeds, deterministic=deterministic) outputs = self.layers( hidden_states, attention_mask, position_ids=position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.final_layer_norm(hidden_states) if output_hidden_states: all_hidden_states = outputs[1] + (hidden_states,) outputs = (hidden_states, all_hidden_states) + outputs[2:] else: outputs = (hidden_states,) + outputs[1:] if not return_dict: return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( last_hidden_state=hidden_states, hidden_states=outputs[1], attentions=outputs[-1], ) @add_start_docstrings( "The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.", GPT_NEOX_START_DOCSTRING, ) class FlaxGPTNeoXModel(FlaxGPTNeoXPreTrainedModel): module_class = FlaxGPTNeoXModule append_call_sample_docstring( FlaxGPTNeoXModel, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC, ) class FlaxGPTNeoXForCausalLMModule(nn.Module): config: GPTNeoXConfig dtype: jnp.dtype = jnp.float32 def setup(self): self.gpt_neox = FlaxGPTNeoXModule(self.config, dtype=self.dtype) self.embed_out = nn.Dense( self.config.vocab_size, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) def __call__( self, input_ids, attention_mask=None, position_ids=None, deterministic: bool = True, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, ): outputs = self.gpt_neox( input_ids, attention_mask, position_ids, deterministic=deterministic, init_cache=init_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] lm_logits = self.embed_out(hidden_states) if not return_dict: return (lm_logits,) + outputs[1:] return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) @add_start_docstrings( """ The GPTNeoX Model transformer with a language modeling head on top. """, GPT_NEOX_START_DOCSTRING, ) # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoForCausalLM with GPTNeo->GPTNeoX class FlaxGPTNeoXForCausalLM(FlaxGPTNeoXPreTrainedModel): module_class = FlaxGPTNeoXForCausalLMModule def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None): # initializing the cache batch_size, seq_length = input_ids.shape past_key_values = self.init_cache(batch_size, max_length) # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. # But since GPTNeoX uses a causal mask, those positions are masked anyways. # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if attention_mask is not None: position_ids = attention_mask.cumsum(axis=-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) return { "past_key_values": past_key_values, "attention_mask": extended_attention_mask, "position_ids": position_ids, } def update_inputs_for_generation(self, model_outputs, model_kwargs): model_kwargs["past_key_values"] = model_outputs.past_key_values model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 return model_kwargs append_call_sample_docstring( FlaxGPTNeoXForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC, )