mattb512 commited on
Commit
f65c76f
1 Parent(s): 360304e

make device a var, custom gradio interface

Browse files
Files changed (2) hide show
  1. app.py +37 -3
  2. image_generator.py +185 -0
app.py CHANGED
@@ -1,7 +1,41 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  iface.launch()
 
1
  import gradio as gr
2
+ from image_generator import ImageGenerator
3
 
4
+ ig = ImageGenerator(g=7.5)
5
+ print(ig)
6
+ ig.load_models()
7
+ ig.load_scheduler()
8
 
9
+ def greet(prompt, mix_prompt, mix_ratio, negative_prompt, steps, init_image ):
10
+
11
+ print(f"{prompt=} {mix_prompt=} {mix_ratio=} {negative_prompt=} {steps=} {init_image=} ")
12
+ generated_image, latents = ig.generate(
13
+ prompt=prompt,
14
+ secondary_prompt=mix_prompt,
15
+ prompt_mix_ratio=mix_ratio,
16
+ negative_prompt=negative_prompt,
17
+ steps=steps,
18
+ init_image=init_image,
19
+ latent_callback_mod=None )
20
+
21
+ if init_image is not None:
22
+ noisy_latent = latents[1]
23
+ else:
24
+ noisy_latent = None
25
+
26
+ return generated_image, noisy_latent
27
+
28
+ iface = gr.Interface(
29
+ fn=greet,
30
+ inputs=[
31
+ gr.Textbox(value="a cute dog", label="Prompt", info="primary prompt used to generate an image"),
32
+ gr.Textbox(value=None, label="Secondary Prompt", info="secondary prompt to mix with the primary embeddings"),
33
+ gr.Slider(0, 1, value=0.5, label="Mix Ratio", info="mix ratio between primary and secondary prompt. 0 = primary only. 1 = secondary only"),
34
+ gr.Textbox(value=None, label="Negative Prompt", info="remove certain aspect from the picture"),
35
+ gr.Slider(10, 50, value=30, step=1, label="Generation Steps", info="How many steps are used to generate the picture"),
36
+ gr.Image(type="pil", value=None, label="Starting Image",), # info="starting image from this image as opposed to random noise"
37
+ ],
38
+ outputs=[
39
+ gr.Image(type="pil", label="Generated Image",),
40
+ gr.Image(type="pil", label="Starting Image with Added Noise",)])
41
  iface.launch()
image_generator.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+ from fastcore.all import concat
8
+ from huggingface_hub import notebook_login
9
+ from PIL import Image
10
+ import numpy as np
11
+ # from IPython.display import display
12
+ from torchvision import transforms as tfms
13
+
14
+ from transformers import CLIPTextModel, CLIPTokenizer
15
+ from diffusers import AutoencoderKL, UNet2DConditionModel
16
+ from diffusers import LMSDiscreteScheduler
17
+ from tqdm.auto import tqdm
18
+
19
+ logging.disable(logging.WARNING)
20
+ class ImageGenerator():
21
+ def __init__(self,
22
+ g:int=7.5,
23
+ ):
24
+ self.latent_images = []
25
+ self.g = g
26
+ self.width = 512
27
+ self.height = 512
28
+ self.generator = torch.manual_seed(32)
29
+ self.bs = 1
30
+ if torch.cuda.is_available():
31
+ self.device = torch.device("cuda")
32
+ self.dtype = torch.float16
33
+ else:
34
+ self.device = torch.device("cpu")
35
+ self.dtype = torch.float32
36
+
37
+ print(f"Working on device: {self.device=}")
38
+ def __repr__(self):
39
+ return f"Image Generator with {self.g=}"
40
+
41
+ def load_models(self):
42
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=self.dtype)
43
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=self.dtype).to(self.device)
44
+ # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=self.dtype ).to(self.device)
45
+ self.vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(self.device)
46
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(self.device) #torch_dtype=torch.float16,
47
+
48
+ def load_scheduler( self,
49
+ beta_start : float=0.00085,
50
+ beta_end : float=0.012,
51
+ beta_schedule : str="scaled_linear",
52
+ num_train_timesteps :int=1000):
53
+
54
+ self.scheduler = LMSDiscreteScheduler(
55
+ beta_start=beta_start,
56
+ beta_end=beta_end,
57
+ beta_schedule="scaled_linear",
58
+ num_train_timesteps=num_train_timesteps)
59
+
60
+ def load_image(self, filepath:str) -> Image:
61
+ return Image.open(filepath).resize(size=(self.width,self.height))
62
+ #.convert("RGB") # RGB = 3 dimensions, RGBA = 4 dimensions
63
+
64
+ def nparray_to_pil(self, np_image: np.array) -> Image:
65
+ return Image.fromarray(np_image).resize(size=(self.width,self.height))
66
+
67
+ def pil_to_latent(self, image: Image) -> torch.Tensor:
68
+ with torch.no_grad():
69
+ np_img = np.transpose( (( np.array(image) / 255)-0.5)*2, (2,0,1)) # turn pil image into np array with values between -1 and 1
70
+ # print(f"{np_img.shape=}") # 4, 64, 64
71
+
72
+ np_images = np.repeat(np_img[np.newaxis, :, :], self.bs, axis=0) # adding a new dimension and repeating the image for each prompt
73
+ # print(f"{np_images.shape=}")
74
+
75
+ decoded_latent = torch.from_numpy(np_images).to(self.device).float() #<-- stability-ai vae uses half(), compvis vae uses float?
76
+ # print(f"{decoded_latent.shape=}")
77
+
78
+ encoded_latent = 0.18215 * self.vae.encode(decoded_latent).latent_dist.sample()
79
+ # print(f"{encoded_latent.shape=}")
80
+
81
+ return encoded_latent
82
+
83
+ def add_noise(self, latent: torch.Tensor, scheduler_steps: int = 10) -> torch.FloatTensor:
84
+ # noise = torch.randn_like(latent) # missing generator parameter
85
+ noise = torch.randn(
86
+ size = (self.bs, self.unet.config.in_channels, self.height//8, self.width//8),
87
+ generator = self.generator).to(self.device)
88
+ timesteps = torch.tensor([self.scheduler.timesteps[scheduler_steps]])
89
+ noisy_latent = self.scheduler.add_noise(latent, noise, timesteps)
90
+ # print(f"add_noise: {timesteps.shape=} {timesteps=} {noisy_latent.shape=}")
91
+ return noisy_latent
92
+
93
+ def latent_to_pil(self, latent:torch.Tensor) -> Image:
94
+ # print(f"latent_to_pil {latent.dtype=}")
95
+ with torch.no_grad():
96
+ decoded = self.vae.decode(1 / 0.18215 * latent).sample[0]
97
+ # print(f"latent_to_pil {decoded.shape=}")
98
+ image = (decoded/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
99
+ return Image.fromarray((image*255).round().astype("uint8"))
100
+
101
+ def image_grid(self, imgs: [Image]) -> Image:
102
+ w,h = imgs[0].size
103
+ cols = len(imgs)
104
+ grid = Image.new('RGB', size=(cols*w, h))
105
+ for i, img in enumerate(imgs):
106
+ # print(f"{img.size=}")
107
+ grid.paste(img, box=(i%cols*w, i//cols*h))
108
+ return grid
109
+
110
+ def text_enc(self, prompt:str, maxlen=None) -> torch.Tensor:
111
+ '''tokenize and encode a prompt'''
112
+ if maxlen is None: maxlen = self.tokenizer.model_max_length
113
+
114
+ inp = self.tokenizer([prompt], padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
115
+ return self.text_encoder(inp.input_ids.to(self.device))[0].float()
116
+
117
+ def tensor_to_pil(self, t:torch.Tensor) -> Image:
118
+ '''transforms a tensor decoded by the vae to a pil image'''
119
+ # print(f"tensor_to_pil {t.shape=} {type(t)=}")
120
+ image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
121
+ return Image.fromarray((image*255).round().astype("uint8"))
122
+
123
+ def latent_callback(self, latent:torch.Tensor) -> None:
124
+ '''store latents in an array so that we can inpect them later.'''
125
+ with torch.no_grad():
126
+ # print(f"cb {latent.shape=}")
127
+ decoded = self.vae.decode(1 / 0.18215 * latent).sample[0]
128
+ self.latent_images.append(self.tensor_to_pil(decoded))
129
+
130
+ def generate(self,
131
+ prompt : str,
132
+ secondary_prompt: str=None,
133
+ prompt_mix_ratio : float=0.5,
134
+ negative_prompt="",
135
+ seed : int=32,
136
+ steps : int=30,
137
+ start_step_ratio : float=1/5,
138
+ init_image : Image=None,
139
+ latent_callback_mod : int=10):
140
+ self.latent_images = []
141
+ if not negative_prompt: negative_prompt = ""
142
+
143
+ with torch.no_grad():
144
+ text = self.text_enc(prompt)
145
+ if secondary_prompt:
146
+ sec_prompt_text = self.text_enc(secondary_prompt)
147
+ text = text * prompt_mix_ratio + sec_prompt_text * ( 1 - prompt_mix_ratio )
148
+ uncond = self.text_enc(negative_prompt * self.bs, text.shape[1])
149
+ emb = torch.cat([uncond, text])
150
+ if seed: torch.manual_seed(seed)
151
+
152
+ self.scheduler.set_timesteps(steps)
153
+ self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32)
154
+
155
+ if (init_image == None):
156
+ start_steps = 0
157
+ latents = torch.randn(
158
+ size = (self.bs, self.unet.config.in_channels, self.height//8, self.width//8),
159
+ generator = self.generator)
160
+ latents = latents * self.scheduler.init_noise_sigma
161
+ # print(f"{latents.shape=}")
162
+ else:
163
+ start_steps = int(steps * start_step_ratio) # 0%: too much noise, 100% no noise
164
+ # print(f"{start_steps=}")
165
+ # img = self.load_image(init_image)
166
+ latents =self.pil_to_latent(init_image)
167
+ self.latent_callback(latents)
168
+ latents = self.add_noise(latents, start_steps).to(self.device).float()
169
+ self.latent_callback(latents)
170
+
171
+ latents = latents.to(self.device).float()
172
+
173
+ for i,ts in enumerate(tqdm(self.scheduler.timesteps, leave=False)):
174
+ if i >= start_steps:
175
+ inp = self.scheduler.scale_model_input(torch.cat([latents] * 2), ts)
176
+ with torch.no_grad():
177
+ u,t = self.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2) #todo, grab those with callbacks
178
+ pred = u + self.g*(t-u)
179
+ # pred = u + self.g*(t-u)/torch.norm(t-u)*torch.norm(u)
180
+ latents = self.scheduler.step(pred, ts, latents).prev_sample
181
+
182
+ if latent_callback_mod and i % latent_callback_mod == 0:
183
+ self.latent_callback(latents)
184
+
185
+ return self.latent_to_pil(latents), self.latent_images