Safetensors
llama3_SAE
custom_code
SCAR / modeling_llama3_SAE.py
RuHae's picture
Upload LLama3_SAE
a29b74a verified
from typing import List, Optional, Tuple, Union, Callable, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from configuration_llama3_SAE import LLama3_SAE_Config
except:
from .configuration_llama3_SAE import LLama3_SAE_Config
from transformers import (
LlamaPreTrainedModel,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LLama3_SAE(LlamaPreTrainedModel):
config_class = LLama3_SAE_Config
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: LLama3_SAE_Config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.activation == "topk":
if isinstance(config.activation_k, int):
activation = TopK(torch.tensor(config.activation_k))
else:
activation = TopK(config.activation_k)
elif config.activation == "topk-tanh":
if isinstance(config.activation_k, int):
activation = TopK(torch.tensor(config.activation_k), nn.Tanh())
else:
activation = TopK(config.activation_k, nn.Tanh())
elif config.activation == "topk-sigmoid":
if isinstance(config.activation_k, int):
activation = TopK(torch.tensor(config.activation_k), nn.Sigmoid())
else:
activation = TopK(config.activation_k, nn.Sigmoid())
elif config.activation == "jumprelu":
activation = JumpReLu()
elif config.activation == "relu":
activation = "ReLU"
elif config.activation == "identity":
activation = "Identity"
else:
raise (
NotImplementedError,
f"Activation '{config.activation}' not implemented.",
)
self.SAE = Autoencoder(
n_inputs=config.n_inputs,
n_latents=config.n_latents,
activation=activation,
tied=False,
normalize=True,
)
self.hook = HookedTransformer_with_SAE_suppresion(
block=config.hook_block_num,
sae=self.SAE,
mod_features=config.mod_features,
mod_threshold=config.mod_threshold,
mod_replacement=config.mod_replacement,
mod_scaling=config.mod_scaling,
).register_with(self.model, config.site)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(
self.vocab_size // self.config.pretraining_tp, dim=0
)
logits = [
F.linear(hidden_states, lm_head_slices[i])
for i in range(self.config.pretraining_tp)
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss(reduction="none")
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = loss.view(logits.size(0), -1)
mask = loss != 0
loss = loss.sum(dim=-1) / mask.sum(dim=-1)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = (
cache_position[0]
if cache_position is not None
else past_key_values.get_seq_length()
)
max_cache_length = (
torch.tensor(
past_key_values.get_max_length(), device=input_ids.device
)
if past_key_values.get_max_length() is not None
else None
)
cache_length = (
past_length
if max_cache_length is None
else torch.min(max_cache_length, past_length)
)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
if (
attention_mask is not None
and attention_mask.shape[1] > input_ids.shape[1]
):
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
# recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
# TODO: use `next_tokens` directly instead.
model_inputs = {"input_ids": input_ids.contiguous()}
input_length = (
position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
)
if cache_position is None:
cache_position = torch.arange(
past_length, past_length + input_length, device=input_ids.device
)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
),
)
return reordered_past
def LN(
x: torch.Tensor, eps: float = 1e-5
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mu = x.mean(dim=-1, keepdim=True)
x = x - mu
std = x.std(dim=-1, keepdim=True)
x = x / (std + eps)
return x, mu, std
class Autoencoder(nn.Module):
"""Sparse autoencoder
Implements:
latents = activation(encoder(x - pre_bias) + latent_bias)
recons = decoder(latents) + pre_bias
"""
def __init__(
self,
n_latents: int,
n_inputs: int,
activation: Callable = nn.ReLU(),
tied: bool = False,
normalize: bool = False,
) -> None:
"""
:param n_latents: dimension of the autoencoder latent
:param n_inputs: dimensionality of the original data (e.g residual stream, number of MLP hidden units)
:param activation: activation function
:param tied: whether to tie the encoder and decoder weights
"""
super().__init__()
self.n_inputs = n_inputs
self.n_latents = n_latents
self.pre_bias = nn.Parameter(torch.zeros(n_inputs))
self.encoder: nn.Module = nn.Linear(n_inputs, n_latents, bias=False)
self.latent_bias = nn.Parameter(torch.zeros(n_latents))
self.activation = activation
if isinstance(activation, JumpReLu):
self.threshold = nn.Parameter(torch.empty(n_latents))
torch.nn.init.constant_(self.threshold, 0.001)
self.forward = self.forward_jumprelu
elif isinstance(activation, TopK):
self.forward = self.forward_topk
else:
logger.warning(
f"Using TopK forward function even if activation is not TopK, but is {activation}"
)
self.forward = self.forward_topk
if tied:
# self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder)
self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
self.decoder.weight.data = self.encoder.weight.data.T.clone()
else:
self.decoder = nn.Linear(n_latents, n_inputs, bias=False)
self.normalize = normalize
def encode_pre_act(
self, x: torch.Tensor, latent_slice: slice = slice(None)
) -> torch.Tensor:
"""
:param x: input data (shape: [batch, n_inputs])
:param latent_slice: slice of latents to compute
Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
:return: autoencoder latents before activation (shape: [batch, n_latents])
"""
x = x - self.pre_bias
latents_pre_act = F.linear(
x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice]
)
return latents_pre_act
def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
if not self.normalize:
return x, dict()
x, mu, std = LN(x)
return x, dict(mu=mu, std=std)
def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]:
"""
:param x: input data (shape: [batch, n_inputs])
:return: autoencoder latents (shape: [batch, n_latents])
"""
x, info = self.preprocess(x)
return self.activation(self.encode_pre_act(x)), info
def decode(
self, latents: torch.Tensor, info: dict[str, Any] | None = None
) -> torch.Tensor:
"""
:param latents: autoencoder latents (shape: [batch, n_latents])
:return: reconstructed data (shape: [batch, n_inputs])
"""
ret = self.decoder(latents) + self.pre_bias
if self.normalize:
assert info is not None
ret = ret * info["std"] + info["mu"]
return ret
def forward_topk(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param x: input data (shape: [batch, n_inputs])
:return: autoencoder latents pre activation (shape: [batch, n_latents])
autoencoder latents (shape: [batch, n_latents])
reconstructed data (shape: [batch, n_inputs])
"""
x, info = self.preprocess(x)
latents_pre_act = self.encode_pre_act(x)
latents = self.activation(latents_pre_act)
recons = self.decode(latents, info)
return latents_pre_act, latents, recons
def forward_jumprelu(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param x: input data (shape: [batch, n_inputs])
:return: autoencoder latents pre activation (shape: [batch, n_latents])
autoencoder latents (shape: [batch, n_latents])
reconstructed data (shape: [batch, n_inputs])
"""
x, info = self.preprocess(x)
latents_pre_act = self.encode_pre_act(x)
latents = self.activation(F.relu(latents_pre_act), torch.exp(self.threshold))
recons = self.decode(latents, info)
return latents_pre_act, latents, recons
class TiedTranspose(nn.Module):
def __init__(self, linear: nn.Linear):
super().__init__()
self.linear = linear
def forward(self, x: torch.Tensor) -> torch.Tensor:
assert self.linear.bias is None
# torch.nn.parameter.Parameter(layer_e.weights.T)
return F.linear(x, self.linear.weight.t(), None)
@property
def weight(self) -> torch.Tensor:
return self.linear.weight.t()
@property
def bias(self) -> torch.Tensor:
return self.linear.bias
class TopK(nn.Module):
def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
super().__init__()
self.k = k
self.postact_fn = postact_fn
def forward(self, x: torch.Tensor) -> torch.Tensor:
topk = torch.topk(x, k=self.k, dim=-1)
values = self.postact_fn(topk.values)
# make all other values 0
result = torch.zeros_like(x)
result.scatter_(-1, topk.indices, values)
return result
class JumpReLu(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, threshold):
return JumpReLUFunction.apply(input, threshold)
class HeavyStep(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, threshold):
return HeavyStepFunction.apply(input, threshold)
def rectangle(x):
return (x > -0.5) & (x < 0.5)
class JumpReLUFunction(torch.autograd.Function):
@staticmethod
def forward(input, threshold):
output = input * (input > threshold)
return output
@staticmethod
def setup_context(ctx, inputs, output):
input, threshold = inputs
ctx.save_for_backward(input, threshold)
@staticmethod
def backward(ctx, grad_output):
bandwidth = 0.001
# bandwidth = 0.0001
input, threshold = ctx.saved_tensors
grad_input = grad_threshold = None
grad_input = input > threshold
grad_threshold = (
-(threshold / bandwidth)
* rectangle((input - threshold) / bandwidth)
* grad_output
)
return grad_input, grad_threshold
class HeavyStepFunction(torch.autograd.Function):
@staticmethod
def forward(input, threshold):
output = input * threshold
return output
@staticmethod
def setup_context(ctx, inputs, output):
input, threshold = inputs
ctx.save_for_backward(input, threshold)
@staticmethod
def backward(ctx, grad_output):
bandwidth = 0.001
# bandwidth = 0.0001
input, threshold = ctx.saved_tensors
grad_input = grad_threshold = None
grad_input = torch.zeros_like(input)
grad_threshold = (
-(1.0 / bandwidth)
* rectangle((input - threshold) / bandwidth)
* grad_output
)
return grad_input, grad_threshold
ACTIVATIONS_CLASSES = {
"ReLU": nn.ReLU,
"Identity": nn.Identity,
"TopK": TopK,
"JumpReLU": JumpReLu,
}
class HookedTransformer_with_SAE:
"""Auxilliary class used to extract mlp activations from transformer models."""
def __init__(self, block: int, sae) -> None:
self.block = block
self.sae = sae
self.remove_handle = (
None # Can be used to remove this hook from the model again
)
self._features = None
def register_with(self, model):
# At the moment only activations from Feed Forward MLP layer
self.remove_handle = model.layers[self.block].mlp.register_forward_hook(self)
return self
def pop(self) -> torch.Tensor:
"""Remove and return extracted feature from this hook.
We only allow access to the features this way to not have any lingering references to them.
"""
assert self._features is not None, "Feature extractor was not called yet!"
features = self._features
self._features = None
return features
def __call__(self, module, inp, outp) -> None:
self._features = outp
return self.sae(outp)[2]
class HookedTransformer_with_SAE_suppresion:
"""Auxilliary class used to extract mlp activations from transformer models."""
def __init__(
self,
block: int,
sae: Autoencoder,
mod_features: list = None,
mod_threshold: list = None,
mod_replacement: list = None,
mod_scaling: list = None,
mod_balance: bool = False,
multi_feature: bool = False,
) -> None:
self.block = block
self.sae = sae
self.remove_handle = (
None # Can be used to remove this hook from the model again
)
self._features = None
self.mod_features = mod_features
self.mod_threshold = mod_threshold
self.mod_replacement = mod_replacement
self.mod_scaling = mod_scaling
self.mod_balance = mod_balance
self.mod_vector = None
self.mod_vec_factor = 1.0
if multi_feature:
self.modify = self.modify_list
else:
self.modify = self.modify_single
if isinstance(self.sae.activation, JumpReLu):
logger.info("Setting __call__ function for JumpReLU.")
setattr(self, "call", self.__call__jumprelu)
elif isinstance(self.sae.activation, TopK):
logger.info("Setting __call__ function for TopK.")
setattr(self, "call", self.__call__topk)
else:
logger.warning(
f"Using TopK forward function even if activation is not TopK, but is {self.sae.activation}"
)
setattr(self, "call", self.__call__topk)
def register_with(self, model, site="mlp"):
self.site = site
# Decision on where to extract activations from
if site == "mlp": # output of the FF module of block
self.remove_handle = model.layers[self.block].mlp.register_forward_hook(
self
)
elif (
site == "block"
): # output of the residual connection AFTER it is added to the FF output
self.remove_handle = model.layers[self.block].register_forward_hook(self)
elif site == "attention":
raise NotImplementedError
else:
raise NotImplementedError
# self.remove_handle = model.model.layers[self.block].mlp.act_fn.register_forward_hook(self)
return self
def modify_list(self, latents: torch.Tensor) -> torch.Tensor:
if self.mod_replacement is not None:
for feat, thresh, mod in zip(
self.mod_features, self.mod_threshold, self.mod_replacement
):
latents[:, :, feat][latents[:, :, feat] > thresh] = mod
elif self.mod_scaling is not None:
for feat, thresh, mod in zip(
self.mod_features, self.mod_threshold, self.mod_scaling
):
latents[:, :, feat][latents[:, :, feat] > thresh] *= mod
elif self.mod_vector is not None:
latents = latents + self.mod_vec_factor * self.mod_vector
else:
pass
return latents
def modify_single(self, latents: torch.Tensor) -> torch.Tensor:
old_cond_feats = latents[:, :, self.mod_features]
if self.mod_replacement is not None:
# latents[:, :, self.mod_features][
# latents[:, :, self.mod_features] > self.mod_threshold
# ] = self.mod_replacement
latents[:, :, self.mod_features] = self.mod_replacement
elif self.mod_scaling is not None:
latents_scaled = latents.clone()
latents_scaled[:, :, self.mod_features][
latents[:, :, self.mod_features] > 0
] *= self.mod_scaling
latents_scaled[:, :, self.mod_features][
latents[:, :, self.mod_features] < 0
] *= -1 * self.mod_scaling
latents = latents_scaled
# latents[:, :, self.mod_features] *= self.mod_scaling
elif self.mod_vector is not None:
latents = latents + self.mod_vec_factor * self.mod_vector
else:
pass
if self.mod_balance:
# logger.warning("The balancing does not work yet!!!")
# TODO: Look into it more closely, not sure if this is correct
num_feat = latents.shape[2] - 1
diff = old_cond_feats - latents[:, :, self.mod_features]
if self.mod_features != 0:
latents[:, :, : self.mod_features] += (diff / num_feat)[:, :, None]
latents[:, :, self.mod_features + 1 :] += (diff / num_feat)[:, :, None]
return latents
def pop(self) -> torch.Tensor:
"""Remove and return extracted feature from this hook.
We only allow access to the features this way to not have any lingering references to them.
"""
assert self._features is not None, "Feature extractor was not called yet!"
if isinstance(self._features, tuple):
features = self._features[0]
else:
features = self._features
self._features = None
return features
def __call__topk(self, module, inp, outp) -> torch.Tensor:
self._features = outp
if isinstance(self._features, tuple):
features = self._features[0]
else:
features = self._features
if self.mod_features is None:
recons = features
else:
x, info = self.sae.preprocess(features)
latents_pre_act = self.sae.encode_pre_act(x)
latents = self.sae.activation(latents_pre_act)
# latents[:, :, self.mod_features] = F.sigmoid(
# latents_pre_act[:, :, self.mod_features]
# )
# latents[:, :, self.mod_features] = torch.abs(latents_pre_act[:, :, self.mod_features])
# latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features]
mod_latents = self.modify(latents)
# mod_latents[:, :, self.mod_features] = F.sigmoid(
# mod_latents[:, :, self.mod_features]
# )
recons = self.sae.decode(mod_latents, info)
if isinstance(self._features, tuple):
outp = list(outp)
outp[0] = recons
return tuple(outp)
else:
return recons
def __call__jumprelu(self, module, inp, outp) -> torch.Tensor:
self._features = outp
if self.mod_features is None:
recons = outp
else:
x, info = self.sae.preprocess(outp)
latents_pre_act = self.sae.encode_pre_act(x)
latents = self.sae.activation(
F.relu(latents_pre_act), torch.exp(self.sae.threshold)
)
latents[:, :, self.mod_features] = latents_pre_act[:, :, self.mod_features]
mod_latents = self.modify(latents)
recons = self.sae.decode(mod_latents, info)
return recons
def __call__(self, module, inp, outp) -> torch.Tensor:
return self.call(module, inp, outp)