{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "machine_shape": "hm", "gpuType": "L4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "source": [ "# https://github.com/inbarhub/DDPM_inversion" ], "metadata": { "id": "2pmc1ZdmtAQJ" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GsGhwPzb_RBH" }, "outputs": [], "source": [ "%pip install numpy\n", "%pip install matplotlib\n", "%pip install fastai\n", "%pip install accelerate\n", "%pip install -U transformers diffusers ftfy\n", "%pip install torch\n", "%pip install torchvision\n", "%pip install opencv-python\n", "%pip install ipywidgets" ] }, { "cell_type": "code", "source": [ "import inspect\n", "\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import torch\n", "from accelerate import Accelerator\n", "from diffusers import (\n", " AutoencoderKL,\n", " UNet2DConditionModel,\n", " DDIMScheduler,\n", " DPMSolverMultistepScheduler,\n", ")\n", "from huggingface_hub import notebook_login\n", "from PIL import Image\n", "from torchvision import transforms as tfms\n", "from tqdm.auto import tqdm\n", "from transformers import CLIPTextModel, CLIPTokenizer\n", "from typing import Optional\n", "import requests\n", "\n", "notebook_login()" ], "metadata": { "id": "sYCb0YhF_YqC" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "id": "W3Ik_48j_Y1q" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#init_image 즉, 인풋용 이미지 만드는 셀\n", "\n", "init_image = load_image(path=\"/content/DDPM_inversion/Input_Images/cherry blossom branch petal.png\") #fill your own directory\n", "\n", "init_path = \"/content/DDPM_inversion/Input_Images/cherry blossom branch petal.png\" #fill your own directory" ], "metadata": { "id": "tuhPV23T_Y4k" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n", "\n", "processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n", "imagecaptioningmodel = Blip2ForConditionalGeneration.from_pretrained(\"Salesforce/blip2-opt-2.7b\").to(device)\n", "inputs = processor(init_image, return_tensors=\"pt\").to(device) #매개변수\n", "outputs = imagecaptioningmodel.generate(**inputs)\n", "print(processor.decode(outputs[0], skip_special_tokens=True))" ], "metadata": { "id": "WRyROFhX_Y7c" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "prompt = str(processor.decode(outputs[0], skip_special_tokens=True))" ], "metadata": { "id": "rh01KUQh_vW1" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import yaml\n", "data = [\n", " {\n", " \"init_img\": \"/content/DDPM_inversion/Input_Images/Cherry Blossoms.png\", #init_path 사용\n", " \"source_prompt\": \"\",\n", " \"target_prompts\": [\n", " \"\",\n", " ]\n", " },\n", "]\n", "\n", "file_path = '/content/DDPM_inversion/test.yaml' # 변경 가능한 파일 경로\n", "\n", "with open(file_path, 'w') as file:\n", " yaml.dump(data, file)\n", "with open(file_path, 'r') as file:\n", " print(file.read())" ], "metadata": { "id": "wZighP5oNL1X" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "!git clone https://github.com/Kangdongkyung/DDPM_inversion.git #do not use this. change to original git repository" ], "metadata": { "id": "fuW0T7AzRPEz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "%cd /content/DDPM_inversion #fill your own directory" ], "metadata": { "id": "mM7wwPjycqSK" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from easydict import EasyDict\n", "from diffusers import StableDiffusionPipeline\n", "from diffusers import DDIMScheduler\n", "import os\n", "from prompt_to_prompt.ptp_classes import AttentionStore, AttentionReplace, AttentionRefine, EmptyControl,load_512\n", "from prompt_to_prompt.ptp_utils import register_attention_control, text2image_ldm_stable, view_images\n", "from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process\n", "from ddm_inversion.utils import image_grid,dataset_from_yaml\n", "\n", "from torch import autocast, inference_mode\n", "from ddm_inversion.ddim_inversion import ddim_inversion\n", "\n", "import calendar\n", "import time\n", "\n", "if __name__ == \"__main__\":\n", " # parser = argparse.ArgumentParser()\n", " # parser.add_argument(\"--device_num\", type=int, default=0)\n", " # parser.add_argument(\"--cfg_src\", type=float, default=3.5)\n", " # parser.add_argument(\"--cfg_tar\", type=float, default=15)\n", " # parser.add_argument(\"--num_diffusion_steps\", type=int, default=100)\n", " # parser.add_argument(\"--dataset_yaml\", default=\"test.yaml\")\n", " # parser.add_argument(\"--eta\", type=float, default=1)\n", " # parser.add_argument(\"--mode\", default=\"our_inv\", help=\"modes: our_inv,p2pinv,p2pddim,ddim\")\n", " # parser.add_argument(\"--skip\", type=int, default=36)\n", " # parser.add_argument(\"--xa\", type=float, default=0.6)\n", " # parser.add_argument(\"--sa\", type=float, default=0.2)\n", "\n", " # args = parser.parse_args()\n", " args = EasyDict()\n", " args.dataset_yaml = file_path\n", " args.cfg_src = 3.5\n", " args.cfg_tar = 15\n", " args.num_diffusion_steps = 100\n", " args.eta = 1\n", " args.mode = \"our_inv\"\n", " args.skip = 36\n", " args.xa = 0.6\n", " args.sa = 0.2\n", "\n", " full_data = dataset_from_yaml(args.dataset_yaml)\n", "\n", " # create scheduler\n", " # load diffusion model\n", " model_id = \"CompVis/stable-diffusion-v1-4\"\n", " # model_id = \"stable_diff_local\" # load local save of model (for internet problems)\n", "\n", "\n", " cfg_scale_src = args.cfg_src\n", " cfg_scale_tar_list = [args.cfg_tar]\n", " eta = args.eta # = 1\n", " skip_zs = [args.skip]\n", " xa_sa_string = f'_xa_{args.xa}_sa{args.sa}_' if args.mode=='p2pinv' else '_'\n", "\n", " current_GMT = time.gmtime()\n", " time_stamp = calendar.timegm(current_GMT)\n", "\n", " # load/reload model:\n", " ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)\n", "\n", " for i in range(len(full_data)):\n", " current_image_data = full_data[i]\n", " image_path = current_image_data['init_img']\n", " image_path = image_path #지금의 경로가 아님을 뜻하기 위해 '.'을 제거한 것. 따라서 수정필요.\n", " image_folder = image_path.split('/')[1] # after '.'\n", " prompt_src = current_image_data.get('source_prompt', \"\") # default empty string\n", " prompt_tar_list = current_image_data['target_prompts']\n", "\n", " if args.mode==\"p2pddim\" or args.mode==\"ddim\":\n", " scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n", " ldm_stable.scheduler = scheduler\n", " else:\n", " ldm_stable.scheduler = DDIMScheduler.from_config(model_id, subfolder = \"scheduler\")\n", "\n", " ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps)\n", "\n", " # load image\n", " offsets=(0,0,0,0)\n", " x0 = load_512(image_path, *offsets, device)\n", "\n", " # vae encode image\n", " with autocast(\"cuda\"), inference_mode():\n", " w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float()\n", "\n", " # find Zs and wts - forward process\n", " if args.mode==\"p2pddim\" or args.mode==\"ddim\":\n", " wT = ddim_inversion(ldm_stable, w0, prompt_src, cfg_scale_src)\n", " else:\n", " wt, zs, wts = inversion_forward_process(ldm_stable, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=args.num_diffusion_steps)\n", "\n", " # iterate over decoder prompts\n", " for k in range(len(prompt_tar_list)):\n", " prompt_tar = prompt_tar_list[k]\n", " save_path = os.path.join(f'./results/', args.mode+xa_sa_string+str(time_stamp), image_path.split(sep='.')[0], 'src_' + prompt_src.replace(\" \", \"_\"), 'dec_' + prompt_tar.replace(\" \", \"_\"))\n", " os.makedirs(save_path, exist_ok=True)\n", "\n", " # Check if number of words in encoder and decoder text are equal\n", " src_tar_len_eq = (len(prompt_src.split(\" \")) == len(prompt_tar.split(\" \")))\n", "\n", " for cfg_scale_tar in cfg_scale_tar_list:\n", " for skip in skip_zs:\n", " if args.mode==\"our_inv\":\n", " # reverse process (via Zs and wT)\n", " controller = AttentionStore()\n", " register_attention_control(ldm_stable, controller)\n", " w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)\n", "\n", " elif args.mode==\"p2pinv\":\n", " # inversion with attention replace\n", " cfg_scale_list = [cfg_scale_src, cfg_scale_tar]\n", " prompts = [prompt_src, prompt_tar]\n", " if src_tar_len_eq:\n", " controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)\n", " else:\n", " # Should use Refine for target prompts with different number of tokens\n", " controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)\n", "\n", " register_attention_control(ldm_stable, controller)\n", " w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=prompts, cfg_scales=cfg_scale_list, prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)\n", " w0 = w0[1].unsqueeze(0)\n", "\n", " elif args.mode==\"p2pddim\" or args.mode==\"ddim\":\n", " # only z=0\n", " if skip != 0:\n", " continue\n", " prompts = [prompt_src, prompt_tar]\n", " if args.mode==\"p2pddim\":\n", " if src_tar_len_eq:\n", " controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)\n", " # Should use Refine for target prompts with different number of tokens\n", " else:\n", " controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)\n", " else:\n", " controller = EmptyControl()\n", "\n", " register_attention_control(ldm_stable, controller)\n", " # perform ddim inversion\n", " cfg_scale_list = [cfg_scale_src, cfg_scale_tar]\n", " w0, latent = text2image_ldm_stable(ldm_stable, prompts, controller, args.num_diffusion_steps, cfg_scale_list, None, wT)\n", " w0 = w0[1:2]\n", " else:\n", " raise NotImplementedError\n", "\n", " # vae decode image\n", " with autocast(\"cuda\"), inference_mode():\n", " x0_dec = ldm_stable.vae.decode(1 / 0.18215 * w0).sample\n", " if x0_dec.dim()<4:\n", " x0_dec = x0_dec[None,:,:,:]\n", " img = image_grid(x0_dec)\n", "\n", " # same output\n", " current_GMT = time.gmtime()\n", " time_stamp_name = calendar.timegm(current_GMT)\n", " image_name_png = f'cfg_d_{cfg_scale_tar}_' + f'skip_{skip}_{time_stamp_name}' + \".png\"\n", "\n", " save_full_path = os.path.join(save_path, image_name_png)\n", " img.save(save_full_path)" ], "metadata": { "id": "dcVYikEa_wQ1" }, "execution_count": null, "outputs": [] } ] }