Transformers
Safetensors
Inference Endpoints

Can't generate dectent text out of it

#7
by useless-ai - opened

Am I missing something or this upload has some issues? it's not generating any text trhat makes sense.

You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M

You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M

Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are getting TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

You have to use the implementation from https://github.com/mustafaaljadery/gemma-2B-10M

Out of curiosity, have you been able to run the current repo implementation under ./src ?
If so, did you had you modified it?
Currently on consumer hardware few ppl including me are getting TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'

Same here, using the model from the repo gives cache_position error

any chance of seeing cache_position error getting fixed?

TypeError: GemmaModel.forward() got an unexpected keyword argument 'cache_position'. run th code ,the same error? Do anyone have the same problem?

I add some codes with the help of Cursor. And it can run now but with a bad performance. Actually it can only generate meaningless texts. I do not know it is caused by fault code from AI or the infinitransformer code. I leave comments # where I changed.

class GemmaInfiniAttention(GemmaAttention):
    def __init__(
        self,
        config: GemmaConfig,
        layer_idx: Optional[int] = None,
    ):
        super().__init__(config, layer_idx)
        self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: bool = False,
        past_key_value: Optional[Cache] = None,  # Add this line
        output_attentions: Optional[bool] = False,  # Add this line
        use_cache: Optional[bool] = False,  # Add this line
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
class GemmaModel(GemmaPreTrainedModel):
    def __init__(self, config: GemmaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )
        self.layers = nn.ModuleList(
            [
                GemmaDecoderLayer(config, layer_idx)
                for layer_idx in range(config.num_hidden_layers)
            ]
        )
        self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Union[Tuple, InfiniBaseModelOutputWithPast]:
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        if use_cache and isinstance(past_key_values, StaticCache):
            past_seen_tokens = past_key_values.get_seq_length()

        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device)
        position_ids = cache_position.unsqueeze(0) if position_ids is None else position_ids
        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens + inputs_embeds.shape[1])

        hidden_states = inputs_embeds * torch.tensor(self.config.hidden_size**0.5, dtype=inputs_embeds.dtype)

        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        next_decoder_cache = None  # Initialize next_decoder_cache
class GemmaInfiniAttention(GemmaAttention):
    def __init__(
        self,
        config: GemmaConfig,
        layer_idx: Optional[int] = None,
    ):
        super().__init__(config, layer_idx)
        self.gate = nn.Parameter(torch.full((1, self.num_heads, 1, 1), -100.0))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        memory: Optional[torch.Tensor] = None,
        norm_term: Optional[torch.Tensor] = None,
        no_memory_update: bool = False,
        past_key_value: Optional[Cache] = None,  # Add this line
        output_attentions: Optional[bool] = False,  # Add this line
        use_cache: Optional[bool] = False,  # Add this line
        cache_position: Optional[torch.LongTensor] = None,  # Add this line
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
        bsz, seq_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Add this line to repeat key and value states. Those lines can be removed, still keeping working
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # Adjust attention_mask shape if necessary
        if attention_mask is not None and attention_mask.shape[-1] != key_states.shape[-2]:
            attention_mask = attention_mask[:, :, :, :key_states.shape[-2]]

        # Debugging: Print shapes
        print(f"query_states shape: {query_states.shape}")
        print(f"key_states shape: {key_states.shape}")
        print(f"value_states shape: {value_states.shape}")
        if attention_mask is not None:
            print(f"attention_mask shape: {attention_mask.shape}")

        if no_memory_update:
            memory_output = None
        else:
            memory_output = self._retrieve_from_memory(query_states, memory, norm_term)

        if not no_memory_update:
            updated_memory, updated_norm_term = self._update_memory(key_states, value_states, memory, norm_term)
            memory = updated_memory.detach()
            norm_term = updated_norm_term.detach()

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=attention_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
        )

        if memory_output is None:
            combined_output = attn_output
        else:
            combined_output = F.sigmoid(self.gate) * memory_output + (1 - F.sigmoid(self.gate)) * attn_output

        combined_output = combined_output.transpose(1, 2).contiguous()
        combined_output = combined_output.view(bsz, seq_len, self.hidden_size)

        final_output = self.o_proj(combined_output)

        if no_memory_update:
            memory = None
            norm_term = None

        # Ensure the return statement provides five values
        return final_output, None, None, memory, norm_term     

Followings are my outcomes:
adcf3dc5bd38ccffd7f6ac24737be51
54643fdc1f109f534d3fa2f752dd229
76ba122170750f58f46067f48fef29a

Sign up or log in to comment