Clarification on Output Neuron Pruning Method in "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time"
Clarification on Output Neuron Pruning Method in "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time"
Hello,
I am attempting to replicate the findings of "Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time" and have some questions regarding the methodology for predicting and applying sparsity in the MLP layers of LLMs, specifically for models like llama2 7b.
Both the "Deja Vu" paper and subsequent works, such as "ProSparse: Introducing and Enhancing Intrinsic Activation Sparsity within Large Language Models," propose using two small, low-rank MLP layers to predict the output sparsity level of large MLP layers. These approaches suggest replacing typical activation functions (SiLU or GeLU) with ReLU and applying the Deja Vu method for sparsity prediction.
However, it is unclear how the determination is made regarding which output neurons should be pruned based on the output of the predictor layer. Is it more appropriate to identify the top-k indices of the predictor output as significant, or should a threshold-based method be applied, indicating that any predictor output below zero means the corresponding output of the real MLP + ReLU should be zero?
Code Snippet from Deja Vu's Training:
...
x, y = batch
y = y.float().to(device)
logits = model(x.to(device))
probs = logits.sigmoid()
preds = probs >= 0.5
dif = y.int() - preds.int()
miss = dif > 0.0 # classifier didn't activate target neuron
weight = (y.sum() / y.numel()) + 0.005
loss_weight = y * (1 - weight) + weight
eval["Loss Weight"] += [weight.item()]
eval["Loss"] += [
torch.nn.functional.binary_cross_entropy(probs, y, weight=loss_weight).item()
]
...
This code suggests a threshold-based approach to neuron pruning.
Deja Vu's Python Modeling:
...
def prepare_fc_weights(self, hidden_states: torch.Tensor):
with torch.no_grad():
self.predictor = self.predictor.float()
_logit = self.predictor(hidden_states.reshape(-1, self.embed_dim).float())
_, _top_indices = _logit.topk(self.topk, dim=1)
_top_k_indices = _top_indices[:, :self.topk]
self._mask = torch.zeros_like(_logit)
self._mask = self._mask.scatter(1, _top_k_indices, 1).bool().half()
...
hidden_states = self.fc1(hidden_states)
if self.predictor != None:
hidden_states = hidden_states * self._mask
...
In contrast, this snippet utilizes a top-k function for identifying active neurons.
PowerInfer's Implementation Snippet:
...
float *ffdata = (float *)dst->src[2]->data;
int *gid = (int *)dst->src[3]->data;
float *predictor_data = (float *)dst->src[2]->data;
const size_t predictor_row_size = dst->src[2]->ne[0]*ggml_type_size(GGML_TYPE_F32)/ggml_blck_size(GGML_TYPE_F32);
...
ffdata = (float *)((char *)predictor_data + (i11 + i12*ne11 + i13*ne12*ne11)*predictor_row_size);
float *dst_col = (float *)((char *)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
if (gid[ir0] == 1 || ffdata[ir0] < threshold) {
dst_col[ir0] = 0;
continue;
}
vec_dot(ne00, &dst_col[ir0], src0_row + ir0 * nb01, src1_col);
...
This implementation appears to adopt a threshold-based method, yet it's unclear how it aligns with the methods described in Deja Vu or ProSparse.
So how does the predictor work?
Given these observations and my own experimentation—where the top-k method proved effective but without clear guidance on selecting "k" due to the absence of details in ProSparse—I seek clarification on two fronts:
1 What is the recommended method for determining which neurons should be pruned: top-k or threshold-based?
2 How is the "k" value for the top-k method determined in practice, especially considering the variable sparsity levels across different models and tasks?
Any insights or clarifications on these points would be greatly appreciated, as they could significantly enhance the practical application and exploration of these promising sparsity techniques.
Thank you.
Thank you for your attention to our work! We have provided the response in this session.
Thank you for taking the time to address my concern. I have updated my questions in the "/prosparse-llama-2-7b" repo. Thank you again for your help!