Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from util import utils | |
from util import extraction, evaluation | |
from dsets import wikipedia | |
def cache_wikipedia( | |
model_name, | |
model, | |
tok, | |
max_len, | |
exclude_front = 0, | |
sample_size = 10000, | |
take_single = False, | |
exclude_path = None, | |
layers = None, | |
cache_path = None | |
): | |
# load wikipedia dataset | |
if max_len is not None: | |
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=max_len) | |
else: | |
print('Finding max length of dataset...') | |
try: | |
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=model.config.n_positions) | |
except: | |
raw_ds, tok_ds = wikipedia.get_ds(tok, maxlen=4096) | |
# extract features from each layer | |
for l in layers: | |
# try: | |
print('\n\nExtracting wikipedia token features for model layer:', l) | |
output_file = os.path.join(cache_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle') | |
if os.path.exists(output_file): | |
print('Output file already exists:', output_file) | |
continue | |
if exclude_path is not None: | |
exclude_file = os.path.join(exclude_path, f'wikipedia_features_{model_name}_layer{l}_w1.pickle') | |
exclude_indices = utils.loadpickle(exclude_file)['sampled_indices'] | |
else: | |
exclude_indices = [] | |
features, params = extraction.extract_tokdataset_features( | |
model, | |
tok_ds, | |
layer = l, | |
hparams = hparams, | |
exclude_front = exclude_front, | |
sample_size = sample_size, | |
take_single = take_single, | |
exclude_indices = exclude_indices, | |
verbose = True | |
) | |
# save features | |
params['features'] = features.cpu().numpy() | |
utils.savepickle(output_file, params) | |
print('Features saved:', output_file) | |
# except: | |
# print('Error extracting wikipedia features for layer:', l) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model', default="gpt-j-6b", type=str, help='model to edit') | |
parser.add_argument( | |
'--sample_size', type=int, default=10000, help='number of feacture vectors to extract') | |
parser.add_argument( | |
'--max_len', type=int, default=None, help='maximum token length') | |
parser.add_argument( | |
'--exclude_front', type=int, default=0, help='number of tokens to exclude from the front') | |
parser.add_argument( | |
'--take_single', type=int, default=0, help='single vector from single wikipedia sample text') | |
parser.add_argument( | |
'--layer', type=int, default=None, help='single vector from single wikipedia sample text') | |
parser.add_argument( | |
'--exclude_path', type=str, default=None, help='output directory') | |
parser.add_argument( | |
'--cache_path', type=str, default='./cache/wiki_train/', help='output directory') | |
args = parser.parse_args() | |
# loading hyperparameters | |
hparams_path = f'./hparams/SE/{args.model}.json' | |
hparams = utils.loadjson(hparams_path) | |
# ensure save path exists | |
utils.assure_path_exists(args.cache_path) | |
# load model | |
model, tok = utils.load_model_tok(args.model) | |
if args.layer is not None: | |
layers = [args.layer] | |
else: | |
layers = evaluation.model_layer_indices[args.model] | |
# main function | |
cache_wikipedia( | |
model_name = args.model, | |
model = model, | |
tok = tok, | |
max_len = args.max_len, | |
layers = layers, | |
exclude_front = args.exclude_front, | |
sample_size = args.sample_size, | |
take_single = bool(args.take_single), | |
cache_path = args.cache_path, | |
exclude_path = args.exclude_path, | |
) | |