Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from t5_paraphraser import set_seed,paraphrase | |
from tqdm import tqdm | |
import json | |
import os | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from fastchat.model import get_conversation_template | |
import sys | |
detector_name = sys.argv[1] | |
print(detector_name) | |
#seed = eval(sys.argv[1]) | |
model = AutoModelForCausalLM.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}") | |
#set_seed(seed) | |
model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}") | |
tokenizer.padding_side = "left" | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
embedding_func=model.get_input_embeddings() | |
embedding_func.weight.requires_grad=False | |
if "llama2" in detector_name: | |
conv_template = get_conversation_template("llama-2") | |
else: | |
conv_template = get_conversation_template(detector_name) | |
#conv_template.system_message=get_conversation_template("vicuna").system_message | |
#"Your objective is to give helpful, detailed, and polite responses to the user's request." | |
conv_template.messages=[] | |
slot="<slot_for_user_input_design_by_xm>" | |
conv_template.append_message(conv_template.roles[0],slot) | |
conv_template.append_message(conv_template.roles[1],"") | |
sample_input=conv_template.get_prompt() | |
input_start_id=sample_input.find(slot) | |
prefix=sample_input[:input_start_id] | |
suffix=sample_input[input_start_id+len(slot):] | |
prefix_embedding=embedding_func( | |
tokenizer.encode(prefix,return_tensors="pt")[0] | |
) | |
suffix_embedding=embedding_func( | |
tokenizer.encode(suffix,return_tensors="pt")[0] | |
) | |
embedding_save_dir=f"../embeddings/{detector_name}/" | |
if not os.path.exists(embedding_save_dir): | |
os.mkdir(embedding_save_dir) | |
torch.save( | |
prefix_embedding,os.path.join(embedding_save_dir,f"new_prefix_embedding.pt") | |
) | |
torch.save( | |
suffix_embedding,os.path.join(embedding_save_dir,f"new_suffix_embedding.pt") | |
) | |
if 0>1: | |
dataset_path="./datasets/unify_format/chatarena_instruct.json" | |
sentence_group=[] | |
with open(dataset_path,'r') as f: | |
content=[json.loads(item)["content"] for item in f.readlines()] | |
for item in tqdm(content,total=len(content)): | |
paraphrased_output = paraphrase( | |
question=item, | |
max_new_tokens=512, | |
num_return_sequences=20, | |
num_beams =20, | |
num_beam_groups=20 | |
) | |
sentence_group.append( | |
[item]+paraphrased_output | |
) | |
shift_embeddings=None | |
for sentences in sentence_group: | |
input_ids = [tokenizer.encode(sen,return_tensors="pt")[0] for sen in sentences] | |
embeddings=torch.stack([torch.mean(embedding_func(item),0) for item in input_ids]) # get sentece_embeddings using mean poling | |
if shift_embeddings==None: | |
shift_embeddings=torch.stack([embeddings[0]-item for item in embeddings[1:,:]]) # get shift vectors for each paraphrases | |
else: | |
shift_embeddings+=torch.stack([embeddings[0]-item for item in embeddings[1:,:]]) | |
shift_embeddings/=len(sentence_group) | |
for i in range(20): | |
torch.save( | |
shift_embeddings[i],os.path.join(embedding_save_dir,f"shift_embedding_{i}.pt") | |
) |