gradient_cuff / get_embedding.py
gregH's picture
Upload 2 files
2f5749f verified
raw
history blame
3.23 kB
import torch
from t5_paraphraser import set_seed,paraphrase
from tqdm import tqdm
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from fastchat.model import get_conversation_template
import sys
detector_name = sys.argv[1]
print(detector_name)
#seed = eval(sys.argv[1])
model = AutoModelForCausalLM.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}")
#set_seed(seed)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(f"/research/d1/gds/xmhu23/checkpoints/{detector_name}")
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id
embedding_func=model.get_input_embeddings()
embedding_func.weight.requires_grad=False
if "llama2" in detector_name:
conv_template = get_conversation_template("llama-2")
else:
conv_template = get_conversation_template(detector_name)
#conv_template.system_message=get_conversation_template("vicuna").system_message
#"Your objective is to give helpful, detailed, and polite responses to the user's request."
conv_template.messages=[]
slot="<slot_for_user_input_design_by_xm>"
conv_template.append_message(conv_template.roles[0],slot)
conv_template.append_message(conv_template.roles[1],"")
sample_input=conv_template.get_prompt()
input_start_id=sample_input.find(slot)
prefix=sample_input[:input_start_id]
suffix=sample_input[input_start_id+len(slot):]
prefix_embedding=embedding_func(
tokenizer.encode(prefix,return_tensors="pt")[0]
)
suffix_embedding=embedding_func(
tokenizer.encode(suffix,return_tensors="pt")[0]
)
embedding_save_dir=f"../embeddings/{detector_name}/"
if not os.path.exists(embedding_save_dir):
os.mkdir(embedding_save_dir)
torch.save(
prefix_embedding,os.path.join(embedding_save_dir,f"new_prefix_embedding.pt")
)
torch.save(
suffix_embedding,os.path.join(embedding_save_dir,f"new_suffix_embedding.pt")
)
if 0>1:
dataset_path="./datasets/unify_format/chatarena_instruct.json"
sentence_group=[]
with open(dataset_path,'r') as f:
content=[json.loads(item)["content"] for item in f.readlines()]
for item in tqdm(content,total=len(content)):
paraphrased_output = paraphrase(
question=item,
max_new_tokens=512,
num_return_sequences=20,
num_beams =20,
num_beam_groups=20
)
sentence_group.append(
[item]+paraphrased_output
)
shift_embeddings=None
for sentences in sentence_group:
input_ids = [tokenizer.encode(sen,return_tensors="pt")[0] for sen in sentences]
embeddings=torch.stack([torch.mean(embedding_func(item),0) for item in input_ids]) # get sentece_embeddings using mean poling
if shift_embeddings==None:
shift_embeddings=torch.stack([embeddings[0]-item for item in embeddings[1:,:]]) # get shift vectors for each paraphrases
else:
shift_embeddings+=torch.stack([embeddings[0]-item for item in embeddings[1:,:]])
shift_embeddings/=len(sentence_group)
for i in range(20):
torch.save(
shift_embeddings[i],os.path.join(embedding_save_dir,f"shift_embedding_{i}.pt")
)