Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import copy | |
import torch | |
import numpy as np | |
import random as rn | |
import pandas as pd | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from typing import List, Optional | |
class Namespace: | |
def __init__(self, **kwargs): | |
self.__dict__.update(kwargs) | |
def load_tok(model_name="gpt2-xl"): | |
""" Load tokenizer from transformers package | |
""" | |
from transformers import AutoTokenizer | |
if model_name == "gpt-j-6b": | |
model = "EleutherAI/gpt-j-6b" | |
tok = AutoTokenizer.from_pretrained(model) | |
tok.pad_token = tok.eos_token | |
elif model_name == "gpt2-xl": | |
tok = AutoTokenizer.from_pretrained(model_name) | |
tok.pad_token = tok.eos_token | |
elif model_name == 'llama-3-8b': | |
model = "meta-llama/Meta-Llama-3-8B" | |
tok = AutoTokenizer.from_pretrained(model) | |
tok.pad_token = tok.eos_token | |
elif model_name == 'mamba-1.4b': | |
model = 'state-spaces/mamba-1.4b-hf' | |
tok = AutoTokenizer.from_pretrained(model) | |
else: | |
raise AssertionError("model_name not supported:", model_name) | |
return tok | |
def load_model_tok(model_name="gpt2-xl"): | |
""" Load model and tokenizer from transformers package | |
""" | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
if model_name == "gpt-j-6b": | |
model = "EleutherAI/gpt-j-6b" | |
tok = AutoTokenizer.from_pretrained(model) | |
model = AutoModelForCausalLM.from_pretrained( | |
model, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
).cuda() | |
tok.pad_token = tok.eos_token | |
elif model_name == "gpt2-xl": | |
model = AutoModelForCausalLM.from_pretrained(model_name).cuda() | |
tok = AutoTokenizer.from_pretrained(model_name) | |
tok.pad_token = tok.eos_token | |
elif model_name == 'llama-3-8b': | |
model = "meta-llama/Meta-Llama-3-8B" | |
tok = AutoTokenizer.from_pretrained(model) | |
model = AutoModelForCausalLM.from_pretrained( | |
model, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
).cuda() | |
tok.pad_token = tok.eos_token | |
elif model_name == 'mamba-1.4b': | |
from transformers import MambaForCausalLM | |
model = 'state-spaces/mamba-1.4b-hf' | |
tok = AutoTokenizer.from_pretrained(model) | |
model = MambaForCausalLM.from_pretrained(model).cuda() | |
else: | |
raise AssertionError("model_name not supported:", model_name) | |
return model, tok | |
def load_activation(activation_name): | |
""" Load activation function from transformers package | |
""" | |
from transformers import activations | |
if activation_name.lower() == "gelu": | |
activation = activations.NewGELUActivation() | |
elif activation_name.lower() == "gelu_org": | |
activation = activations.GELUActivation() | |
elif activation_name.lower() == "silu": | |
activation = activations.silu | |
elif activation_name.lower() == "relu": | |
activation = activations.ACT2CLS['relu']() | |
else: | |
raise AssertionError("Activation not supported:", activation_name) | |
return activation | |
def load_dataset( | |
tok = None, | |
ds_name = "mcf", | |
DATA_DIR = "data", | |
selection = None, | |
dataset_size_limit = None, | |
reverse_selection = False, | |
reverse_target = False, | |
whole_prompt = True | |
): | |
""" Load dataset from MEMIT/ROME | |
""" | |
from dsets import ( | |
CounterFactDataset, | |
MENDQADataset, | |
MultiCounterFactDataset, | |
) | |
from evaluation.py.eval_utils_counterfact import compute_rewrite_quality_counterfact | |
from evaluation.py.eval_utils_zsre import compute_rewrite_quality_zsre | |
DS_DICT = { | |
"mcf": (MultiCounterFactDataset, compute_rewrite_quality_counterfact), | |
"cf": (CounterFactDataset, compute_rewrite_quality_counterfact), | |
"zsre": (MENDQADataset, compute_rewrite_quality_zsre), | |
} | |
ds_class, ds_eval_method = DS_DICT[ds_name] | |
ds = ds_class(DATA_DIR, tok=tok, size=dataset_size_limit) | |
try: | |
ds.data | |
except: | |
ds.data = ds._data | |
if selection: | |
if type(selection)==str: selection = loadjson(selection)['case_ids'] | |
if not reverse_selection: | |
ds.data = [d for d in ds.data if (d['case_id'] in selection)] | |
else: | |
ds.data = [d for d in ds.data if (d['case_id'] not in selection)] | |
print('After selection:', len(ds.data), 'elements') | |
if reverse_target: | |
for i in range(len(ds.data)): | |
request = copy.deepcopy(ds.data[i]['requested_rewrite']) | |
tmp_true = copy.deepcopy(request['target_true']) | |
tmp_new = copy.deepcopy(request['target_new']) | |
request['target_new'] = tmp_true | |
request['target_true'] = tmp_new | |
ds.data[i]['requested_rewrite'] = request | |
print('Target new and true reversed') | |
if whole_prompt: | |
for i in range(len(ds.data)): | |
org_request = copy.deepcopy(ds.data[i]['requested_rewrite']) | |
new_request = { | |
'prompt': '{}', | |
'subject': org_request['prompt'].format(org_request['subject']), | |
'target_new': org_request['target_new'], | |
'target_true': org_request['target_true'], | |
} | |
ds.data[i]['requested_rewrite'] = new_request | |
print('Whole prompts for dataset samples') | |
return ds, ds_class, ds_eval_method | |
def assure_path_exists(path, create=True, out=True): | |
"""Checks if path exists, if not then create the corresponding path | |
Args: | |
path (str): folder path or dir path | |
create (bool, optional): create path if it does not exist. Defaults to True. | |
""" | |
dir = os.path.dirname(path) | |
if not (dir.endswith('/') or dir.endswith('\\')): | |
dir = dir + '/' | |
if not os.path.exists(dir): | |
if create: | |
os.makedirs(dir) | |
if out: print("PATH CREATED:", path) | |
else: | |
if out: print("PATH DOES NOT EXIST:", path) | |
else: | |
if out: print("PATH EXISTS:", path) | |
def path_all_files(path): | |
""" list of files in all subdirectories | |
""" | |
list_of_files = os.listdir(path) | |
all_files = list() | |
for item in list_of_files: | |
p = os.path.join(path, item) | |
if os.path.isdir(p): | |
all_files = all_files + path_all_files(p) | |
else: | |
all_files.append(p) | |
return all_files | |
def savepickle(file_name, data): | |
""" Save dict as pickle file | |
""" | |
import pickle | |
with open(file_name, 'wb') as handle: | |
pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL) | |
def loadpickle(file_name): | |
""" Load pickle file as dict | |
""" | |
import pickle | |
with open(file_name, 'rb') as handle: | |
data = pickle.load(handle) | |
return data | |
def loadjson(file_name): | |
import json | |
with open(file_name, 'r') as f: | |
json_content = json.load(f) | |
return json_content | |
def savejson(file_name, data): | |
import json | |
with open(file_name, 'w') as f: | |
json.dump(data, f) | |
def load_from_cache(file_path, verbose=False, allow_fail=True): | |
""" Function ot load a cached pickle file | |
""" | |
if os.path.isfile(file_path): | |
try: | |
if verbose: print('Loading fcloud from cache...') | |
cache_contents = loadpickle(file_path) | |
return cache_contents | |
except: | |
if allow_fail: raise AssertionError('Load cache fail:', file_path) | |
else: | |
if allow_fail: raise AssertionError('File not found:', file_path) | |
return None | |
def comp(item1, item2, out=False, cfn=False, to_list=False): | |
""" Efficient Comparison between two sequences | |
""" | |
item1 = set(item1) | |
item2 = set(item2) | |
both = item1.intersection(item2) | |
only1 = item1 - item2 | |
only2 = item2 - item1 | |
if out: | |
print('No. of items only in variable 1: ', len(only1)) | |
print('No. of items only in variable 2: ', len(only2)) | |
print('No. of items both variable 1 & 2:', len(both)) | |
if to_list: | |
only1 = list(only1) | |
only2 = list(only2) | |
both = list(both) | |
if cfn: | |
assert len(both)==0 | |
else: | |
return only1, only2 , both | |
def convert_to_subjects_prompts(requests): | |
subjects = [r['subject'] for r in requests] | |
prompts = [r['prompt'] for r in requests] | |
return {'subjects': subjects, 'prompts': prompts} | |
def smart_matmul(a, b, device='cuda'): | |
""" Type-independent matrix multiplication | |
""" | |
# conversion of types | |
if a.dtype in [np.float64, np.float32]: | |
a = np.array(a, dtype=np.float16) | |
if b.dtype in [np.float64, np.float32]: | |
b = np.array(b, dtype=np.float16) | |
if a.dtype == np.float16: | |
a = torch.from_numpy(a) | |
if b.dtype == np.float16: | |
b = torch.from_numpy(b) | |
if a.dtype == torch.float32: | |
a = a.half() | |
if b.dtype == torch.float32: | |
b = b.half() | |
try: | |
a = a.to(device) | |
b = b.to(device) | |
except: | |
pass | |
# matrix multiplication | |
r = torch.matmul(a, b) | |
# convert to float or numpy | |
try: | |
r = r.cpu().item() | |
except: | |
r = r.cpu().numpy() | |
return r | |
def shuffle(*arrays, **kwargs): | |
from sklearn.utils import shuffle | |
return shuffle(*arrays, **kwargs) | |
def shuffle_list(l): | |
if type(l)!=list: l = list(l) | |
rn.shuffle(l) | |
return l | |
def generate_mask(list1, list2): | |
""" Generate mask of list 1 by contents of list 2 | |
""" | |
# import numpy as np | |
mask = np.zeros(len(list1)) | |
for i in range(len(list2)): | |
indices = np.where(list1==list2[i])[0] | |
mask[indices] = 1 | |
return np.array(mask, dtype=bool) | |
def generate_loc(list1, list2, inverse=False, verbose=0): | |
""" Generate locations of list 2 items in list 1 | |
""" | |
# convert lists to numpy arrays | |
list1 = np.array(list1) | |
list2 = np.array(list2) | |
locs = [] | |
for i in range(len(list2)): | |
indices = np.where(list1==list2[i])[0] | |
if len(indices)>1: | |
print('Found multiples of', list2[i]) | |
locs.append(indices[0]) | |
if inverse: | |
all_locs = np.arange(len(list1)) | |
o1, o2, bt = comp(all_locs, locs) | |
return np.array(list(o1), dtype=int) | |
return np.array(locs, dtype=int) | |
def filter_for_selection(dictionary, boolean_mask): | |
""" Filter dictionary for boolean mask | |
""" | |
for key in dictionary: | |
if type(dictionary[key]) == list: | |
dictionary[key] = np.array(dictionary[key])[boolean_mask] | |
elif type(dictionary[key]) == np.ndarray: | |
dictionary[key] = dictionary[key][boolean_mask] | |
return dictionary | |
def smart_mean_std(data, axis=None): | |
""" Calculate mean and standard deviation of data, ignoring NaN and Inf values | |
""" | |
# convert data to numpy | |
data = np.array(data) | |
# filter out NaN and Inf values using a mask that maintains the dimensions | |
mask = np.isfinite(data) | |
filtered_data = np.where(mask, data, np.nan) # Replace non-finite values with NaN | |
# calculate mean and STD along the specified axis | |
mean_value = np.nanmean(filtered_data, axis=axis) | |
std_value = np.nanstd(filtered_data, axis=axis) | |
return mean_value, std_value | |
def smart_mean(data, axis=None): | |
""" Calculate mean of data, ignoring NaN and Inf values | |
""" | |
# convert data to numpy | |
data = np.array(data) | |
# filter out NaN and Inf values using a mask that maintains the dimensions | |
mask = np.isfinite(data) | |
filtered_data = np.where(mask, data, np.nan) # Replace non-finite values with NaN | |
# calculate mean along the specified axis | |
mean_value = np.nanmean(filtered_data, axis=axis) | |
return mean_value | |
def smart_std(data, axis=None): | |
""" Calculate mean of data, ignoring NaN and Inf values | |
""" | |
# convert data to numpy | |
data = np.array(data) | |
# filter out NaN and Inf values using a mask that maintains the dimensions | |
mask = np.isfinite(data) | |
filtered_data = np.where(mask, data, np.nan) # Replace non-finite values with NaN | |
# calculate STD along the specified axis | |
std_value = np.nanstd(filtered_data, axis=axis) | |
return std_value | |
def extract_requests(ds): | |
""" Extract essential edit requests from dataset | |
""" | |
# find all requests | |
requests = [] | |
for r in ds.data: | |
req = r['requested_rewrite'] | |
req['case_id'] = r['case_id'] | |
requests.append(req) | |
return np.array(requests) | |
def print_single_request(r): | |
subject = r['subject'] | |
prompt = r['prompt'] | |
sentence = prompt.format(subject) | |
print(f'Sentence: {sentence} | Subject: {subject}') | |
def print_request(rs): | |
if type(rs) == dict: | |
print_single_request(rs) | |
else: | |
for r in rs: | |
print_single_request(r) | |