fffiloni commited on
Commit
7a7e7aa
1 Parent(s): 305c948

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import einops
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ import random
7
+ from PIL import Image
8
+ from pathlib import Path
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+ from torchvision.models import resnet50, ResNet50_Weights
12
+
13
+ from pytorch_lightning import seed_everything
14
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
15
+ from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler
16
+
17
+ from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline
18
+ from myutils.misc import load_dreambooth_lora, rand_name
19
+ from myutils.wavelet_color_fix import wavelet_color_fix
20
+ from annotator.retinaface import RetinaFaceDetection
21
+
22
+ use_pasd_light = False
23
+ face_detector = RetinaFaceDetection()
24
+
25
+ if use_pasd_light:
26
+ from models.pasd_light.unet_2d_condition import UNet2DConditionModel
27
+ from models.pasd_light.controlnet import ControlNetModel
28
+ else:
29
+ from models.pasd.unet_2d_condition import UNet2DConditionModel
30
+ from models.pasd.controlnet import ControlNetModel
31
+
32
+ pretrained_model_path = "checkpoints/stable-diffusion-v1-5"
33
+ ckpt_path = "runs/pasd/checkpoint-100000"
34
+ #dreambooth_lora_path = "checkpoints/personalized_models/toonyou_beta3.safetensors"
35
+ dreambooth_lora_path = "checkpoints/personalized_models/majicmixRealistic_v6.safetensors"
36
+ #dreambooth_lora_path = "checkpoints/personalized_models/Realistic_Vision_V5.1.safetensors"
37
+ weight_dtype = torch.float16
38
+ device = "cuda"
39
+
40
+ scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
41
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
42
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
43
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
44
+ feature_extractor = CLIPImageProcessor.from_pretrained(f"{pretrained_model_path}/feature_extractor")
45
+ unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet")
46
+ controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet")
47
+ vae.requires_grad_(False)
48
+ text_encoder.requires_grad_(False)
49
+ unet.requires_grad_(False)
50
+ controlnet.requires_grad_(False)
51
+
52
+ unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path)
53
+
54
+ text_encoder.to(device, dtype=weight_dtype)
55
+ vae.to(device, dtype=weight_dtype)
56
+ unet.to(device, dtype=weight_dtype)
57
+ controlnet.to(device, dtype=weight_dtype)
58
+
59
+ validation_pipeline = StableDiffusionControlNetPipeline(
60
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
61
+ unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
62
+ )
63
+ #validation_pipeline.enable_vae_tiling()
64
+ validation_pipeline._init_tiled_vae(decoder_tile_size=224)
65
+
66
+ weights = ResNet50_Weights.DEFAULT
67
+ preprocess = weights.transforms()
68
+ resnet = resnet50(weights=weights)
69
+ resnet.eval()
70
+
71
+ def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed):
72
+ process_size = 768
73
+ resize_preproc = transforms.Compose([
74
+ transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR),
75
+ ])
76
+
77
+ with torch.no_grad():
78
+ seed_everything(seed)
79
+ generator = torch.Generator(device=device)
80
+
81
+ input_image = input_image.convert('RGB')
82
+ batch = preprocess(input_image).unsqueeze(0)
83
+ prediction = resnet(batch).squeeze(0).softmax(0)
84
+ class_id = prediction.argmax().item()
85
+ score = prediction[class_id].item()
86
+ category_name = weights.meta["categories"][class_id]
87
+ if score >= 0.1:
88
+ prompt += f"{category_name}" if prompt=='' else f", {category_name}"
89
+
90
+ prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}"
91
+
92
+ ori_width, ori_height = input_image.size
93
+ resize_flag = False
94
+
95
+ rscale = upscale
96
+ input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale))
97
+
98
+ if min(validation_image.size) < process_size:
99
+ validation_image = resize_preproc(validation_image)
100
+
101
+ input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8))
102
+ width, height = input_image.size
103
+ resize_flag = True #
104
+
105
+ try:
106
+ image = validation_pipeline(
107
+ None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg,
108
+ negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0,
109
+ ).images[0]
110
+
111
+ if True: #alpha<1.0:
112
+ image = wavelet_color_fix(image, input_image)
113
+
114
+ if resize_flag:
115
+ image = image.resize((ori_width*rscale, ori_height*rscale))
116
+ except Exception as e:
117
+ print(e)
118
+ image = Image.new(mode="RGB", size=(512, 512))
119
+
120
+ return image
121
+
122
+ title = "Pixel-Aware Stable Diffusion for Real-ISR"
123
+ description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them."
124
+ article = "<p style='text-align: center'><a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a></p>"
125
+ examples=[['samples/27d38eeb2dbbe7c9.png'],['samples/629e4da70703193b.png']]
126
+
127
+ demo = gr.Interface(
128
+ fn=inference,
129
+ inputs=[gr.Image(type="pil"),
130
+ gr.Textbox(label="Prompt", value="Asian"),
131
+ gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece'),
132
+ gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'),
133
+ gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1),
134
+ gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1),
135
+ gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1),
136
+ gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1),
137
+ gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)],
138
+ outputs=gr.Image(type="pil"),
139
+ title=title,
140
+ description=description,
141
+ article=article,
142
+ examples=examples).queue(concurrency_count=1)
143
+
144
+ demo.launch()