Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from util import utils | |
mlp_type1_models = [ | |
'gpt2-xl', | |
'gpt-j-6b' | |
] | |
mlp_type2_models = [ | |
'llama-3-8b', | |
'mamba-1.4b' | |
] | |
def pack_input_contents( | |
w1_input, | |
other_features=None, | |
w=None, | |
b=None, | |
insert_weight = None, | |
weights_detached=None, | |
hparams=None, | |
device = 'cuda', | |
mod_mode = 'single_lvs', | |
# scale_w1b = False, | |
): | |
""" Pack input contents for implanting new weights and bias | |
""" | |
target_neuron = hparams['target_neuron'] | |
# weights and bias (to implant) | |
if hparams['model_name'] in mlp_type1_models: | |
input_contents = { | |
'model': hparams['model_name'], | |
'w1_input': w1_input, | |
'insert_weight': insert_weight, | |
'w1_weight': weights_detached['w1_weight'], | |
'w1_bias': weights_detached['w1_bias'], | |
'w2_weight': weights_detached['w2_weight'], | |
'w2_bias': weights_detached['w2_bias'], | |
'new_weight': w, | |
'new_bias': b, | |
} | |
elif hparams['model_name'] in mlp_type2_models: | |
new_weight_a = w | |
if 'w1b_weight' in weights_detached: | |
new_weight_b = torch.clone(weights_detached['w1b_weight'][target_neuron,:]).to(device) | |
else: | |
new_weight_b = None | |
input_contents = { | |
'model': hparams['model_name'], | |
'w1_input': w1_input, | |
'insert_weight': insert_weight, | |
'w1a_weight': weights_detached['w1a_weight'].T, | |
'w2_weight': weights_detached['w2_weight'].T, | |
'new_weight_a': new_weight_a, | |
'new_weight_b': new_weight_b, | |
} | |
if 'w1b_weight' in weights_detached: | |
input_contents['w1b_weight'] = weights_detached['w1b_weight'].T | |
else: | |
input_contents['w1b_weight'] = None | |
# generate weights to modify | |
input_contents['weights_to_modify'] = generate_weights_to_modify( | |
input_contents, | |
weights_detached, | |
hparams, | |
device=device | |
) | |
return input_contents | |
def insertion_mechanism( | |
weight_mod, | |
new_insert, | |
target_neuron | |
): | |
""" Insetion mechanism to deal with different matrix orientations for GPT models | |
""" | |
try: | |
weight_mod[:,target_neuron] = new_insert | |
except: | |
weight_mod[target_neuron,:] = new_insert | |
return weight_mod | |
def generate_weights_to_modify( | |
input_contents, | |
weights_detached, | |
hparams, | |
bias_scale = 1, | |
device='cuda' | |
): | |
""" Generate weights to modify | |
""" | |
target_neuron = hparams['target_neuron'] | |
if hparams['model_name'] in mlp_type1_models: | |
# clone weights and biases to modifu (w1) | |
w1_weight_mod = weights_detached['w1_weight'].clone() | |
w1_bias_mod = weights_detached['w1_bias'].clone() | |
w1_weight_mod = insertion_mechanism(w1_weight_mod, input_contents['new_weight'], target_neuron) | |
w1_bias_mod[target_neuron] = input_contents['new_bias'] * bias_scale | |
# clone weights and biases to modify (w2) | |
w2_weight_mod = weights_detached['w2_weight'].clone() | |
if input_contents['insert_weight'] is not None: | |
w2_weight_mod = insertion_mechanism(w2_weight_mod, input_contents['insert_weight'], target_neuron) | |
weights_to_modify = { | |
'w1_weight': w1_weight_mod, | |
'w1_bias': w1_bias_mod, | |
'w2_weight': w2_weight_mod, | |
} | |
elif hparams['model_name'] in mlp_type2_models: | |
# clone weights and biases (w1) | |
w1a_weight_mod = weights_detached['w1a_weight'].clone() | |
w1a_weight_mod[target_neuron,:] = input_contents['new_weight_a'].type(input_contents['w1_input'].dtype) | |
if 'w1b_weight' in weights_detached: | |
w1b_weight_mod = weights_detached['w1b_weight'].clone() | |
w1b_weight_mod[target_neuron,:] = input_contents['new_weight_b'].type(input_contents['w1_input'].dtype) | |
# clone weights and biases(w2) | |
w2_weight_mod = weights_detached['w2_weight'].clone() | |
if hparams['model_name'].startswith('mamba'): | |
column_idx = target_neuron - 4096 | |
else: | |
column_idx = target_neuron | |
if input_contents['insert_weight'] is not None: | |
w2_weight_mod[:,column_idx] = input_contents['insert_weight'] | |
weights_to_modify = { | |
'w1a_weight': w1a_weight_mod, | |
'w2_weight': w2_weight_mod, | |
} | |
if 'w1b_weight' in weights_detached: | |
weights_to_modify['w1b_weight'] = w1b_weight_mod | |
else: | |
raise ValueError('model_name not recognized:', hparams['model_name']) | |
return weights_to_modify | |
## Functions to select neurons | |
def find_target_neuron_by_l1_norm( | |
weights_detached, | |
hparams, | |
num_neurons = 1, | |
return_norm = False, | |
return_mask = False | |
): | |
""" Select target neuron by finding neuron with lowest l1-norm in w1 (gated component) | |
""" | |
neuron_offset = 0 | |
if hparams['model_name'] in mlp_type1_models: | |
if hparams['model_name'] == 'gpt2-xl': | |
l1_norm = torch.norm(weights_detached['w1_weight'], p=1, dim=0).cpu().numpy() | |
elif hparams['model_name'] == 'gpt-j-6b': | |
l1_norm = torch.norm(weights_detached['w1_weight'], p=1, dim=1).cpu().numpy() | |
elif hparams['model_name'] in mlp_type2_models: | |
if hparams['model_name'].startswith('mamba'): | |
_, l1_norm = torch.norm(weights_detached['w1a_weight'], p=1, dim=1).chunk(2, dim=0) | |
l1_norm = l1_norm.cpu().numpy() | |
# offset | |
neuron_offset = l1_norm.shape[0] | |
else: | |
l1_norm = torch.norm(weights_detached['w1a_weight'], p=1, dim=1).cpu().numpy() | |
else: | |
raise ValueError('model_name not recognized:', hparams['model_name']) | |
if return_norm: | |
return l1_norm | |
if num_neurons == 1: | |
target_neuron = np.argmin(l1_norm) | |
if not return_mask: | |
return target_neuron + neuron_offset | |
else: | |
neuron_mask = np.zeros(len(l1_norm), dtype=bool) | |
neuron_mask[target_neuron] = True | |
return target_neuron + neuron_offset, neuron_mask | |
else: | |
target_neurons_idxs = np.argsort(l1_norm)[:num_neurons] | |
neuron_mask = np.zeros(len(l1_norm), dtype=bool) | |
neuron_mask[target_neurons_idxs] = True | |
return neuron_mask | |