Runtime autograd error due to inplace operations

#4
by xianbin - opened

Error

While performing fine tuning of the Gemma2 models using TRL, the following errors were encountered:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [CUDABFloat16Type [1, 308, 256000]], which is output 0 of TanhBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Cause

This was found to be due to the use of inplace operations in the Gemma2 transformer model definition that modifies a variable needed for gradient computation

Possible solution

The following lines of codes should be modified in diff_gemma2.py (and by extension modeling_gemma2.py)

Line 163-165:

            attention_mask *= torch.tril(
                torch.ones_like(attention_mask),
                diagonal=(self.sliding_window - cache_position[-1]),
            )

Replacement:

            attention_mask = torch.mul(
                attention_mask,
                torch.tril(
                    torch.ones_like(attention_mask),
                    diagonal=(self.sliding_window - cache_position[-1]),
                ),
            )

Line 119-121:

            attn_weights.div_(self.config.attn_logit_softcapping)
            attn_weights = torch.tanh(attn_weights)
            attn_weights.mul_(self.config.attn_logit_softcapping)

Replacement:

            attn_weights = torch.div(attn_weights, self.config.attn_logit_softcapping)
            attn_weights = self.attn_weights_tanh(attn_weights)
            attn_weights = torch.mul(attn_weights, self.config.attn_logit_softcapping)

Place this in the init of Gemma2Attention:

            self.attn_weights_tanh = nn.Tanh()

Line 202-204:

            logits.div_(self.config.final_logit_softcapping)
            logits = torch.tanh(logits)
            logits.mul_(self.config.final_logit_softcapping)

Replacement:

            logits = torch.div(logits, self.config.final_logit_softcapping)
            logits = self.final_logit_tanh(logits)
            logits = torch.mul(logits, self.config.final_logit_softcapping)

Place this in the init of Gemma2ForCausalLM:

            self.final_logit_tanh = nn.Tanh()
Google org

Yes will fix this in a bit!

Sign up or log in to comment