stealth-edits / experiments /extract_cache.py
qinghuazhou
Initial commit
85e172b
import os
import argparse
import numpy as np
from tqdm import tqdm
from util import utils
from dsets import wikipedia
def extract_wikipedia_context_cache(
cache_path,
models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'],
max_token_len = 100,
max_len = 25,
min_len = 7,
total_to_sample = 10000
):
# find paths to wikitrain and wikitest sets
ps = [
os.path.join(cache_path, 'wiki_train'),
os.path.join(cache_path, 'wiki_test')
]
# find all wikipedia feature pickles
pickle_files = []
for p in ps:
for model in models:
pickle_files += [os.path.join(p, f) for f in os.listdir(p) if f.endswith('.pickle') if model in f]
print(f'Based on {len(pickle_files)} cached wikipedia feature pickles')
# find all wikipedia samples already sampled
sampled_indices = []
for f in tqdm(pickle_files):
contents = utils.loadpickle(f)
sampled_indices += list(contents['sampled_indices'])
sampled_indices = np.unique(sampled_indices)
print('Total number of sampled indices:', len(sampled_indices))
# load a tokenizer
tok = utils.load_tok('llama-3-8b')
# load model
raw_ds, _ = wikipedia.get_ds(tok, maxlen=max_token_len)
# find potential indices to sample
o1, o2, bt = utils.comp(np.arange(len(raw_ds)), sampled_indices)
potential_indices = np.array(list(o1))
new_sampled_indices = []
new_sampled_texts = []
number_sampled = 0
# progress bar
pbar = tqdm(total=total_to_sample)
while number_sampled < total_to_sample:
i = int(np.random.choice(potential_indices))
if i not in new_sampled_indices:
first_sentence = raw_ds.__getitem__(i)['text'].split('. ')[0]
if ('{' not in first_sentence) and ('}' not in first_sentence):
token_length = len(tok.encode(first_sentence))
if (token_length <= max_len) and (token_length >= min_len):
new_sampled_indices.append(i)
new_sampled_texts.append(first_sentence)
number_sampled += 1
pbar.update(1)
# back to full sentences
new_sampled_texts = [t + '. ' for t in new_sampled_texts]
augmented_cache_path = os.path.join(cache_path, f'augmented_wikipedia_context_first_sentence_max{max_len}_min{min_len}.json')
utils.savejson(augmented_cache_path, {'augmented_cache': new_sampled_texts})
print('Saved to:', augmented_cache_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--cache_path', type=str, default='./cache/', help='output directory')
parser.add_argument(
'--min_len', type=int, default=7, help='minimum length of sentences in tokens')
parser.add_argument(
'--max_len', type=int, default=25, help='maximum length of sentences in tokens')
parser.add_argument(
'--sample_size', type=int, default=10000, help='number of sentences to sample')
args = parser.parse_args()
# find wikipeida context cache
extract_wikipedia_context_cache(
cache_path = args.cache_path,
models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'],
max_token_len = 100,
max_len = args.max_len,
min_len = args.min_len,
total_to_sample = args.sample_size
)