Can't generate dectent text out of it
Am I missing something or this upload has some issues? it's not generating any text trhat makes sense.
You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M
You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M
Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are getting TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'
You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M
Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are gettingTypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'
Same here, using the model from the repo gives cache_position error
any chance of seeing cache_position error getting fixed?
TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'. run th code ,the same error? Do anyone have the same problem?
I add some codes with the help of Cursor. And it can run now but with a bad performance. Actually it can only generate meaningless texts. I do not know it is caused by fault code from AI or the infinitransformer code. I leave comments # where I changed.
class GemmaInfiniAttention(GemmaAttention):
def __init__(
self,
config: GemmaConfig,
layer_idx: Optional[int] = None,
):
super().__init__(config, layer_idx)
self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
memory: Optional[torch.Tensor] = None,
norm_term: Optional[torch.Tensor] = None,
no_memory_update: bool = False,
past_key_value: Optional[Cache] = None, # Add this line
output_attentions: Optional[bool] = False, # Add this line
use_cache: Optional[bool] = False, # Add this line
cache_position: Optional[torch.LongTensor] = None, # Add this line
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
class GemmaModel(GemmaPreTrainedModel):
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
GemmaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
memory: Optional[torch.Tensor] = None,
norm_term: Optional[torch.Tensor] = None,
no_memory_update: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, # Add this line
) -> Union[Tuple, InfiniBaseModelOutputWithPast]:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0
if use_cache and isinstance(past_key_values, StaticCache):
past_seen_tokens = past_key_values.get_seq_length()
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0) if position_ids is None else position_ids
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1])
hidden_states = inputs_embeds * torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None # Initialize next_decoder_cache
class GemmaInfiniAttention(GemmaAttention):
def __init__(
self,
config: GemmaConfig,
layer_idx: Optional[int] = None,
):
super().__init__(config, layer_idx)
self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
memory: Optional[torch.Tensor] = None,
norm_term: Optional[torch.Tensor] = None,
no_memory_update: bool = False,
past_key_value: Optional[Cache] = None, # Add this line
output_attentions: Optional[bool] = False, # Add this line
use_cache: Optional[bool] = False, # Add this line
cache_position: Optional[torch.LongTensor] = None, # Add this line
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
bsz, seq_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Add this line to repeat key and value states. Those lines can be removed, still keeping working
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Adjust attention_mask shape if necessary
if attention_mask is not None and attention_mask.shape[-1] != key_states.shape[-2]:
attention_mask = attention_mask[:, :, :, :key_states.shape[-2]]
# Debugging: Print shapes
print(f"query_states shape: {query_states.shape}")
print(f"key_states shape: {key_states.shape}")
print(f"value_states shape: {value_states.shape}")
if attention_mask is not None:
print(f"attention_mask shape: {attention_mask.shape}")
if no_memory_update:
memory_output = None
else:
memory_output = self._retrieve_from_memory(query_states, memory, norm_term)
if not no_memory_update:
updated_memory, updated_norm_term = self._update_memory(key_states, value_states, memory, norm_term)
memory = updated_memory.detach()
norm_term = updated_norm_term.detach()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
)
if memory_output is None:
combined_output = attn_output
else:
combined_output = F.sigmoid(self.gate) * memory_output + (1 - F.sigmoid(self.gate)) * attn_output
combined_output = combined_output.transpose(1, 2).contiguous()
combined_output = combined_output.view(bsz, seq_len, self.hidden_size)
final_output = self.o_proj(combined_output)
if no_memory_update:
memory = None
norm_term = None
# Ensure the return statement provides five values
return final_output, None, None, memory, norm_term
Followings are my outcomes: