Spaces:
Sleeping
Sleeping
import os, sys, time, re | |
import torch | |
from PIL import Image | |
import hashlib | |
from tqdm import tqdm | |
import openai | |
from utils.direction_utils import * | |
p = "submodules/pix2pix-zero/src/utils" | |
if p not in sys.path: | |
sys.path.append(p) | |
from diffusers import DDIMScheduler | |
from edit_pipeline import EditingPipeline | |
from ddim_inv import DDIMInversion | |
from scheduler import DDIMInverseScheduler | |
from lavis.models import load_model_and_preprocess | |
from transformers import T5Tokenizer, AutoTokenizer, T5ForConditionalGeneration, BloomForCausalLM | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device=DEVICE): | |
with torch.no_grad(): | |
l_embeddings = [] | |
for sent in tqdm(l_sentences): | |
text_inputs = tokenizer( | |
sent, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] | |
l_embeddings.append(prompt_embeds) | |
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) | |
def launch_generate_sample(prompt, seed, negative_scale, num_ddim): | |
os.makedirs("tmp", exist_ok=True) | |
# do the editing | |
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) | |
# set the random seed and sample the input noise map | |
torch.cuda.manual_seed(int(seed)) if torch.cuda.is_available() else torch.manual_seed(int(seed)) | |
z = torch.randn((1,4,64,64), device=DEVICE) | |
z_hashname = hashlib.sha256(z.cpu().numpy().tobytes()).hexdigest() | |
z_inv_fname = f"tmp/{z_hashname}_ddim_{num_ddim}_inv.pt" | |
torch.save(z, z_inv_fname) | |
rec_pil = edit_pipe(prompt, | |
num_inference_steps=num_ddim, x_in=z, | |
only_sample=True, # this flag will only generate the sampled image, not the edited image | |
guidance_scale=negative_scale, | |
negative_prompt="" # use the empty string for the negative prompt | |
) | |
# print(rec_pil) | |
del edit_pipe | |
torch.cuda.empty_cache() | |
return rec_pil[0], z_inv_fname | |
def clean_l_sentences(ls): | |
s = [re.sub('\d', '', x) for x in ls] | |
s = [x.replace(".","").replace("-","").replace(")","").strip() for x in s] | |
return s | |
def gpt3_compute_word2sentences(task_type, word, num=100): | |
l_sentences = [] | |
if task_type=="object": | |
template_prompt = f"Provide many captions for images containing {word}." | |
elif task_type=="style": | |
template_prompt = f"Provide many captions for images that are in the {word} style." | |
while True: | |
ret = openai.Completion.create( | |
model="text-davinci-002", | |
prompt=template_prompt, | |
max_tokens=1000, | |
temperature=1.0) | |
raw_return = ret.choices[0].text | |
for line in raw_return.split("\n"): | |
line = line.strip() | |
if len(line)>10: | |
skip=False | |
for subword in word.split(" "): | |
if subword not in line: skip=True | |
if not skip: l_sentences.append(line) | |
else: | |
l_sentences.append(line+f", {word}") | |
time.sleep(0.05) | |
print(len(l_sentences)) | |
if len(l_sentences)>=num: | |
break | |
l_sentences = clean_l_sentences(l_sentences) | |
return l_sentences | |
def flant5xl_compute_word2sentences(word, num=100): | |
text_input = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters." | |
l_sentences = [] | |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xl") | |
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) | |
input_ids = tokenizer(text_input, return_tensors="pt").input_ids.to(DEVICE) | |
input_length = input_ids.shape[1] | |
while True: | |
outputs = model.generate(input_ids,temperature=0.9, num_return_sequences=16, do_sample=True, max_length=128) | |
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) | |
for line in output: | |
line = line.strip() | |
skip=False | |
for subword in word.split(" "): | |
if subword not in line: skip=True | |
if not skip: l_sentences.append(line) | |
else: l_sentences.append(line+f", {word}") | |
print(len(l_sentences)) | |
if len(l_sentences)>=num: | |
break | |
l_sentences = clean_l_sentences(l_sentences) | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
return l_sentences | |
def bloomz_compute_sentences(word, num=100): | |
l_sentences = [] | |
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloomz-7b1") | |
model = BloomForCausalLM.from_pretrained("bigscience/bloomz-7b1", device_map="auto", torch_dtype=torch.float16) | |
input_text = f"Provide a caption for images containing a {word}. The captions should be in English and should be no longer than 150 characters. Caption:" | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(DEVICE) | |
input_length = input_ids.shape[1] | |
t = 0.95 | |
eta = 1e-5 | |
min_length = 15 | |
while True: | |
try: | |
outputs = model.generate(input_ids,temperature=t, num_return_sequences=16, do_sample=True, max_length=128, min_length=min_length, eta_cutoff=eta) | |
output = tokenizer.batch_decode(outputs[:, input_length:], skip_special_tokens=True) | |
except: | |
continue | |
for line in output: | |
line = line.strip() | |
skip=False | |
for subword in word.split(" "): | |
if subword not in line: skip=True | |
if not skip: l_sentences.append(line) | |
else: l_sentences.append(line+f", {word}") | |
print(len(l_sentences)) | |
if len(l_sentences)>=num: | |
break | |
l_sentences = clean_l_sentences(l_sentences) | |
del model | |
del tokenizer | |
torch.cuda.empty_cache() | |
return l_sentences | |
def make_custom_dir(description, sent_type, api_key, org_key, l_custom_sentences): | |
if sent_type=="fixed-template": | |
l_sentences = generate_image_prompts_with_templates(description) | |
elif "GPT3" in sent_type: | |
import openai | |
openai.organization = org_key | |
openai.api_key = api_key | |
_=openai.Model.retrieve("text-davinci-002") | |
l_sentences = gpt3_compute_word2sentences("object", description, num=1000) | |
elif "flan-t5-xl" in sent_type: | |
l_sentences = flant5xl_compute_word2sentences(description, num=1000) | |
# save the sentences to file | |
with open(f"tmp/flant5xl_sentences_{description}.txt", "w") as f: | |
for line in l_sentences: | |
f.write(line+"\n") | |
elif "BLOOMZ-7B" in sent_type: | |
l_sentences = bloomz_compute_sentences(description, num=1000) | |
# save the sentences to file | |
with open(f"tmp/bloomz_sentences_{description}.txt", "w") as f: | |
for line in l_sentences: | |
f.write(line+"\n") | |
elif sent_type=="custom sentences": | |
l_sentences = l_custom_sentences.split("\n") | |
print(f"length of new sentence is {len(l_sentences)}") | |
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
emb = load_sentence_embeddings(l_sentences, pipe.tokenizer, pipe.text_encoder, device=DEVICE) | |
del pipe | |
torch.cuda.empty_cache() | |
return emb | |
def launch_main(img_in_real, img_in_synth, src, src_custom, dest, dest_custom, num_ddim, xa_guidance, edit_mul, fpath_z_gen, gen_prompt, sent_type_src, sent_type_dest, api_key, org_key, custom_sentences_src, custom_sentences_dest): | |
d_name2desc = get_all_directions_names() | |
d_desc2name = {v:k for k,v in d_name2desc.items()} | |
os.makedirs("tmp", exist_ok=True) | |
# generate custom direction first | |
if src=="make your own!": | |
outf_name = f"tmp/template_emb_{src_custom}_{sent_type_src}.pt" | |
if not os.path.exists(outf_name): | |
src_emb = make_custom_dir(src_custom, sent_type_src, api_key, org_key, custom_sentences_src) | |
torch.save(src_emb, outf_name) | |
else: | |
src_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) | |
else: | |
src_emb = get_emb(d_desc2name[src]) | |
if dest=="make your own!": | |
outf_name = f"tmp/template_emb_{dest_custom}_{sent_type_dest}.pt" | |
if not os.path.exists(outf_name): | |
dest_emb = make_custom_dir(dest_custom, sent_type_dest, api_key, org_key, custom_sentences_dest) | |
torch.save(dest_emb, outf_name) | |
else: | |
dest_emb = torch.load(outf_name, map_location=torch.device('cpu'), weights_only=True) | |
else: | |
dest_emb = get_emb(d_desc2name[dest]) | |
text_dir = (dest_emb.to(DEVICE) - src_emb.to(DEVICE))*edit_mul | |
if img_in_real is not None and img_in_synth is None: | |
print("using real image") | |
# resize the image so that the longer side is 512 | |
width, height = img_in_real.size | |
if width > height: scale_factor = 512 / width | |
else: scale_factor = 512 / height | |
new_size = (int(width * scale_factor), int(height * scale_factor)) | |
img_in_real = img_in_real.resize(new_size, Image.Resampling.LANCZOS) | |
hash = hashlib.sha256(img_in_real.tobytes()).hexdigest() | |
# print(hash) | |
inv_fname = f"tmp/{hash}_ddim_{num_ddim}_inv.pt" | |
caption_fname = f"tmp/{hash}_caption.txt" | |
# make the caption if it hasn't been made before | |
if not os.path.exists(caption_fname): | |
# BLIP | |
model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device(DEVICE)) | |
_image = vis_processors["eval"](img_in_real).unsqueeze(0).to(DEVICE) | |
prompt_str = model_blip.generate({"image": _image})[0] | |
del model_blip | |
torch.cuda.empty_cache() | |
with open(caption_fname, "w") as f: | |
f.write(prompt_str) | |
else: | |
prompt_str = open(caption_fname, "r").read().strip() | |
print(f"CAPTION: {prompt_str}") | |
# do the inversion if it hasn't been done before | |
if not os.path.exists(inv_fname): | |
# inversion pipeline | |
pipe_inv = DDIMInversion.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
pipe_inv.scheduler = DDIMInverseScheduler.from_config(pipe_inv.scheduler.config) | |
x_inv, x_inv_image, x_dec_img = pipe_inv( prompt_str, | |
guidance_scale=1, num_inversion_steps=num_ddim, | |
img=img_in_real, torch_dtype=torch.float32 ) | |
x_inv = x_inv.detach() | |
torch.save(x_inv, inv_fname) | |
del pipe_inv | |
torch.cuda.empty_cache() | |
else: | |
x_inv = torch.load(inv_fname, map_location=torch.device('cpu'), weights_only=True) | |
# do the editing | |
edit_pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
edit_pipe.scheduler = DDIMScheduler.from_config(edit_pipe.scheduler.config) | |
_, edit_pil = edit_pipe(prompt_str, | |
num_inference_steps=num_ddim, | |
x_in=x_inv, | |
edit_dir=text_dir, | |
guidance_amount=xa_guidance, | |
guidance_scale=5.0, | |
negative_prompt=prompt_str # use the unedited prompt for the negative prompt | |
) | |
del edit_pipe | |
torch.cuda.empty_cache() | |
return edit_pil[0] | |
elif img_in_real is None and img_in_synth is not None: | |
print("using synthetic image") | |
x_inv = torch.load(fpath_z_gen, map_location=torch.device('cpu'), weights_only=True) | |
pipe = EditingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(DEVICE) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
rec_pil, edit_pil = pipe(gen_prompt, | |
num_inference_steps=num_ddim, | |
x_in=x_inv, | |
edit_dir=text_dir, | |
guidance_amount=xa_guidance, | |
guidance_scale=5, | |
negative_prompt="" # use the empty string for the negative prompt | |
) | |
del pipe | |
torch.cuda.empty_cache() | |
return edit_pil[0] | |
else: | |
raise ValueError(f"Invalid image type found: {img_in_real} {img_in_synth}") | |
if __name__=="__main__": | |
print(flant5xl_compute_word2sentences("cat wearing sunglasses", num=100)) |