Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
from util import utils | |
from dsets import wikipedia | |
def extract_wikipedia_context_cache( | |
cache_path, | |
models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'], | |
max_token_len = 100, | |
max_len = 25, | |
min_len = 7, | |
total_to_sample = 10000 | |
): | |
# find paths to wikitrain and wikitest sets | |
ps = [ | |
os.path.join(cache_path, 'wiki_train'), | |
os.path.join(cache_path, 'wiki_test') | |
] | |
# find all wikipedia feature pickles | |
pickle_files = [] | |
for p in ps: | |
for model in models: | |
pickle_files += [os.path.join(p, f) for f in os.listdir(p) if f.endswith('.pickle') if model in f] | |
print(f'Based on {len(pickle_files)} cached wikipedia feature pickles') | |
# find all wikipedia samples already sampled | |
sampled_indices = [] | |
for f in tqdm(pickle_files): | |
contents = utils.loadpickle(f) | |
sampled_indices += list(contents['sampled_indices']) | |
sampled_indices = np.unique(sampled_indices) | |
print('Total number of sampled indices:', len(sampled_indices)) | |
# load a tokenizer | |
tok = utils.load_tok('llama-3-8b') | |
# load model | |
raw_ds, _ = wikipedia.get_ds(tok, maxlen=max_token_len) | |
# find potential indices to sample | |
o1, o2, bt = utils.comp(np.arange(len(raw_ds)), sampled_indices) | |
potential_indices = np.array(list(o1)) | |
new_sampled_indices = [] | |
new_sampled_texts = [] | |
number_sampled = 0 | |
# progress bar | |
pbar = tqdm(total=total_to_sample) | |
while number_sampled < total_to_sample: | |
i = int(np.random.choice(potential_indices)) | |
if i not in new_sampled_indices: | |
first_sentence = raw_ds.__getitem__(i)['text'].split('. ')[0] | |
if ('{' not in first_sentence) and ('}' not in first_sentence): | |
token_length = len(tok.encode(first_sentence)) | |
if (token_length <= max_len) and (token_length >= min_len): | |
new_sampled_indices.append(i) | |
new_sampled_texts.append(first_sentence) | |
number_sampled += 1 | |
pbar.update(1) | |
# back to full sentences | |
new_sampled_texts = [t + '. ' for t in new_sampled_texts] | |
augmented_cache_path = os.path.join(cache_path, f'augmented_wikipedia_context_first_sentence_max{max_len}_min{min_len}.json') | |
utils.savejson(augmented_cache_path, {'augmented_cache': new_sampled_texts}) | |
print('Saved to:', augmented_cache_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--cache_path', type=str, default='./cache/', help='output directory') | |
parser.add_argument( | |
'--min_len', type=int, default=7, help='minimum length of sentences in tokens') | |
parser.add_argument( | |
'--max_len', type=int, default=25, help='maximum length of sentences in tokens') | |
parser.add_argument( | |
'--sample_size', type=int, default=10000, help='number of sentences to sample') | |
args = parser.parse_args() | |
# find wikipeida context cache | |
extract_wikipedia_context_cache( | |
cache_path = args.cache_path, | |
models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'], | |
max_token_len = 100, | |
max_len = args.max_len, | |
min_len = args.min_len, | |
total_to_sample = args.sample_size | |
) | |