apolinario commited on
Commit
a0bd9cc
1 Parent(s): 7f1aa40

Initial attempt Hypetron v2

Browse files
Files changed (2) hide show
  1. app.py +2346 -8
  2. requirements.txt +20 -1
app.py CHANGED
@@ -1,11 +1,2349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- is_cuda = torch.cuda.is_available()
5
- def greet(name):
6
- if is_cuda:
7
- return "Hello cuda" + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  else:
9
- return "Hello ooops" + name + "!!"
10
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
11
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import math
4
+ from pathlib import Path
5
+ import sys
6
+ import pandas as pd
7
+ from base64 import b64encode
8
+ from omegaconf import OmegaConf
9
+ from PIL import Image
10
+ from taming.models import cond_transformer, vqgan
11
+ import torch
12
+ from os.path import exists as path_exists
13
+
14
+ torch.cuda.empty_cache()
15
+ from torch import nn
16
+ import torch.optim as optim
17
+ from torch import optim
18
+ from torch.nn import functional as F
19
+ from torchvision import transforms
20
+ from torchvision.transforms import functional as TF
21
+ import torchvision.transforms as T
22
+
23
+ from CLIP import clip
24
  import gradio as gr
25
+ import kornia.augmentation as K
26
+ import numpy as np
27
+ import subprocess
28
+ import imageio
29
+ from PIL import ImageFile, Image
30
+ import time
31
+
32
+ import hashlib
33
+ from PIL.PngImagePlugin import PngImageFile, PngInfo
34
+ import json
35
+ import IPython
36
+ from IPython.display import Markdown, display, Image, clear_output
37
+ import urllib.request
38
+ import random
39
+ from random import randint
40
+ from pathvalidate import sanitize_filename
41
+ from huggingface_hub import hf_hub_download
42
+
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ print("Using device:", device)
45
+
46
+ vqgan_model = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="model.ckpt")
47
+ vqgan_config = hf_hub_download(repo_id="boris/vqgan_f16_16384", filename="config.yaml")
48
+
49
+ def load_vqgan_model(config_path, checkpoint_path):
50
+ config = OmegaConf.load(config_path)
51
+ if config.model.target == "taming.models.vqgan.VQModel":
52
+ model = vqgan.VQModel(**config.model.params)
53
+ model.eval().requires_grad_(False)
54
+ model.init_from_ckpt(checkpoint_path)
55
+ elif config.model.target == "taming.models.cond_transformer.Net2NetTransformer":
56
+ parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
57
+ parent_model.eval().requires_grad_(False)
58
+ parent_model.init_from_ckpt(checkpoint_path)
59
+ model = parent_model.first_stage_model
60
+ elif config.model.target == "taming.models.vqgan.GumbelVQ":
61
+ model = vqgan.GumbelVQ(**config.model.params)
62
+ # print(config.model.params)
63
+ model.eval().requires_grad_(False)
64
+ model.init_from_ckpt(checkpoint_path)
65
+ else:
66
+ raise ValueError(f"unknown model type: {config.model.target}")
67
+ del model.loss
68
+ return model
69
+ model = load_vqgan_model(vqgan_config, vqgan_model).to(device)
70
+ perceptor = (
71
+ clip.load("ViT-B/32", jit=False)[0]
72
+ .eval()
73
+ .requires_grad_(False)
74
+ .to(device)
75
+ )
76
+ def run(user_input,num_steps, template, width,height):
77
+ #if uploaded_file is not None:
78
+ #uploaded_folder = f"{DefaultPaths.root_path}/uploaded"
79
+ #if not path_exists(uploaded_folder):
80
+ # os.makedirs(uploaded_folder)
81
+ #image_data = uploaded_file.read()
82
+ #f = open(f"{uploaded_folder}/{uploaded_file.name}", "wb")
83
+ #f.write(image_data)
84
+ #f.close()
85
+ #image_path = f"{uploaded_folder}/{uploaded_file.name}"
86
+ #pass
87
+ #else:
88
+ image_path = None
89
+ flavor = 'cumin'
90
+
91
+ args2 = argparse.Namespace(
92
+ prompt=user_input,
93
+ seed=int(seed),
94
+ sizex=width,
95
+ sizey=height,
96
+ flavor=flavor,
97
+ iterations=num_steps,
98
+ mse=True,
99
+ update=100,
100
+ template=template,
101
+ vqgan_model='ImageNet 16384',
102
+ seed_image=image_path,
103
+ image_file="progress.png",
104
+ #frame_dir=intermediary_folder,
105
+ )
106
+ if args2.seed is not None:
107
+ import torch
108
+
109
+ sys.stdout.write(f"Setting seed to {args2.seed} ...\n")
110
+ sys.stdout.flush()
111
+ import numpy as np
112
+
113
+ np.random.seed(args2.seed)
114
+ import random
115
+
116
+ random.seed(args2.seed)
117
+ # next line forces deterministic random values, but causes other issues with resampling (uncomment to see)
118
+ torch.manual_seed(args2.seed)
119
+ torch.cuda.manual_seed(args2.seed)
120
+ torch.cuda.manual_seed_all(args2.seed)
121
+ torch.backends.cudnn.deterministic = True
122
+ torch.backends.cudnn.benchmark = False
123
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
124
+ print("Using device:", device)
125
+
126
+ def noise_gen(shape, octaves=5):
127
+ n, c, h, w = shape
128
+ noise = torch.zeros([n, c, 1, 1])
129
+ max_octaves = min(octaves, math.log(h) / math.log(2), math.log(w) / math.log(2))
130
+ for i in reversed(range(max_octaves)):
131
+ h_cur, w_cur = h // 2**i, w // 2**i
132
+ noise = F.interpolate(
133
+ noise, (h_cur, w_cur), mode="bicubic", align_corners=False
134
+ )
135
+ noise += torch.randn([n, c, h_cur, w_cur]) / 5
136
+ return noise
137
+
138
+ def sinc(x):
139
+ return torch.where(
140
+ x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])
141
+ )
142
+
143
+ def lanczos(x, a):
144
+ cond = torch.logical_and(-a < x, x < a)
145
+ out = torch.where(cond, sinc(x) * sinc(x / a), x.new_zeros([]))
146
+ return out / out.sum()
147
+
148
+ def ramp(ratio, width):
149
+ n = math.ceil(width / ratio + 1)
150
+ out = torch.empty([n])
151
+ cur = 0
152
+ for i in range(out.shape[0]):
153
+ out[i] = cur
154
+ cur += ratio
155
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
156
+
157
+ def resample(input, size, align_corners=True):
158
+ n, c, h, w = input.shape
159
+ dh, dw = size
160
+
161
+ input = input.view([n * c, 1, h, w])
162
+
163
+ if dh < h:
164
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
165
+ pad_h = (kernel_h.shape[0] - 1) // 2
166
+ input = F.pad(input, (0, 0, pad_h, pad_h), "reflect")
167
+ input = F.conv2d(input, kernel_h[None, None, :, None])
168
+
169
+ if dw < w:
170
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
171
+ pad_w = (kernel_w.shape[0] - 1) // 2
172
+ input = F.pad(input, (pad_w, pad_w, 0, 0), "reflect")
173
+ input = F.conv2d(input, kernel_w[None, None, None, :])
174
+
175
+ input = input.view([n, c, h, w])
176
+ return F.interpolate(input, size, mode="bicubic", align_corners=align_corners)
177
+
178
+ def lerp(a, b, f):
179
+ return (a * (1.0 - f)) + (b * f)
180
+
181
+ class ReplaceGrad(torch.autograd.Function):
182
+ @staticmethod
183
+ def forward(ctx, x_forward, x_backward):
184
+ ctx.shape = x_backward.shape
185
+ return x_forward
186
+
187
+ @staticmethod
188
+ def backward(ctx, grad_in):
189
+ return None, grad_in.sum_to_size(ctx.shape)
190
+
191
+ replace_grad = ReplaceGrad.apply
192
+
193
+ class ClampWithGrad(torch.autograd.Function):
194
+ @staticmethod
195
+ def forward(ctx, input, min, max):
196
+ ctx.min = min
197
+ ctx.max = max
198
+ ctx.save_for_backward(input)
199
+ return input.clamp(min, max)
200
+
201
+ @staticmethod
202
+ def backward(ctx, grad_in):
203
+ (input,) = ctx.saved_tensors
204
+ return (
205
+ grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0),
206
+ None,
207
+ None,
208
+ )
209
+
210
+ clamp_with_grad = ClampWithGrad.apply
211
+
212
+ def vector_quantize(x, codebook):
213
+ d = (
214
+ x.pow(2).sum(dim=-1, keepdim=True)
215
+ + codebook.pow(2).sum(dim=1)
216
+ - 2 * x @ codebook.T
217
+ )
218
+ indices = d.argmin(-1)
219
+ x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
220
+ return replace_grad(x_q, x)
221
+
222
+ class Prompt(nn.Module):
223
+ def __init__(self, embed, weight=1.0, stop=float("-inf")):
224
+ super().__init__()
225
+ self.register_buffer("embed", embed)
226
+ self.register_buffer("weight", torch.as_tensor(weight))
227
+ self.register_buffer("stop", torch.as_tensor(stop))
228
+
229
+ def forward(self, input):
230
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
231
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
232
+ dists = (
233
+ input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
234
+ )
235
+ dists = dists * self.weight.sign()
236
+ return (
237
+ self.weight.abs()
238
+ * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
239
+ )
240
+
241
+ def parse_prompt(prompt):
242
+ if prompt.startswith("http://") or prompt.startswith("https://"):
243
+ vals = prompt.rsplit(":", 1)
244
+ vals = [vals[0] + ":" + vals[1], *vals[2:]]
245
+ else:
246
+ vals = prompt.rsplit(":", 1)
247
+ vals = vals + ["", "1", "-inf"][len(vals) :]
248
+ return vals[0], float(vals[1]), float(vals[2])
249
+
250
+ def one_sided_clip_loss(input, target, labels=None, logit_scale=100):
251
+ input_normed = F.normalize(input, dim=-1)
252
+ target_normed = F.normalize(target, dim=-1)
253
+ logits = input_normed @ target_normed.T * logit_scale
254
+ if labels is None:
255
+ labels = torch.arange(len(input), device=logits.device)
256
+ return F.cross_entropy(logits, labels)
257
+
258
+ class EMATensor(nn.Module):
259
+ """implmeneted by Katherine Crowson"""
260
+
261
+ def __init__(self, tensor, decay):
262
+ super().__init__()
263
+ self.tensor = nn.Parameter(tensor)
264
+ self.register_buffer("biased", torch.zeros_like(tensor))
265
+ self.register_buffer("average", torch.zeros_like(tensor))
266
+ self.decay = decay
267
+ self.register_buffer("accum", torch.tensor(1.0))
268
+ self.update()
269
+
270
+ @torch.no_grad()
271
+ def update(self):
272
+ if not self.training:
273
+ raise RuntimeError("update() should only be called during training")
274
+
275
+ self.accum *= self.decay
276
+ self.biased.mul_(self.decay)
277
+ self.biased.add_((1 - self.decay) * self.tensor)
278
+ self.average.copy_(self.biased)
279
+ self.average.div_(1 - self.accum)
280
+
281
+ def forward(self):
282
+ if self.training:
283
+ return self.tensor
284
+ return self.average
285
+
286
+ class MakeCutoutsCustom(nn.Module):
287
+ def __init__(self, cut_size, cutn, cut_pow, augs):
288
+ super().__init__()
289
+ self.cut_size = cut_size
290
+ # tqdm.write(f"cut size: {self.cut_size}")
291
+ self.cutn = cutn
292
+ self.cut_pow = cut_pow
293
+ self.noise_fac = 0.1
294
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
295
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
296
+ self.augs = nn.Sequential(
297
+ K.RandomHorizontalFlip(p=Random_Horizontal_Flip),
298
+ K.RandomSharpness(Random_Sharpness, p=Random_Sharpness_P),
299
+ K.RandomGaussianBlur(
300
+ (Random_Gaussian_Blur),
301
+ (Random_Gaussian_Blur_W, Random_Gaussian_Blur_W),
302
+ p=Random_Gaussian_Blur_P,
303
+ ),
304
+ K.RandomGaussianNoise(p=Random_Gaussian_Noise_P),
305
+ K.RandomElasticTransform(
306
+ kernel_size=(
307
+ Random_Elastic_Transform_Kernel_Size_W,
308
+ Random_Elastic_Transform_Kernel_Size_H,
309
+ ),
310
+ sigma=(Random_Elastic_Transform_Sigma),
311
+ p=Random_Elastic_Transform_P,
312
+ ),
313
+ K.RandomAffine(
314
+ degrees=Random_Affine_Degrees,
315
+ translate=Random_Affine_Translate,
316
+ p=Random_Affine_P,
317
+ padding_mode="border",
318
+ ),
319
+ K.RandomPerspective(Random_Perspective, p=Random_Perspective_P),
320
+ K.ColorJitter(
321
+ hue=Color_Jitter_Hue,
322
+ saturation=Color_Jitter_Saturation,
323
+ p=Color_Jitter_P,
324
+ ),
325
+ )
326
+ # K.RandomErasing((0.1, 0.7), (0.3, 1/0.4), same_on_batch=True, p=0.2),)
327
+
328
+ def set_cut_pow(self, cut_pow):
329
+ self.cut_pow = cut_pow
330
+
331
+ def forward(self, input):
332
+ sideY, sideX = input.shape[2:4]
333
+ max_size = min(sideX, sideY)
334
+ min_size = min(sideX, sideY, self.cut_size)
335
+ cutouts = []
336
+ cutouts_full = []
337
+ noise_fac = 0.1
338
+
339
+ min_size_width = min(sideX, sideY)
340
+ lower_bound = float(self.cut_size / min_size_width)
341
+
342
+ for ii in range(self.cutn):
343
+
344
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
345
+ randsize = (
346
+ torch.zeros(
347
+ 1,
348
+ )
349
+ .normal_(mean=0.8, std=0.3)
350
+ .clip(lower_bound, 1.0)
351
+ )
352
+ size_mult = randsize**self.cut_pow
353
+ size = int(
354
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
355
+ ) # replace .5 with a result for 224 the default large size is .95
356
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
357
+
358
+ offsetx = torch.randint(0, sideX - size + 1, ())
359
+ offsety = torch.randint(0, sideY - size + 1, ())
360
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
361
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
362
+
363
+ cutouts = torch.cat(cutouts, dim=0)
364
+ cutouts = clamp_with_grad(cutouts, 0, 1)
365
+
366
+ # if args.use_augs:
367
+ cutouts = self.augs(cutouts)
368
+ if self.noise_fac:
369
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
370
+ 0, self.noise_fac
371
+ )
372
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
373
+ return cutouts
374
+
375
+ class MakeCutoutsJuu(nn.Module):
376
+ def __init__(self, cut_size, cutn, cut_pow, augs):
377
+ super().__init__()
378
+ self.cut_size = cut_size
379
+ self.cutn = cutn
380
+ self.cut_pow = cut_pow
381
+ self.augs = nn.Sequential(
382
+ # K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
383
+ K.RandomHorizontalFlip(p=0.5),
384
+ K.RandomSharpness(0.3, p=0.4),
385
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
386
+ K.RandomPerspective(0.2, p=0.4),
387
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
388
+ K.RandomGrayscale(p=0.1),
389
+ )
390
+ self.noise_fac = 0.1
391
+
392
+ def forward(self, input):
393
+ sideY, sideX = input.shape[2:4]
394
+ max_size = min(sideX, sideY)
395
+ min_size = min(sideX, sideY, self.cut_size)
396
+ cutouts = []
397
+ for _ in range(self.cutn):
398
+ size = int(
399
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
400
+ )
401
+ offsetx = torch.randint(0, sideX - size + 1, ())
402
+ offsety = torch.randint(0, sideY - size + 1, ())
403
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
404
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
405
+ batch = self.augs(torch.cat(cutouts, dim=0))
406
+ if self.noise_fac:
407
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
408
+ batch = batch + facs * torch.randn_like(batch)
409
+ return batch
410
+
411
+ class MakeCutoutsMoth(nn.Module):
412
+ def __init__(self, cut_size, cutn, cut_pow, augs, skip_augs=False):
413
+ super().__init__()
414
+ self.cut_size = cut_size
415
+ self.cutn = cutn
416
+ self.cut_pow = cut_pow
417
+ self.skip_augs = skip_augs
418
+ self.augs = T.Compose(
419
+ [
420
+ T.RandomHorizontalFlip(p=0.5),
421
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
422
+ T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
423
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
424
+ T.RandomPerspective(distortion_scale=0.4, p=0.7),
425
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
426
+ T.RandomGrayscale(p=0.15),
427
+ T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
428
+ # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
429
+ ]
430
+ )
431
+
432
+ def forward(self, input):
433
+ input = T.Pad(input.shape[2] // 4, fill=0)(input)
434
+ sideY, sideX = input.shape[2:4]
435
+ max_size = min(sideX, sideY)
436
+
437
+ cutouts = []
438
+ for ch in range(cutn):
439
+ if ch > cutn - cutn // 4:
440
+ cutout = input.clone()
441
+ else:
442
+ size = int(
443
+ max_size
444
+ * torch.zeros(
445
+ 1,
446
+ )
447
+ .normal_(mean=0.8, std=0.3)
448
+ .clip(float(self.cut_size / max_size), 1.0)
449
+ )
450
+ offsetx = torch.randint(0, abs(sideX - size + 1), ())
451
+ offsety = torch.randint(0, abs(sideY - size + 1), ())
452
+ cutout = input[
453
+ :, :, offsety : offsety + size, offsetx : offsetx + size
454
+ ]
455
+
456
+ if not self.skip_augs:
457
+ cutout = self.augs(cutout)
458
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
459
+ del cutout
460
+
461
+ cutouts = torch.cat(cutouts, dim=0)
462
+ return cutouts
463
+
464
+ class MakeCutoutsAaron(nn.Module):
465
+ def __init__(self, cut_size, cutn, cut_pow, augs):
466
+ super().__init__()
467
+ self.cut_size = cut_size
468
+ self.cutn = cutn
469
+ self.cut_pow = cut_pow
470
+ self.augs = augs
471
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
472
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
473
+
474
+ def set_cut_pow(self, cut_pow):
475
+ self.cut_pow = cut_pow
476
+
477
+ def forward(self, input):
478
+ sideY, sideX = input.shape[2:4]
479
+ max_size = min(sideX, sideY)
480
+ min_size = min(sideX, sideY, self.cut_size)
481
+ cutouts = []
482
+ cutouts_full = []
483
+
484
+ min_size_width = min(sideX, sideY)
485
+ lower_bound = float(self.cut_size / min_size_width)
486
+
487
+ for ii in range(self.cutn):
488
+ size = int(
489
+ min_size_width
490
+ * torch.zeros(
491
+ 1,
492
+ )
493
+ .normal_(mean=0.8, std=0.3)
494
+ .clip(lower_bound, 1.0)
495
+ ) # replace .5 with a result for 224 the default large size is .95
496
+
497
+ offsetx = torch.randint(0, sideX - size + 1, ())
498
+ offsety = torch.randint(0, sideY - size + 1, ())
499
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
500
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
501
+
502
+ cutouts = torch.cat(cutouts, dim=0)
503
+
504
+ return clamp_with_grad(cutouts, 0, 1)
505
+
506
+ class MakeCutoutsCumin(nn.Module):
507
+ # from https://colab.research.google.com/drive/1ZAus_gn2RhTZWzOWUpPERNC0Q8OhZRTZ
508
+ def __init__(self, cut_size, cutn, cut_pow, augs):
509
+ super().__init__()
510
+ self.cut_size = cut_size
511
+ # tqdm.write(f"cut size: {self.cut_size}")
512
+ self.cutn = cutn
513
+ self.cut_pow = cut_pow
514
+ self.noise_fac = 0.1
515
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
516
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
517
+ self.augs = nn.Sequential(
518
+ # K.RandomHorizontalFlip(p=0.5),
519
+ # K.RandomSharpness(0.3,p=0.4),
520
+ # K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
521
+ # K.RandomGaussianNoise(p=0.5),
522
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
523
+ K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode="border"),
524
+ K.RandomPerspective(0.7, p=0.7),
525
+ K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
526
+ K.RandomErasing((0.1, 0.4), (0.3, 1 / 0.3), same_on_batch=True, p=0.7),
527
+ )
528
+
529
+ def set_cut_pow(self, cut_pow):
530
+ self.cut_pow = cut_pow
531
+
532
+ def forward(self, input):
533
+ sideY, sideX = input.shape[2:4]
534
+ max_size = min(sideX, sideY)
535
+ min_size = min(sideX, sideY, self.cut_size)
536
+ cutouts = []
537
+ cutouts_full = []
538
+ noise_fac = 0.1
539
+
540
+ min_size_width = min(sideX, sideY)
541
+ lower_bound = float(self.cut_size / min_size_width)
542
+
543
+ for ii in range(self.cutn):
544
+
545
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
546
+ randsize = (
547
+ torch.zeros(
548
+ 1,
549
+ )
550
+ .normal_(mean=0.8, std=0.3)
551
+ .clip(lower_bound, 1.0)
552
+ )
553
+ size_mult = randsize**self.cut_pow
554
+ size = int(
555
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
556
+ ) # replace .5 with a result for 224 the default large size is .95
557
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
558
+
559
+ offsetx = torch.randint(0, sideX - size + 1, ())
560
+ offsety = torch.randint(0, sideY - size + 1, ())
561
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
562
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
563
+
564
+ cutouts = torch.cat(cutouts, dim=0)
565
+ cutouts = clamp_with_grad(cutouts, 0, 1)
566
+
567
+ # if args.use_augs:
568
+ cutouts = self.augs(cutouts)
569
+ if self.noise_fac:
570
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
571
+ 0, self.noise_fac
572
+ )
573
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
574
+ return cutouts
575
+
576
+ class MakeCutoutsHolywater(nn.Module):
577
+ def __init__(self, cut_size, cutn, cut_pow, augs):
578
+ super().__init__()
579
+ self.cut_size = cut_size
580
+ # tqdm.write(f"cut size: {self.cut_size}")
581
+ self.cutn = cutn
582
+ self.cut_pow = cut_pow
583
+ self.noise_fac = 0.1
584
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
585
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
586
+ self.augs = nn.Sequential(
587
+ # K.RandomGaussianNoise(mean=0.0, std=0.5, p=0.1),
588
+ K.RandomHorizontalFlip(p=0.5),
589
+ K.RandomSharpness(0.3, p=0.4),
590
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
591
+ K.RandomPerspective(0.2, p=0.4),
592
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
593
+ K.RandomGrayscale(p=0.1),
594
+ )
595
+
596
+ def set_cut_pow(self, cut_pow):
597
+ self.cut_pow = cut_pow
598
+
599
+ def forward(self, input):
600
+ sideY, sideX = input.shape[2:4]
601
+ max_size = min(sideX, sideY)
602
+ min_size = min(sideX, sideY, self.cut_size)
603
+ cutouts = []
604
+ cutouts_full = []
605
+ noise_fac = 0.1
606
+ min_size_width = min(sideX, sideY)
607
+ lower_bound = float(self.cut_size / min_size_width)
608
+
609
+ for ii in range(self.cutn):
610
+ size = int(
611
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
612
+ )
613
+ randsize = (
614
+ torch.zeros(
615
+ 1,
616
+ )
617
+ .normal_(mean=0.8, std=0.3)
618
+ .clip(lower_bound, 1.0)
619
+ )
620
+ size_mult = randsize**self.cut_pow * ii + size
621
+ size1 = int(
622
+ (min_size_width) * (size_mult.clip(lower_bound, 1.0))
623
+ ) # replace .5 with a result for 224 the default large size is .95
624
+ size2 = int(
625
+ (min_size_width)
626
+ * torch.zeros(
627
+ 1,
628
+ )
629
+ .normal_(mean=0.9, std=0.3)
630
+ .clip(lower_bound, 0.95)
631
+ ) # replace .5 with a result for 224 the default large size is .95
632
+ offsetx = torch.randint(0, sideX - size1 + 1, ())
633
+ offsety = torch.randint(0, sideY - size2 + 1, ())
634
+ cutout = input[
635
+ :, :, offsety : offsety + size2 + ii, offsetx : offsetx + size1 + ii
636
+ ]
637
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
638
+
639
+ cutouts = torch.cat(cutouts, dim=0)
640
+ cutouts = clamp_with_grad(cutouts, 0, 1)
641
+ cutouts = self.augs(cutouts)
642
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
643
+ 0, self.noise_fac
644
+ )
645
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
646
+ return cutouts
647
+
648
+ class MakeCutoutsOldHolywater(nn.Module):
649
+ def __init__(self, cut_size, cutn, cut_pow, augs):
650
+ super().__init__()
651
+ self.cut_size = cut_size
652
+ # tqdm.write(f"cut size: {self.cut_size}")
653
+ self.cutn = cutn
654
+ self.cut_pow = cut_pow
655
+ self.noise_fac = 0.1
656
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
657
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
658
+ self.augs = nn.Sequential(
659
+ # K.RandomHorizontalFlip(p=0.5),
660
+ # K.RandomSharpness(0.3,p=0.4),
661
+ # K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
662
+ # K.RandomGaussianNoise(p=0.5),
663
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
664
+ K.RandomAffine(
665
+ degrees=180, translate=0.5, p=0.2, padding_mode="border"
666
+ ),
667
+ K.RandomPerspective(0.6, p=0.9),
668
+ K.ColorJitter(hue=0.03, saturation=0.01, p=0.1),
669
+ K.RandomErasing((0.1, 0.7), (0.3, 1 / 0.4), same_on_batch=True, p=0.2),
670
+ )
671
+
672
+ def set_cut_pow(self, cut_pow):
673
+ self.cut_pow = cut_pow
674
+
675
+ def forward(self, input):
676
+ sideY, sideX = input.shape[2:4]
677
+ max_size = min(sideX, sideY)
678
+ min_size = min(sideX, sideY, self.cut_size)
679
+ cutouts = []
680
+ cutouts_full = []
681
+ noise_fac = 0.1
682
+
683
+ min_size_width = min(sideX, sideY)
684
+ lower_bound = float(self.cut_size / min_size_width)
685
+
686
+ for ii in range(self.cutn):
687
+
688
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
689
+ randsize = (
690
+ torch.zeros(
691
+ 1,
692
+ )
693
+ .normal_(mean=0.8, std=0.3)
694
+ .clip(lower_bound, 1.0)
695
+ )
696
+ size_mult = randsize**self.cut_pow
697
+ size = int(
698
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
699
+ ) # replace .5 with a result for 224 the default large size is .95
700
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
701
+
702
+ offsetx = torch.randint(0, sideX - size + 1, ())
703
+ offsety = torch.randint(0, sideY - size + 1, ())
704
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
705
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
706
+
707
+ cutouts = torch.cat(cutouts, dim=0)
708
+ cutouts = clamp_with_grad(cutouts, 0, 1)
709
+
710
+ # if args.use_augs:
711
+ cutouts = self.augs(cutouts)
712
+ if self.noise_fac:
713
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
714
+ 0, self.noise_fac
715
+ )
716
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
717
+ return cutouts
718
+
719
+ class MakeCutoutsGinger(nn.Module):
720
+ def __init__(self, cut_size, cutn, cut_pow, augs):
721
+ super().__init__()
722
+ self.cut_size = cut_size
723
+ # tqdm.write(f"cut size: {self.cut_size}")
724
+ self.cutn = cutn
725
+ self.cut_pow = cut_pow
726
+ self.noise_fac = 0.1
727
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
728
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
729
+ self.augs = augs
730
+ """
731
+ nn.Sequential(
732
+ K.RandomHorizontalFlip(p=0.5),
733
+ K.RandomSharpness(0.3,p=0.4),
734
+ K.RandomGaussianBlur((3,3),(10.5,10.5),p=0.2),
735
+ K.RandomGaussianNoise(p=0.5),
736
+ K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
737
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), # padding_mode=2
738
+ K.RandomPerspective(0.2,p=0.4, ),
739
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),)
740
+ """
741
+
742
+ def set_cut_pow(self, cut_pow):
743
+ self.cut_pow = cut_pow
744
+
745
+ def forward(self, input):
746
+ sideY, sideX = input.shape[2:4]
747
+ max_size = min(sideX, sideY)
748
+ min_size = min(sideX, sideY, self.cut_size)
749
+ cutouts = []
750
+ cutouts_full = []
751
+ noise_fac = 0.1
752
+
753
+ min_size_width = min(sideX, sideY)
754
+ lower_bound = float(self.cut_size / min_size_width)
755
+
756
+ for ii in range(self.cutn):
757
+
758
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
759
+ randsize = (
760
+ torch.zeros(
761
+ 1,
762
+ )
763
+ .normal_(mean=0.8, std=0.3)
764
+ .clip(lower_bound, 1.0)
765
+ )
766
+ size_mult = randsize**self.cut_pow
767
+ size = int(
768
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
769
+ ) # replace .5 with a result for 224 the default large size is .95
770
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
771
+
772
+ offsetx = torch.randint(0, sideX - size + 1, ())
773
+ offsety = torch.randint(0, sideY - size + 1, ())
774
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
775
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
776
+
777
+ cutouts = torch.cat(cutouts, dim=0)
778
+ cutouts = clamp_with_grad(cutouts, 0, 1)
779
+
780
+ # if args.use_augs:
781
+ cutouts = self.augs(cutouts)
782
+ if self.noise_fac:
783
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
784
+ 0, self.noise_fac
785
+ )
786
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
787
+ return cutouts
788
+
789
+ class MakeCutoutsZynth(nn.Module):
790
+ def __init__(self, cut_size, cutn, cut_pow, augs):
791
+ super().__init__()
792
+ self.cut_size = cut_size
793
+ # tqdm.write(f"cut size: {self.cut_size}")
794
+ self.cutn = cutn
795
+ self.cut_pow = cut_pow
796
+ self.noise_fac = 0.1
797
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
798
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
799
+ self.augs = nn.Sequential(
800
+ K.RandomHorizontalFlip(p=0.5),
801
+ # K.RandomSolarize(0.01, 0.01, p=0.7),
802
+ K.RandomSharpness(0.3, p=0.4),
803
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode="border"),
804
+ K.RandomPerspective(0.2, p=0.4),
805
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
806
+ )
807
+
808
+ def set_cut_pow(self, cut_pow):
809
+ self.cut_pow = cut_pow
810
+
811
+ def forward(self, input):
812
+ sideY, sideX = input.shape[2:4]
813
+ max_size = min(sideX, sideY)
814
+ min_size = min(sideX, sideY, self.cut_size)
815
+ cutouts = []
816
+ cutouts_full = []
817
+ noise_fac = 0.1
818
+
819
+ min_size_width = min(sideX, sideY)
820
+ lower_bound = float(self.cut_size / min_size_width)
821
+
822
+ for ii in range(self.cutn):
823
+
824
+ # size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
825
+ randsize = (
826
+ torch.zeros(
827
+ 1,
828
+ )
829
+ .normal_(mean=0.8, std=0.3)
830
+ .clip(lower_bound, 1.0)
831
+ )
832
+ size_mult = randsize**self.cut_pow
833
+ size = int(
834
+ min_size_width * (size_mult.clip(lower_bound, 1.0))
835
+ ) # replace .5 with a result for 224 the default large size is .95
836
+ # size = int(min_size_width*torch.zeros(1,).normal_(mean=.9, std=.3).clip(lower_bound, .95)) # replace .5 with a result for 224 the default large size is .95
837
+
838
+ offsetx = torch.randint(0, sideX - size + 1, ())
839
+ offsety = torch.randint(0, sideY - size + 1, ())
840
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
841
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
842
+
843
+ cutouts = torch.cat(cutouts, dim=0)
844
+ cutouts = clamp_with_grad(cutouts, 0, 1)
845
+
846
+ # if args.use_augs:
847
+ cutouts = self.augs(cutouts)
848
+ if self.noise_fac:
849
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
850
+ 0, self.noise_fac
851
+ )
852
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
853
+ return cutouts
854
+
855
+ class MakeCutoutsWyvern(nn.Module):
856
+ def __init__(self, cut_size, cutn, cut_pow, augs):
857
+ super().__init__()
858
+ self.cut_size = cut_size
859
+ # tqdm.write(f"cut size: {self.cut_size}")
860
+ self.cutn = cutn
861
+ self.cut_pow = cut_pow
862
+ self.noise_fac = 0.1
863
+ self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
864
+ self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
865
+ self.augs = augs
866
+
867
+ def forward(self, input):
868
+ sideY, sideX = input.shape[2:4]
869
+ max_size = min(sideX, sideY)
870
+ min_size = min(sideX, sideY, self.cut_size)
871
+ cutouts = []
872
+ for _ in range(self.cutn):
873
+ size = int(
874
+ torch.rand([]) ** self.cut_pow * (max_size - min_size) + min_size
875
+ )
876
+ offsetx = torch.randint(0, sideX - size + 1, ())
877
+ offsety = torch.randint(0, sideY - size + 1, ())
878
+ cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
879
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
880
+ return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)
881
+
882
+
883
+ import PIL
884
+
885
+ def resize_image(image, out_size):
886
+ ratio = image.size[0] / image.size[1]
887
+ area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
888
+ size = round((area * ratio) ** 0.5), round((area / ratio) ** 0.5)
889
+ return image.resize(size, PIL.Image.LANCZOS)
890
+
891
+ class GaussianBlur2d(nn.Module):
892
+ def __init__(self, sigma, window=0, mode="reflect", value=0):
893
+ super().__init__()
894
+ self.mode = mode
895
+ self.value = value
896
+ if not window:
897
+ window = max(math.ceil((sigma * 6 + 1) / 2) * 2 - 1, 3)
898
+ if sigma:
899
+ kernel = torch.exp(
900
+ -((torch.arange(window) - window // 2) ** 2) / 2 / sigma**2
901
+ )
902
+ kernel /= kernel.sum()
903
+ else:
904
+ kernel = torch.ones([1])
905
+ self.register_buffer("kernel", kernel)
906
+
907
+ def forward(self, input):
908
+ n, c, h, w = input.shape
909
+ input = input.view([n * c, 1, h, w])
910
+ start_pad = (self.kernel.shape[0] - 1) // 2
911
+ end_pad = self.kernel.shape[0] // 2
912
+ input = F.pad(
913
+ input, (start_pad, end_pad, start_pad, end_pad), self.mode, self.value
914
+ )
915
+ input = F.conv2d(input, self.kernel[None, None, None, :])
916
+ input = F.conv2d(input, self.kernel[None, None, :, None])
917
+ return input.view([n, c, h, w])
918
+
919
+ BUF_SIZE = 65536
920
+
921
+ def get_digest(path, alg=hashlib.sha256):
922
+ hash = alg()
923
+ # print(path)
924
+ with open(path, "rb") as fp:
925
+ while True:
926
+ data = fp.read(BUF_SIZE)
927
+ if not data:
928
+ break
929
+ hash.update(data)
930
+ return b64encode(hash.digest()).decode("utf-8")
931
+
932
+ flavordict = {
933
+ "cumin": MakeCutoutsCumin,
934
+ "holywater": MakeCutoutsHolywater,
935
+ "old_holywater": MakeCutoutsOldHolywater,
936
+ "ginger": MakeCutoutsGinger,
937
+ "zynth": MakeCutoutsZynth,
938
+ "wyvern": MakeCutoutsWyvern,
939
+ "aaron": MakeCutoutsAaron,
940
+ "moth": MakeCutoutsMoth,
941
+ "juu": MakeCutoutsJuu,
942
+ "custom": MakeCutoutsCustom,
943
+ }
944
+
945
+ @torch.jit.script
946
+ def gelu_impl(x):
947
+ """OpenAI's gelu implementation."""
948
+ return (
949
+ 0.5
950
+ * x
951
+ * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
952
+ )
953
+
954
+ def gelu(x):
955
+ return gelu_impl(x)
956
+
957
+ class MSEDecayLoss(nn.Module):
958
+ def __init__(self, init_weight, mse_decay_rate, mse_epoches, mse_quantize):
959
+ super().__init__()
960
+
961
+ self.init_weight = init_weight
962
+ self.has_init_image = False
963
+ self.mse_decay = init_weight / mse_epoches if init_weight else 0
964
+ self.mse_decay_rate = mse_decay_rate
965
+ self.mse_weight = init_weight
966
+ self.mse_epoches = mse_epoches
967
+ self.mse_quantize = mse_quantize
968
+
969
+ @torch.no_grad()
970
+ def set_target(self, z_tensor, model):
971
+ z_tensor = z_tensor.detach().clone()
972
+ if self.mse_quantize:
973
+ z_tensor = vector_quantize(
974
+ z_tensor.movedim(1, 3), model.quantize.embedding.weight
975
+ ).movedim(
976
+ 3, 1
977
+ ) # z.average
978
+ self.z_orig = z_tensor
979
+
980
+ def forward(self, i, z):
981
+ if self.is_active(i):
982
+ return F.mse_loss(z, self.z_orig) * self.mse_weight / 2
983
+ return 0
984
+
985
+ def is_active(self, i):
986
+ if not self.init_weight:
987
+ return False
988
+ if i <= self.mse_decay_rate and not self.has_init_image:
989
+ return False
990
+ return True
991
+
992
+ @torch.no_grad()
993
+ def step(self, i):
994
+
995
+ if (
996
+ i % self.mse_decay_rate == 0
997
+ and i != 0
998
+ and i < self.mse_decay_rate * self.mse_epoches
999
+ ):
1000
+
1001
+ if (
1002
+ self.mse_weight - self.mse_decay > 0
1003
+ and self.mse_weight - self.mse_decay >= self.mse_decay
1004
+ ):
1005
+ self.mse_weight -= self.mse_decay
1006
+ else:
1007
+ self.mse_weight = 0
1008
+ # print(f"updated mse weight: {self.mse_weight}")
1009
+
1010
+ return True
1011
+
1012
+ return False
1013
+
1014
+ class TVLoss(nn.Module):
1015
+ def forward(self, input):
1016
+ input = F.pad(input, (0, 1, 0, 1), "replicate")
1017
+ x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
1018
+ y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
1019
+ diff = x_diff**2 + y_diff**2 + 1e-8
1020
+ return diff.mean(dim=1).sqrt().mean()
1021
+
1022
+ class MultiClipLoss(nn.Module):
1023
+ def __init__(
1024
+ self, clip_models, text_prompt, cutn, cut_pow=1.0, clip_weight=1.0
1025
+ ):
1026
+ super().__init__()
1027
+
1028
+ # Load Clip
1029
+ self.perceptors = []
1030
+ for cm in clip_models:
1031
+ sys.stdout.write(f"Loading {cm[0]} ...\n")
1032
+ sys.stdout.flush()
1033
+ c = (
1034
+ clip.load(cm[0], jit=False)[0]
1035
+ .eval()
1036
+ .requires_grad_(False)
1037
+ .to(device)
1038
+ )
1039
+ self.perceptors.append(
1040
+ {
1041
+ "res": c.visual.input_resolution,
1042
+ "perceptor": c,
1043
+ "weight": cm[1],
1044
+ "prompts": [],
1045
+ }
1046
+ )
1047
+ self.perceptors.sort(key=lambda e: e["res"], reverse=True)
1048
+
1049
+ # Make Cutouts
1050
+ self.max_cut_size = self.perceptors[0]["res"]
1051
+ # self.make_cuts = flavordict[flavor](self.max_cut_size, cutn, cut_pow)
1052
+ # cutouts = flavordict[flavor](self.max_cut_size, cutn, cut_pow=cut_pow, augs=args.augs)
1053
+
1054
+ # Get Prompt Embedings
1055
+ # texts = [phrase.strip() for phrase in text_prompt.split("|")]
1056
+ # if text_prompt == ['']:
1057
+ # texts = []
1058
+ texts = text_prompt
1059
+ self.pMs = []
1060
+ for prompt in texts:
1061
+ txt, weight, stop = parse_prompt(prompt)
1062
+ clip_token = clip.tokenize(txt).to(device)
1063
+ for p in self.perceptors:
1064
+ embed = p["perceptor"].encode_text(clip_token).float()
1065
+ embed_normed = F.normalize(embed.unsqueeze(0), dim=2)
1066
+ p["prompts"].append(
1067
+ {
1068
+ "embed_normed": embed_normed,
1069
+ "weight": torch.as_tensor(weight, device=device),
1070
+ "stop": torch.as_tensor(stop, device=device),
1071
+ }
1072
+ )
1073
+
1074
+ # Prep Augments
1075
+ self.normalize = transforms.Normalize(
1076
+ mean=[0.48145466, 0.4578275, 0.40821073],
1077
+ std=[0.26862954, 0.26130258, 0.27577711],
1078
+ )
1079
+
1080
+ self.augs = nn.Sequential(
1081
+ K.RandomHorizontalFlip(p=0.5),
1082
+ K.RandomSharpness(0.3, p=0.1),
1083
+ K.RandomAffine(
1084
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
1085
+ ), # padding_mode=2
1086
+ K.RandomPerspective(
1087
+ 0.2,
1088
+ p=0.4,
1089
+ ),
1090
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
1091
+ K.RandomGrayscale(p=0.15),
1092
+ )
1093
+ self.noise_fac = 0.1
1094
+
1095
+ self.clip_weight = clip_weight
1096
+
1097
+ def prepare_cuts(self, img):
1098
+ cutouts = self.make_cuts(img)
1099
+ cutouts = self.augs(cutouts)
1100
+ if self.noise_fac:
1101
+ facs = cutouts.new_empty([cutouts.shape[0], 1, 1, 1]).uniform_(
1102
+ 0, self.noise_fac
1103
+ )
1104
+ cutouts = cutouts + facs * torch.randn_like(cutouts)
1105
+ cutouts = self.normalize(cutouts)
1106
+ return cutouts
1107
+
1108
+ def forward(self, i, img):
1109
+ cutouts = checkpoint(self.prepare_cuts, img)
1110
+ loss = []
1111
+
1112
+ current_cuts = cutouts
1113
+ currentres = self.max_cut_size
1114
+ for p in self.perceptors:
1115
+ if currentres != p["res"]:
1116
+ current_cuts = resample(cutouts, (p["res"], p["res"]))
1117
+ currentres = p["res"]
1118
+
1119
+ iii = p["perceptor"].encode_image(current_cuts).float()
1120
+ input_normed = F.normalize(iii.unsqueeze(1), dim=2)
1121
+ for prompt in p["prompts"]:
1122
+ dists = (
1123
+ input_normed.sub(prompt["embed_normed"])
1124
+ .norm(dim=2)
1125
+ .div(2)
1126
+ .arcsin()
1127
+ .pow(2)
1128
+ .mul(2)
1129
+ )
1130
+ dists = dists * prompt["weight"].sign()
1131
+ l = (
1132
+ prompt["weight"].abs()
1133
+ * replace_grad(
1134
+ dists, torch.maximum(dists, prompt["stop"])
1135
+ ).mean()
1136
+ )
1137
+ loss.append(l * p["weight"])
1138
+
1139
+ return loss
1140
+
1141
+ class ModelHost:
1142
+ def __init__(self, args):
1143
+ self.args = args
1144
+ self.model, self.perceptor = None, None
1145
+ self.make_cutouts = None
1146
+ self.alt_make_cutouts = None
1147
+ self.imageSize = None
1148
+ self.prompts = None
1149
+ self.opt = None
1150
+ self.normalize = None
1151
+ self.z, self.z_orig, self.z_min, self.z_max = None, None, None, None
1152
+ self.metadata = None
1153
+ self.mse_weight = 0
1154
+ self.normal_flip_optim = None
1155
+ self.usealtprompts = False
1156
+
1157
+ def setup_metadata(self, seed):
1158
+ metadata = {k: v for k, v in vars(self.args).items()}
1159
+ del metadata["max_iterations"]
1160
+ del metadata["display_freq"]
1161
+ metadata["seed"] = seed
1162
+ if metadata["init_image"]:
1163
+ path = metadata["init_image"]
1164
+ digest = get_digest(path)
1165
+ metadata["init_image"] = (path, digest)
1166
+ if metadata["image_prompts"]:
1167
+ prompts = []
1168
+ for prompt in metadata["image_prompts"]:
1169
+ path = prompt
1170
+ digest = get_digest(path)
1171
+ prompts.append((path, digest))
1172
+ metadata["image_prompts"] = prompts
1173
+ self.metadata = metadata
1174
+
1175
+ def setup_model(self, x):
1176
+ i = x
1177
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1178
+
1179
+ #perceptor = (
1180
+ # clip.load(args.clip_model, jit=False)[0]
1181
+ # .eval()
1182
+ # .requires_grad_(False)
1183
+ # .to(device)
1184
+ #)
1185
+
1186
+ cut_size = perceptor.visual.input_resolution
1187
+
1188
+ if self.args.is_gumbel:
1189
+ e_dim = model.quantize.embedding_dim
1190
+ else:
1191
+ e_dim = model.quantize.e_dim
1192
 
1193
+ f = 2 ** (model.decoder.num_resolutions - 1)
1194
+
1195
+ make_cutouts = flavordict[flavor](
1196
+ cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow, augs=args.augs
1197
+ )
1198
+
1199
+ # make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.mse_cut_pow,augs=args.augs)
1200
+ if args.altprompts:
1201
+ self.usealtprompts = True
1202
+ self.alt_make_cutouts = flavordict[flavor](
1203
+ cut_size,
1204
+ args.mse_cutn,
1205
+ cut_pow=args.alt_mse_cut_pow,
1206
+ augs=args.altaugs,
1207
+ )
1208
+ # self.alt_make_cutouts = MakeCutouts(cut_size, args.mse_cutn, cut_pow=args.alt_mse_cut_pow,augs=args.altaugs)
1209
+
1210
+ if self.args.is_gumbel:
1211
+ n_toks = model.quantize.n_embed
1212
+ else:
1213
+ n_toks = model.quantize.n_e
1214
+
1215
+ toksX, toksY = args.size[0] // f, args.size[1] // f
1216
+ sideX, sideY = toksX * f, toksY * f
1217
+
1218
+ if self.args.is_gumbel:
1219
+ z_min = model.quantize.embed.weight.min(dim=0).values[
1220
+ None, :, None, None
1221
+ ]
1222
+ z_max = model.quantize.embed.weight.max(dim=0).values[
1223
+ None, :, None, None
1224
+ ]
1225
+ else:
1226
+ z_min = model.quantize.embedding.weight.min(dim=0).values[
1227
+ None, :, None, None
1228
+ ]
1229
+ z_max = model.quantize.embedding.weight.max(dim=0).values[
1230
+ None, :, None, None
1231
+ ]
1232
+
1233
+ from PIL import Image
1234
+ import cv2
1235
+
1236
+ # -------
1237
+ working_dir = self.args.folder_name
1238
+
1239
+ if self.args.init_image != "":
1240
+ img_0 = cv2.imread(init_image)
1241
+ z, *_ = model.encode(
1242
+ TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
1243
+ )
1244
+ elif not os.path.isfile(f"{working_dir}/steps/{i:04d}.png"):
1245
+ one_hot = F.one_hot(
1246
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1247
+ ).float()
1248
+ if self.args.is_gumbel:
1249
+ z = one_hot @ model.quantize.embed.weight
1250
+ else:
1251
+ z = one_hot @ model.quantize.embedding.weight
1252
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1253
+ else:
1254
+ center = (1 * img_0.shape[1] // 2, 1 * img_0.shape[0] // 2)
1255
+ trans_mat = np.float32([[1, 0, 10], [0, 1, 10]])
1256
+ rot_mat = cv2.getRotationMatrix2D(center, 10, 20)
1257
+
1258
+ trans_mat = np.vstack([trans_mat, [0, 0, 1]])
1259
+ rot_mat = np.vstack([rot_mat, [0, 0, 1]])
1260
+ transformation_matrix = np.matmul(rot_mat, trans_mat)
1261
+
1262
+ img_0 = cv2.warpPerspective(
1263
+ img_0,
1264
+ transformation_matrix,
1265
+ (img_0.shape[1], img_0.shape[0]),
1266
+ borderMode=cv2.BORDER_WRAP,
1267
+ )
1268
+ z, *_ = model.encode(
1269
+ TF.to_tensor(img_0).to(device).unsqueeze(0) * 2 - 1
1270
+ )
1271
+
1272
+ def save_output(i, img, suffix="zoomed"):
1273
+ filename = f"{working_dir}/steps/{i:04}{'_' + suffix if suffix else ''}.png"
1274
+ imageio.imwrite(filename, np.array(img))
1275
+
1276
+ save_output(i, img_0)
1277
+ # -------
1278
+ if args.init_image:
1279
+ pil_image = Image.open(args.init_image).convert("RGB")
1280
+ pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
1281
+ z, *_ = model.encode(
1282
+ TF.to_tensor(pil_image).to(device).unsqueeze(0) * 2 - 1
1283
+ )
1284
+ else:
1285
+ one_hot = F.one_hot(
1286
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1287
+ ).float()
1288
+ if self.args.is_gumbel:
1289
+ z = one_hot @ model.quantize.embed.weight
1290
+ else:
1291
+ z = one_hot @ model.quantize.embedding.weight
1292
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1293
+ z = EMATensor(z, args.ema_val)
1294
+
1295
+ if args.mse_with_zeros and not args.init_image:
1296
+ z_orig = torch.zeros_like(z.tensor)
1297
+ else:
1298
+ z_orig = z.tensor.clone()
1299
+ z.requires_grad_(True)
1300
+ # opt = optim.AdamW(z.parameters(), lr=args.mse_step_size, weight_decay=0.00000000)
1301
+ if self.normal_flip_optim == True:
1302
+ if randint(1, 2) == 1:
1303
+ opt = torch.optim.AdamW(
1304
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1305
+ )
1306
+ # opt = Ranger21(z.parameters(), lr=args.step_size, weight_decay=0.00000000)
1307
+ else:
1308
+ opt = optim.DiffGrad(
1309
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1310
+ )
1311
+ else:
1312
+ opt = torch.optim.AdamW(
1313
+ z.parameters(), lr=args.step_size, weight_decay=0.00000000
1314
+ )
1315
+
1316
+ self.cur_step_size = args.mse_step_size
1317
+
1318
+ normalize = transforms.Normalize(
1319
+ mean=[0.48145466, 0.4578275, 0.40821073],
1320
+ std=[0.26862954, 0.26130258, 0.27577711],
1321
+ )
1322
+
1323
+ pMs = []
1324
+ altpMs = []
1325
+
1326
+ for prompt in args.prompts:
1327
+ txt, weight, stop = parse_prompt(prompt)
1328
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
1329
+ pMs.append(Prompt(embed, weight, stop).to(device))
1330
+
1331
+ for prompt in args.altprompts:
1332
+ txt, weight, stop = parse_prompt(prompt)
1333
+ embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
1334
+ altpMs.append(Prompt(embed, weight, stop).to(device))
1335
+
1336
+ from PIL import Image
1337
+
1338
+ for prompt in args.image_prompts:
1339
+ path, weight, stop = parse_prompt(prompt)
1340
+ img = resize_image(Image.open(path).convert("RGB"), (sideX, sideY))
1341
+ batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
1342
+ embed = perceptor.encode_image(normalize(batch)).float()
1343
+ pMs.append(Prompt(embed, weight, stop).to(device))
1344
+
1345
+ for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
1346
+ gen = torch.Generator().manual_seed(seed)
1347
+ embed = torch.empty([1, perceptor.visual.output_dim]).normal_(
1348
+ generator=gen
1349
+ )
1350
+ pMs.append(Prompt(embed, weight).to(device))
1351
+ if self.usealtprompts:
1352
+ altpMs.append(Prompt(embed, weight).to(device))
1353
+
1354
+ self.model, self.perceptor = model, perceptor
1355
+ self.make_cutouts = make_cutouts
1356
+ self.imageSize = (sideX, sideY)
1357
+ self.prompts = pMs
1358
+ self.altprompts = altpMs
1359
+ self.opt = opt
1360
+ self.normalize = normalize
1361
+ self.z, self.z_orig, self.z_min, self.z_max = z, z_orig, z_min, z_max
1362
+ self.setup_metadata(args2.seed)
1363
+ self.mse_weight = self.args.init_weight
1364
+
1365
+ def synth(self, z):
1366
+ if self.args.is_gumbel:
1367
+ z_q = vector_quantize(
1368
+ z.movedim(1, 3), self.model.quantize.embed.weight
1369
+ ).movedim(3, 1)
1370
+ else:
1371
+ z_q = vector_quantize(
1372
+ z.movedim(1, 3), self.model.quantize.embedding.weight
1373
+ ).movedim(3, 1)
1374
+ return clamp_with_grad(self.model.decode(z_q).add(1).div(2), 0, 1)
1375
+
1376
+ def add_metadata(self, path, i):
1377
+ imfile = PngImageFile(path)
1378
+ meta = PngInfo()
1379
+ step_meta = {"iterations": i}
1380
+ step_meta.update(self.metadata)
1381
+ # meta.add_itxt('vqgan-params', json.dumps(step_meta), zip=True)
1382
+ imfile.save(path, pnginfo=meta)
1383
+ # Hey you. This one's for Glooperpogger#7353 on Discord (Gloop has a gun), they are a nice snek
1384
+
1385
+ @torch.no_grad()
1386
+ def checkin(self, i, losses, x):
1387
+ out = self.synth(self.z.average)
1388
+
1389
+ batchpath = "./"
1390
+ TF.to_pil_image(out[0].cpu()).save(args2.image_file)
1391
+
1392
+ def unique_index(self, batchpath):
1393
+ i = 0
1394
+ while i < 10000:
1395
+ if os.path.isfile(batchpath + "/" + str(i) + ".png"):
1396
+ i = i + 1
1397
+ else:
1398
+ return batchpath + "/" + str(i) + ".png"
1399
+
1400
+ def ascend_txt(self, i):
1401
+ out = self.synth(self.z.tensor)
1402
+ iii = self.perceptor.encode_image(
1403
+ self.normalize(self.make_cutouts(out))
1404
+ ).float()
1405
+
1406
+ result = []
1407
+ if self.args.init_weight and self.mse_weight > 0:
1408
+ result.append(
1409
+ F.mse_loss(self.z.tensor, self.z_orig) * self.mse_weight / 2
1410
+ )
1411
+
1412
+ for prompt in self.prompts:
1413
+ result.append(prompt(iii))
1414
+
1415
+ if self.usealtprompts:
1416
+ iii = self.perceptor.encode_image(
1417
+ self.normalize(self.alt_make_cutouts(out))
1418
+ ).float()
1419
+ for prompt in self.altprompts:
1420
+ result.append(prompt(iii))
1421
+
1422
+ return result
1423
+
1424
+ def train(self, i, x):
1425
+ self.opt.zero_grad()
1426
+ mse_decay = self.args.mse_decay
1427
+ mse_decay_rate = self.args.mse_decay_rate
1428
+ lossAll = self.ascend_txt(i)
1429
+
1430
+ sys.stdout.write("Iteration {}".format(i) + "\n")
1431
+ sys.stdout.flush()
1432
+
1433
+ if i % args2.update == 0:
1434
+ self.checkin(i, lossAll, x)
1435
+
1436
+ loss = sum(lossAll)
1437
+ loss.backward()
1438
+ self.opt.step()
1439
+ with torch.no_grad():
1440
+ if (
1441
+ self.mse_weight > 0
1442
+ and self.args.init_weight
1443
+ and i > 0
1444
+ and i % mse_decay_rate == 0
1445
+ ):
1446
+ if self.args.is_gumbel:
1447
+ self.z_orig = vector_quantize(
1448
+ self.z.average.movedim(1, 3),
1449
+ self.model.quantize.embed.weight,
1450
+ ).movedim(3, 1)
1451
+ else:
1452
+ self.z_orig = vector_quantize(
1453
+ self.z.average.movedim(1, 3),
1454
+ self.model.quantize.embedding.weight,
1455
+ ).movedim(3, 1)
1456
+ if self.mse_weight - mse_decay > 0:
1457
+ self.mse_weight = self.mse_weight - mse_decay
1458
+ # print(f"updated mse weight: {self.mse_weight}")
1459
+ else:
1460
+ self.mse_weight = 0
1461
+ self.make_cutouts = flavordict[flavor](
1462
+ self.perceptor.visual.input_resolution,
1463
+ args.cutn,
1464
+ cut_pow=args.cut_pow,
1465
+ augs=args.augs,
1466
+ )
1467
+ if self.usealtprompts:
1468
+ self.alt_make_cutouts = flavordict[flavor](
1469
+ self.perceptor.visual.input_resolution,
1470
+ args.cutn,
1471
+ cut_pow=args.alt_cut_pow,
1472
+ augs=args.altaugs,
1473
+ )
1474
+ self.z = EMATensor(self.z.average, args.ema_val)
1475
+ self.new_step_size = args.step_size
1476
+ self.opt = torch.optim.AdamW(
1477
+ self.z.parameters(),
1478
+ lr=args.step_size,
1479
+ weight_decay=0.00000000,
1480
+ )
1481
+ # print(f"updated mse weight: {self.mse_weight}")
1482
+ if i > args.mse_end:
1483
+ if (
1484
+ args.step_size != args.final_step_size
1485
+ and args.max_iterations > 0
1486
+ ):
1487
+ progress = (i - args.mse_end) / (args.max_iterations)
1488
+ self.cur_step_size = lerp(step_size, final_step_size, progress)
1489
+ for g in self.opt.param_groups:
1490
+ g["lr"] = self.cur_step_size
1491
+
1492
+ def run(self, x):
1493
+ j = 0
1494
+ try:
1495
+ before_start_time = time.perf_counter()
1496
+ total_steps = int(args.max_iterations + args.mse_end) - 1
1497
+ for _ in range(total_steps):
1498
+ self.train(j, x)
1499
+ if j > 0 and j % args.mse_decay_rate == 0 and self.mse_weight > 0:
1500
+ self.z = EMATensor(self.z.average, args.ema_val)
1501
+ self.opt = torch.optim.AdamW(
1502
+ self.z.parameters(),
1503
+ lr=args.mse_step_size,
1504
+ weight_decay=0.00000000,
1505
+ )
1506
+ if j >= total_steps:
1507
+ break
1508
+ self.z.update()
1509
+ j += 1
1510
+ time_past_seconds = time.perf_counter() - before_start_time
1511
+ iterations_per_second = j / time_past_seconds
1512
+ time_left = (total_steps - j) / iterations_per_second
1513
+ percentage = round((j / (total_steps + 1)) * 100)
1514
+
1515
+ import shutil
1516
+ import os
1517
+
1518
+ image_data = Image.open(args2.image_file)
1519
+ return(image_data)
1520
+
1521
+ except KeyboardInterrupt:
1522
+ pass
1523
+ except st.script_runner.StopException as e:
1524
+ torch.cuda.empty_cache()
1525
+ pass
1526
+ return j
1527
+
1528
+ def add_noise(img):
1529
+
1530
+ # Getting the dimensions of the image
1531
+ row, col = img.shape
1532
+
1533
+ # Randomly pick some pixels in the
1534
+ # image for coloring them white
1535
+ # Pick a random number between 300 and 10000
1536
+ number_of_pixels = random.randint(300, 10000)
1537
+ for i in range(number_of_pixels):
1538
+
1539
+ # Pick a random y coordinate
1540
+ y_coord = random.randint(0, row - 1)
1541
+
1542
+ # Pick a random x coordinate
1543
+ x_coord = random.randint(0, col - 1)
1544
+
1545
+ # Color that pixel to white
1546
+ img[y_coord][x_coord] = 255
1547
+
1548
+ # Randomly pick some pixels in
1549
+ # the image for coloring them black
1550
+ # Pick a random number between 300 and 10000
1551
+ number_of_pixels = random.randint(300, 10000)
1552
+ for i in range(number_of_pixels):
1553
+
1554
+ # Pick a random y coordinate
1555
+ y_coord = random.randint(0, row - 1)
1556
+
1557
+ # Pick a random x coordinate
1558
+ x_coord = random.randint(0, col - 1)
1559
+
1560
+ # Color that pixel to black
1561
+ img[y_coord][x_coord] = 0
1562
+
1563
+ return img
1564
+
1565
+ import io
1566
+ import base64
1567
+
1568
+ def image_to_data_url(img, ext):
1569
+ img_byte_arr = io.BytesIO()
1570
+ img.save(img_byte_arr, format=ext)
1571
+ img_byte_arr = img_byte_arr.getvalue()
1572
+ # ext = filename.split('.')[-1]
1573
+ prefix = f"data:image/{ext};base64,"
1574
+ return prefix + base64.b64encode(img_byte_arr).decode("utf-8")
1575
+
1576
+ import torch
1577
+ import math
1578
+
1579
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1580
+
1581
+ def rand_perlin_2d(
1582
+ shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3
1583
+ ):
1584
+ delta = (res[0] / shape[0], res[1] / shape[1])
1585
+ d = (shape[0] // res[0], shape[1] // res[1])
1586
+
1587
+ grid = (
1588
+ torch.stack(
1589
+ torch.meshgrid(
1590
+ torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
1591
+ ),
1592
+ dim=-1,
1593
+ )
1594
+ % 1
1595
+ )
1596
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
1597
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
1598
+
1599
+ tile_grads = (
1600
+ lambda slice1, slice2: gradients[
1601
+ slice1[0] : slice1[1], slice2[0] : slice2[1]
1602
+ ]
1603
+ .repeat_interleave(d[0], 0)
1604
+ .repeat_interleave(d[1], 1)
1605
+ )
1606
+ dot = lambda grad, shift: (
1607
+ torch.stack(
1608
+ (
1609
+ grid[: shape[0], : shape[1], 0] + shift[0],
1610
+ grid[: shape[0], : shape[1], 1] + shift[1],
1611
+ ),
1612
+ dim=-1,
1613
+ )
1614
+ * grad[: shape[0], : shape[1]]
1615
+ ).sum(dim=-1)
1616
+
1617
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
1618
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
1619
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
1620
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
1621
+ t = fade(grid[: shape[0], : shape[1]])
1622
+ return math.sqrt(2) * torch.lerp(
1623
+ torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
1624
+ )
1625
+
1626
+ def rand_perlin_2d_octaves(desired_shape, octaves=1, persistence=0.5):
1627
+ shape = torch.tensor(desired_shape)
1628
+ shape = 2 ** torch.ceil(torch.log2(shape))
1629
+ shape = shape.type(torch.int)
1630
+
1631
+ max_octaves = int(
1632
+ min(
1633
+ octaves,
1634
+ math.log(shape[0]) / math.log(2),
1635
+ math.log(shape[1]) / math.log(2),
1636
+ )
1637
+ )
1638
+ res = torch.floor(shape / 2**max_octaves).type(torch.int)
1639
+
1640
+ noise = torch.zeros(list(shape))
1641
+ frequency = 1
1642
+ amplitude = 1
1643
+ for _ in range(max_octaves):
1644
+ noise += amplitude * rand_perlin_2d(
1645
+ shape, (frequency * res[0], frequency * res[1])
1646
+ )
1647
+ frequency *= 2
1648
+ amplitude *= persistence
1649
+
1650
+ return noise[: desired_shape[0], : desired_shape[1]]
1651
+
1652
+ def rand_perlin_rgb(desired_shape, amp=0.1, octaves=6):
1653
+ r = rand_perlin_2d_octaves(desired_shape, octaves)
1654
+ g = rand_perlin_2d_octaves(desired_shape, octaves)
1655
+ b = rand_perlin_2d_octaves(desired_shape, octaves)
1656
+ rgb = (torch.stack((r, g, b)) * amp + 1) * 0.5
1657
+ return rgb.unsqueeze(0).clip(0, 1).to(device)
1658
+
1659
+ def pyramid_noise_gen(shape, octaves=5, decay=1.0):
1660
+ n, c, h, w = shape
1661
+ noise = torch.zeros([n, c, 1, 1])
1662
+ max_octaves = int(min(math.log(h) / math.log(2), math.log(w) / math.log(2)))
1663
+ if octaves is not None and 0 < octaves:
1664
+ max_octaves = min(octaves, max_octaves)
1665
+ for i in reversed(range(max_octaves)):
1666
+ h_cur, w_cur = h // 2**i, w // 2**i
1667
+ noise = F.interpolate(
1668
+ noise, (h_cur, w_cur), mode="bicubic", align_corners=False
1669
+ )
1670
+ noise += (torch.randn([n, c, h_cur, w_cur]) / max_octaves) * decay ** (
1671
+ max_octaves - (i + 1)
1672
+ )
1673
+ return noise
1674
+
1675
+ def rand_z(model, toksX, toksY):
1676
+ e_dim = model.quantize.e_dim
1677
+ n_toks = model.quantize.n_e
1678
+ z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
1679
+ z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
1680
+
1681
+ one_hot = F.one_hot(
1682
+ torch.randint(n_toks, [toksY * toksX], device=device), n_toks
1683
+ ).float()
1684
+ z = one_hot @ model.quantize.embedding.weight
1685
+ z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
1686
+
1687
+ return z
1688
+
1689
+ def make_rand_init(
1690
+ mode,
1691
+ model,
1692
+ perlin_octaves,
1693
+ perlin_weight,
1694
+ pyramid_octaves,
1695
+ pyramid_decay,
1696
+ toksX,
1697
+ toksY,
1698
+ f,
1699
+ ):
1700
+
1701
+ if mode == "VQGAN ZRand":
1702
+ return rand_z(model, toksX, toksY)
1703
+ elif mode == "Perlin Noise":
1704
+ rand_init = rand_perlin_rgb(
1705
+ (toksY * f, toksX * f), perlin_weight, perlin_octaves
1706
+ )
1707
+ z, *_ = model.encode(rand_init * 2 - 1)
1708
+ return z
1709
+ elif mode == "Pyramid Noise":
1710
+ rand_init = pyramid_noise_gen(
1711
+ (1, 3, toksY * f, toksX * f), pyramid_octaves, pyramid_decay
1712
+ ).to(device)
1713
+ rand_init = (rand_init * 0.5 + 0.5).clip(0, 1)
1714
+ z, *_ = model.encode(rand_init * 2 - 1)
1715
+ return z
1716
+
1717
+ ##################### JUICY MESS ###################################
1718
+ import os
1719
+
1720
+ imagenet_1024 = False # @param {type:"boolean"}
1721
+ imagenet_16384 = True # @param {type:"boolean"}
1722
+ gumbel_8192 = False # @param {type:"boolean"}
1723
+ sber_gumbel = False # @param {type:"boolean"}
1724
+ # imagenet_cin = False #@param {type:"boolean"}
1725
+ coco = False # @param {type:"boolean"}
1726
+ coco_1stage = False # @param {type:"boolean"}
1727
+ faceshq = False # @param {type:"boolean"}
1728
+ wikiart_1024 = False # @param {type:"boolean"}
1729
+ wikiart_16384 = False # @param {type:"boolean"}
1730
+ wikiart_7mil = False # @param {type:"boolean"}
1731
+ sflckr = False # @param {type:"boolean"}
1732
+
1733
+ ##@markdown Experimental models (won't probably work, if you know how to make them work, go ahead :D):
1734
+ # celebahq = False #@param {type:"boolean"}
1735
+ # ade20k = False #@param {type:"boolean"}
1736
+ # drin = False #@param {type:"boolean"}
1737
+ # gumbel = False #@param {type:"boolean"}
1738
+ # gumbel_8192 = False #@param {type:"boolean"}
1739
+
1740
+ # Configure and run the model"""
1741
+
1742
+ # Commented out IPython magic to ensure Python compatibility.
1743
+ # @title <font color="lightgreen" size="+3">←</font> <font size="+2">🏃‍♂️</font> **Configure & Run** <font size="+2">🏃‍♂️</font>
1744
+
1745
+ import os
1746
+ import random
1747
+ import cv2
1748
+
1749
+ # from google.colab import drive
1750
+ from PIL import Image
1751
+ from importlib import reload
1752
+
1753
+ reload(PIL.TiffTags)
1754
+ # %cd /content/
1755
+ # @markdown >`prompts` is the list of prompts to give to the AI, separated by `|`. With more than one, it will attempt to mix them together. You can add weights to different parts of the prompt by adding a `p:x` at the end of a prompt (before a `|`) where `p` is the prompt and `x` is the weight.
1756
+
1757
+ # prompts = "A fantasy landscape, by Greg Rutkowski. A lush mountain.:1 | Trending on ArtStation, unreal engine. 4K HD, realism.:0.63" #@param {type:"string"}
1758
+
1759
+ prompts = args2.prompt
1760
+
1761
+ width = args2.sizex # @param {type:"number"}
1762
+ height = args2.sizey # @param {type:"number"}
1763
+
1764
+ # model = "ImageNet 16384" #@param ['ImageNet 16384', 'ImageNet 1024', "Gumbel 8192", "Sber Gumbel", 'WikiArt 1024', 'WikiArt 16384', 'WikiArt 7mil', 'COCO-Stuff', 'COCO 1 Stage', 'FacesHQ', 'S-FLCKR']
1765
+ #model = args2.vqgan_model
1766
+
1767
+ #if model == "Gumbel 8192" or model == "Sber Gumbel":
1768
+ # is_gumbel = True
1769
+ #else:
1770
+ # is_gumbel = False
1771
+ is_gumbel = False
1772
+ ##@markdown The flavor effects the output greatly. Each has it's own characteristics and depending on what you choose, you'll get a widely different result with the same prompt and seed. Ginger is the default, nothing special. Cumin results more of a painting, while Holywater makes everythng super funky and/or colorful. Custom is a custom flavor, use the utilities above.
1773
+ # Type "old_holywater" to use the old holywater flavor from Hypertron V1
1774
+ flavor = (
1775
+ args2.flavor
1776
+ ) #'ginger' #@param ["ginger", "cumin", "holywater", "zynth", "wyvern", "aaron", "moth", "juu", "custom"]
1777
+ template = (
1778
+ args2.template
1779
+ ) # @param ["none", "----------Parameter Tweaking----------", "Balanced", "Detailed", "Consistent Creativity", "Realistic", "Smooth", "Subtle MSE", "Hyper Fast Results", "----------Complete Overhaul----------", "flag", "planet", "creature", "human", "----------Sizes----------", "Size: Square", "Size: Landscape", "Size: Poster", "----------Prompt Modifiers----------", "Better - Fast", "Better - Slow", "Movie Poster", "Negative Prompt", "Better Quality"]
1780
+ ##@markdown To use initial or target images, upload it on the left in the file browser. You can also use previous outputs by putting its path below, e.g. `batch_01/0.png`. If your previous output is saved to drive, you can use the checkbox so you don't have to type the whole path.
1781
+ init = "default noise" # @param ["default noise", "image", "random image", "salt and pepper noise", "salt and pepper noise on init image"]
1782
+
1783
+ if args2.seed_image is None:
1784
+ init_image = "" # args2.seed_image #""#@param {type:"string"}
1785
  else:
1786
+ init_image = args2.seed_image # ""#@param {type:"string"}
1787
+
1788
+ if init == "random image":
1789
+ url = (
1790
+ "https://picsum.photos/"
1791
+ + str(width)
1792
+ + "/"
1793
+ + str(height)
1794
+ + "?blur="
1795
+ + str(random.randrange(5, 10))
1796
+ )
1797
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1798
+ init_image = "Init_Img/Image.png"
1799
+ elif init == "random image clear":
1800
+ url = "https://source.unsplash.com/random/" + str(width) + "x" + str(height)
1801
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1802
+ init_image = "Init_Img/Image.png"
1803
+ elif init == "random image clear 2":
1804
+ url = "https://loremflickr.com/" + str(width) + "/" + str(height)
1805
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1806
+ init_image = "Init_Img/Image.png"
1807
+ elif init == "salt and pepper noise":
1808
+ urllib.request.urlretrieve(
1809
+ "https://i.stack.imgur.com/olrL8.png", "Init_Img/Image.png"
1810
+ )
1811
+ import cv2
1812
+
1813
+ img = cv2.imread("Init_Img/Image.png", 0)
1814
+ cv2.imwrite("Init_Img/Image.png", add_noise(img))
1815
+ init_image = "Init_Img/Image.png"
1816
+ elif init == "salt and pepper noise on init image":
1817
+ img = cv2.imread(init_image, 0)
1818
+ cv2.imwrite("Init_Img/Image.png", add_noise(img))
1819
+ init_image = "Init_Img/Image.png"
1820
+ elif init == "perlin noise":
1821
+ # For some reason Colab started crashing from this
1822
+ import noise
1823
+ import numpy as np
1824
+ from PIL import Image
1825
+
1826
+ shape = (width, height)
1827
+ scale = 100
1828
+ octaves = 6
1829
+ persistence = 0.5
1830
+ lacunarity = 2.0
1831
+ seed = np.random.randint(0, 100000)
1832
+ world = np.zeros(shape)
1833
+ for i in range(shape[0]):
1834
+ for j in range(shape[1]):
1835
+ world[i][j] = noise.pnoise2(
1836
+ i / scale,
1837
+ j / scale,
1838
+ octaves=octaves,
1839
+ persistence=persistence,
1840
+ lacunarity=lacunarity,
1841
+ repeatx=1024,
1842
+ repeaty=1024,
1843
+ base=seed,
1844
+ )
1845
+ Image.fromarray(prep_world(world)).convert("L").save("Init_Img/Image.png")
1846
+ init_image = "Init_Img/Image.png"
1847
+ elif init == "black and white":
1848
+ url = "https://www.random.org/bitmaps/?format=png&width=300&height=300&zoom=1"
1849
+ urllib.request.urlretrieve(url, "Init_Img/Image.png")
1850
+ init_image = "Init_Img/Image.png"
1851
+
1852
+ seed = args2.seed # @param {type:"number"}
1853
+ # @markdown >iterations excludes iterations spent during the mse phase, if it is being used. The total iterations will be more if `mse_decay_rate` is more than 0.
1854
+ iterations = args2.iterations # @param {type:"number"}
1855
+ transparent_png = False # @param {type:"boolean"}
1856
+
1857
+ # @markdown <font size="+3">⚠</font> **ADVANCED SETTINGS** <font size="+3">⚠</font>
1858
+ # @markdown ---
1859
+ # @markdown ---
1860
+
1861
+ # @markdown >If you want to make multiple images with different prompts, use this. Seperate different prompts for different images with a `~` (example: `prompt1~prompt1~prompt3`). Iter is the iterations you want each image to run for. If you use MSE, I'd type a pretty low number (about 10).
1862
+ multiple_prompt_batches = False # @param {type:"boolean"}
1863
+ multiple_prompt_batches_iter = 300 # @param {type:"number"}
1864
+
1865
+ # @markdown >`folder_name` is the name of the folder you want to output your result(s) to. Previous outputs will NOT be overwritten. By default, it will be saved to the colab's root folder, but the `save_to_drive` checkbox will save it to `MyDrive\VQGAN_Output` instead.
1866
+ folder_name = "" # @param {type:"string"}
1867
+ save_to_drive = False # @param {type:"boolean"}
1868
+ prompt_experiment = "None" # @param ['None', 'Fever Dream', 'Philipuss’s Basement', 'Vivid Turmoil', 'Mad Dad', 'Platinum', 'Negative Energy']
1869
+ if prompt_experiment == "Fever Dream":
1870
+ prompts = "<|startoftext|>" + prompts + "<|endoftext|>"
1871
+ elif prompt_experiment == "Vivid Turmoil":
1872
+ prompts = prompts.replace(" ", "¡")
1873
+ prompts = "¬" + prompts + "®"
1874
+ elif prompt_experiment == "Mad Dad":
1875
+ prompts = prompts.replace(" ", "\\s+")
1876
+ elif prompt_experiment == "Platinum":
1877
+ prompts = "~!" + prompts + "!~"
1878
+ prompts = prompts.replace(" ", "</w>")
1879
+ elif prompt_experiment == "Philipuss’s Basement":
1880
+ prompts = "<|startoftext|>" + prompts
1881
+ prompts = prompts.replace(" ", "<|endoftext|><|startoftext|>")
1882
+ elif prompt_experiment == "Lowercase":
1883
+ prompts = prompts.lower()
1884
+
1885
+
1886
+ # @markdown >Target images work like prompts, write the name of the image. You can add multiple target images by seperating them with a `|`.
1887
+ target_images = "" # @param {type:"string"}
1888
+
1889
+ # @markdown ><font size="+2">☢</font> Advanced values. Values of cut_pow below 1 prioritize structure over detail, and vice versa for above 1. Step_size affects how wild the change between iterations is, and if final_step_size is not 0, step_size will interpolate towards it over time.
1890
+ # @markdown >Cutn affects on 'Creativity': less cutout will lead to more random/creative results, sometimes barely readable, while higher values (90+) lead to very stable, photo-like outputs
1891
+ cutn = 130 # @param {type:"number"}
1892
+ cut_pow = 1 # @param {type:"number"}
1893
+ # @markdown >Step_size is like weirdness. Lower: more accurate/realistic, slower; Higher: less accurate/more funky, faster.
1894
+ step_size = 0.1 # @param {type:"number"}
1895
+ # @markdown >Start_step_size is a temporary step_size that will be active only in the first 10 iterations. It (sometimes) helps with speed. If it's set to 0, it won't be used.
1896
+ start_step_size = 0 # @param {type:"number"}
1897
+ # @markdown >Final_step_size is a goal step_size which the AI will try and reach. If set to 0, it won't be used.
1898
+ final_step_size = 0 # @param {type:"number"}
1899
+ if start_step_size <= 0:
1900
+ start_step_size = step_size
1901
+ if final_step_size <= 0:
1902
+ final_step_size = step_size
1903
+
1904
+ # @markdown ---
1905
+
1906
+ # @markdown >EMA maintains a moving average of trained parameters. The number below is the rate of decay (higher means slower).
1907
+ ema_val = 0.98 # @param {type:"number"}
1908
+
1909
+ # @markdown >If you want to keep starting from the same point, set `gen_seed` to a positive number. `-1` will make it random every time.
1910
+ gen_seed = -1 # @param {type:'number'}
1911
+
1912
+ init_image_in_drive = False # @param {type:"boolean"}
1913
+ if init_image_in_drive and init_image:
1914
+ init_image = "/content/drive/MyDrive/VQGAN_Output/" + init_image
1915
+
1916
+ images_interval = args2.update # @param {type:"number"}
1917
+
1918
+ # I think you should give "Free Thoughts on the Proceedings of the Continental Congress" a read, really funny and actually well-written, Hamilton presented it in a bad light IMO.
1919
+
1920
+ batch_size = 1 # @param {type:"number"}
1921
+
1922
+ # @markdown ---
1923
+
1924
+ # @markdown <font size="+1">🔮</font> **MSE Regulization** <font size="+1">🔮</font>
1925
+ # Based off of this notebook: https://colab.research.google.com/drive/1gFn9u3oPOgsNzJWEFmdK-N9h_y65b8fj?usp=sharing - already in credits
1926
+ use_mse = args2.mse # @param {type:"boolean"}
1927
+ mse_images_interval = images_interval
1928
+ mse_init_weight = 0.2 # @param {type:"number"}
1929
+ mse_decay_rate = 160 # @param {type:"number"}
1930
+ mse_epoches = 10 # @param {type:"number"}
1931
+ ##@param {type:"number"}
1932
+
1933
+ # @markdown >Overwrites the usual values during the mse phase if included. If any value is 0, its normal counterpart is used instead.
1934
+ mse_with_zeros = True # @param {type:"boolean"}
1935
+ mse_step_size = 0.87 # @param {type:"number"}
1936
+ mse_cutn = 42 # @param {type:"number"}
1937
+ mse_cut_pow = 0.75 # @param {type:"number"}
1938
+
1939
+ # @markdown >normal_flip_optim flips between two optimizers during the normal (not MSE) phase. It can improve quality, but it's kind of experimental, use at your own risk.
1940
+ normal_flip_optim = True # @param {type:"boolean"}
1941
+ ##@markdown >Adding some TV may make the image blurrier but also helps to get rid of noise. A good value to try might be 0.1.
1942
+ # tv_weight = 0.1 #@param {type:'number'}
1943
+ # @markdown ---
1944
+
1945
+ # @markdown >`altprompts` is a set of prompts that take in a different augmentation pipeline, and can have their own cut_pow. At the moment, the default "alt augment" settings flip the picture cutouts upside down before evaluating. This can be good for optical illusion images. If either cut_pow value is 0, it will use the same value as the normal prompts.
1946
+ altprompts = "" # @param {type:"string"}
1947
+ altprompt_mode = "flipped"
1948
+ ##@param ["normal" , "flipped", "sideways"]
1949
+ alt_cut_pow = 0 # @param {type:"number"}
1950
+ alt_mse_cut_pow = 0 # @param {type:"number"}
1951
+ # altprompt_type = "upside-down" #@param ['upside-down', 'as']
1952
+
1953
+ ##@markdown ---
1954
+ ##@markdown <font size="+1">💫</font> **Zooming and Moving** <font size="+1">💫</font>
1955
+ zoom = False
1956
+ ##@param {type:"boolean"}
1957
+ zoom_speed = 100
1958
+ ##@param {type:"number"}
1959
+ zoom_frequency = 20
1960
+ ##@param {type:"number"}
1961
+
1962
+ # @markdown ---
1963
+ # @markdown On an unrelated note, if you get any errors while running this, restart the runtime and run the first cell again. If that doesn't work either, message me on Discord (Philipuss#4066).
1964
+
1965
+ model_names = {
1966
+ "vqgan_imagenet_f16_16384": "vqgan_imagenet_f16_16384",
1967
+ "ImageNet 1024": "vqgan_imagenet_f16_1024",
1968
+ "Gumbel 8192": "gumbel_8192",
1969
+ "Sber Gumbel": "sber_gumbel",
1970
+ "imagenet_cin": "imagenet_cin",
1971
+ "WikiArt 1024": "wikiart_1024",
1972
+ "WikiArt 16384": "wikiart_16384",
1973
+ "COCO-Stuff": "coco",
1974
+ "FacesHQ": "faceshq",
1975
+ "S-FLCKR": "sflckr",
1976
+ "WikiArt 7mil": "wikiart_7mil",
1977
+ "COCO 1 Stage": "coco_1stage",
1978
+ }
1979
+
1980
+ if template == "Better - Fast":
1981
+ prompts = prompts + ". Detailed artwork. ArtStationHQ. unreal engine. 4K HD."
1982
+ elif template == "Better - Slow":
1983
+ prompts = (
1984
+ prompts
1985
+ + ". Detailed artwork. Trending on ArtStation. unreal engine. | Rendered in Maya. "
1986
+ + prompts
1987
+ + ". 4K HD."
1988
+ )
1989
+ elif template == "Movie Poster":
1990
+ prompts = prompts + ". Movie poster. Rendered in unreal engine. ArtStationHQ."
1991
+ width = 400
1992
+ height = 592
1993
+ elif template == "flag":
1994
+ prompts = (
1995
+ "A photo of a flag of the country "
1996
+ + prompts
1997
+ + " | Flag of "
1998
+ + prompts
1999
+ + ". White background."
2000
+ )
2001
+ # import cv2
2002
+ # img = cv2.imread('templates/flag.png', 0)
2003
+ # cv2.imwrite('templates/final_flag.png', add_noise(img))
2004
+ init_image = "templates/flag.png"
2005
+ transparent_png = True
2006
+ elif template == "planet":
2007
+ import cv2
2008
+
2009
+ img = cv2.imread("templates/planet.png", 0)
2010
+ cv2.imwrite("templates/final_planet.png", add_noise(img))
2011
+ prompts = (
2012
+ "A photo of the planet "
2013
+ + prompts
2014
+ + ". Planet in the middle with black background. | The planet of "
2015
+ + prompts
2016
+ + ". Photo of a planet. Black background. Trending on ArtStation. | Colorful."
2017
+ )
2018
+ init_image = "templates/final_planet.png"
2019
+ elif template == "creature":
2020
+ # import cv2
2021
+ # img = cv2.imread('templates/planet.png', 0)
2022
+ # cv2.imwrite('templates/final_planet.png', add_noise(img))
2023
+ prompts = (
2024
+ "A photo of a creature with "
2025
+ + prompts
2026
+ + ". Animal in the middle with white background. | The creature has "
2027
+ + prompts
2028
+ + ". Photo of a creature/animal. White background. Detailed image of a creature. | White background."
2029
+ )
2030
+ init_image = "templates/creature.png"
2031
+ # transparent_png = True
2032
+ elif template == "Detailed":
2033
+ prompts = (
2034
+ prompts
2035
+ + ", by Puer Udger. Detailed artwork, trending on artstation. 4K HD, realism."
2036
+ )
2037
+ flavor = "cumin"
2038
+ elif template == "human":
2039
+ init_image = "/content/templates/human.png"
2040
+ elif template == "Realistic":
2041
+ cutn = 200
2042
+ step_size = 0.03
2043
+ cut_pow = 0.2
2044
+ flavor = "holywater"
2045
+ elif template == "Consistent Creativity":
2046
+ flavor = "cumin"
2047
+ cut_pow = 0.01
2048
+ cutn = 136
2049
+ step_size = 0.08
2050
+ mse_step_size = 0.41
2051
+ mse_cut_pow = 0.3
2052
+ ema_val = 0.99
2053
+ normal_flip_optim = False
2054
+ elif template == "Smooth":
2055
+ flavor = "wyvern"
2056
+ step_size = 0.10
2057
+ cutn = 120
2058
+ normal_flip_optim = False
2059
+ tv_weight = 10
2060
+ elif template == "Subtle MSE":
2061
+ mse_init_weight = 0.07
2062
+ mse_decay_rate = 130
2063
+ mse_step_size = 0.2
2064
+ mse_cutn = 100
2065
+ mse_cut_pow = 0.6
2066
+ elif template == "Balanced":
2067
+ cutn = 130
2068
+ cut_pow = 1
2069
+ step_size = 0.16
2070
+ final_step_size = 0
2071
+ ema_val = 0.98
2072
+ mse_init_weight = 0.2
2073
+ mse_decay_rate = 130
2074
+ mse_with_zeros = True
2075
+ mse_step_size = 0.9
2076
+ mse_cutn = 50
2077
+ mse_cut_pow = 0.8
2078
+ normal_flip_optim = True
2079
+ elif template == "Size: Square":
2080
+ width = 450
2081
+ height = 450
2082
+ elif template == "Size: Landscape":
2083
+ width = 480
2084
+ height = 336
2085
+ elif template == "Size: Poster":
2086
+ width = 336
2087
+ height = 480
2088
+ elif template == "Negative Prompt":
2089
+ prompts = prompts.replace(":", ":-")
2090
+ prompts = prompts.replace(":--", ":")
2091
+ elif template == "Hyper Fast Results":
2092
+ step_size = 1
2093
+ ema_val = 0.3
2094
+ cutn = 30
2095
+ elif template == "Better Quality":
2096
+ prompts = (
2097
+ prompts + ":1 | Watermark, blurry, cropped, confusing, cut, incoherent:-1"
2098
+ )
2099
+
2100
+ mse_decay = 0
2101
+
2102
+ if use_mse == False:
2103
+ mse_init_weight = 0.0
2104
+ else:
2105
+ mse_decay = mse_init_weight / mse_epoches
2106
+
2107
+
2108
+ if seed == -1:
2109
+ seed = None
2110
+ if init_image == "None":
2111
+ init_image = None
2112
+ if target_images == "None" or not target_images:
2113
+ target_images = []
2114
+ else:
2115
+ target_images = target_images.split("|")
2116
+ target_images = [image.strip() for image in target_images]
2117
+
2118
+ prompts = [phrase.strip() for phrase in prompts.split("|")]
2119
+ if prompts == [""]:
2120
+ prompts = []
2121
+
2122
+ altprompts = [phrase.strip() for phrase in altprompts.split("|")]
2123
+ if altprompts == [""]:
2124
+ altprompts = []
2125
+
2126
+ if mse_images_interval == 0:
2127
+ mse_images_interval = images_interval
2128
+ if mse_step_size == 0:
2129
+ mse_step_size = step_size
2130
+ if mse_cutn == 0:
2131
+ mse_cutn = cutn
2132
+ if mse_cut_pow == 0:
2133
+ mse_cut_pow = cut_pow
2134
+ if alt_cut_pow == 0:
2135
+ alt_cut_pow = cut_pow
2136
+ if alt_mse_cut_pow == 0:
2137
+ alt_mse_cut_pow = mse_cut_pow
2138
+
2139
+ augs = nn.Sequential(
2140
+ K.RandomHorizontalFlip(p=0.5),
2141
+ K.RandomSharpness(0.3, p=0.4),
2142
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2143
+ # K.RandomGaussianNoise(p=0.5),
2144
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2145
+ K.RandomAffine(
2146
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2147
+ ), # padding_mode=2
2148
+ K.RandomPerspective(
2149
+ 0.2,
2150
+ p=0.4,
2151
+ ),
2152
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2153
+ K.RandomGrayscale(p=0.1),
2154
+ )
2155
+
2156
+ if altprompt_mode == "normal":
2157
+ altaugs = nn.Sequential(
2158
+ K.RandomRotation(degrees=90.0, return_transform=True),
2159
+ K.RandomHorizontalFlip(p=0.5),
2160
+ K.RandomSharpness(0.3, p=0.4),
2161
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2162
+ # K.RandomGaussianNoise(p=0.5),
2163
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2164
+ K.RandomAffine(
2165
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2166
+ ), # padding_mode=2
2167
+ K.RandomPerspective(
2168
+ 0.2,
2169
+ p=0.4,
2170
+ ),
2171
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2172
+ K.RandomGrayscale(p=0.1),
2173
+ )
2174
+ elif altprompt_mode == "flipped":
2175
+ altaugs = nn.Sequential(
2176
+ K.RandomHorizontalFlip(p=0.5),
2177
+ # K.RandomRotation(degrees=90.0),
2178
+ K.RandomVerticalFlip(p=1),
2179
+ K.RandomSharpness(0.3, p=0.4),
2180
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2181
+ # K.RandomGaussianNoise(p=0.5),
2182
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2183
+ K.RandomAffine(
2184
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2185
+ ), # padding_mode=2
2186
+ K.RandomPerspective(
2187
+ 0.2,
2188
+ p=0.4,
2189
+ ),
2190
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2191
+ K.RandomGrayscale(p=0.1),
2192
+ )
2193
+ elif altprompt_mode == "sideways":
2194
+ altaugs = nn.Sequential(
2195
+ K.RandomHorizontalFlip(p=0.5),
2196
+ # K.RandomRotation(degrees=90.0),
2197
+ K.RandomVerticalFlip(p=1),
2198
+ K.RandomSharpness(0.3, p=0.4),
2199
+ K.RandomGaussianBlur((3, 3), (4.5, 4.5), p=0.3),
2200
+ # K.RandomGaussianNoise(p=0.5),
2201
+ # K.RandomElasticTransform(kernel_size=(33, 33), sigma=(7,7), p=0.2),
2202
+ K.RandomAffine(
2203
+ degrees=30, translate=0.1, p=0.8, padding_mode="border"
2204
+ ), # padding_mode=2
2205
+ K.RandomPerspective(
2206
+ 0.2,
2207
+ p=0.4,
2208
+ ),
2209
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
2210
+ K.RandomGrayscale(p=0.1),
2211
+ )
2212
+
2213
+ if multiple_prompt_batches:
2214
+ prompts_all = str(prompts).split("~")
2215
+ else:
2216
+ prompts_all = prompts
2217
+ multiple_prompt_batches_iter = iterations
2218
+
2219
+ if multiple_prompt_batches:
2220
+ mtpl_prmpts_btchs = len(prompts_all)
2221
+ else:
2222
+ mtpl_prmpts_btchs = 1
2223
+
2224
+ # print(mtpl_prmpts_btchs)
2225
+
2226
+ steps_path = "./"
2227
+ zoom_path = "./"
2228
+
2229
+ path = "./"
2230
+
2231
+ iterations = multiple_prompt_batches_iter
2232
+
2233
+ for pr in range(0, mtpl_prmpts_btchs):
2234
+ # print(prompts_all[pr].replace('[\'', '').replace('\']', ''))
2235
+ if multiple_prompt_batches:
2236
+ prompts = prompts_all[pr].replace("['", "").replace("']", "")
2237
+
2238
+ if zoom:
2239
+ mdf_iter = round(iterations / zoom_frequency)
2240
+ else:
2241
+ mdf_iter = 2
2242
+ zoom_frequency = iterations
2243
+
2244
+ for iter in range(1, mdf_iter):
2245
+ if zoom:
2246
+ if iter != 0:
2247
+ image = Image.open("progress.png")
2248
+ area = (0, 0, width - zoom_speed, height - zoom_speed)
2249
+ cropped_img = image.crop(area)
2250
+ cropped_img.show()
2251
+
2252
+ new_image = cropped_img.resize((width, height))
2253
+ new_image.save("zoom.png")
2254
+ init_image = "zoom.png"
2255
+
2256
+ args = argparse.Namespace(
2257
+ prompts=prompts,
2258
+ altprompts=altprompts,
2259
+ image_prompts=target_images,
2260
+ noise_prompt_seeds=[],
2261
+ noise_prompt_weights=[],
2262
+ size=[width, height],
2263
+ init_image=init_image,
2264
+ png=transparent_png,
2265
+ init_weight=mse_init_weight,
2266
+ vqgan_model=model_names[model],
2267
+ step_size=step_size,
2268
+ start_step_size=start_step_size,
2269
+ final_step_size=final_step_size,
2270
+ cutn=cutn,
2271
+ cut_pow=cut_pow,
2272
+ mse_cutn=mse_cutn,
2273
+ mse_cut_pow=mse_cut_pow,
2274
+ mse_step_size=mse_step_size,
2275
+ display_freq=images_interval,
2276
+ mse_display_freq=mse_images_interval,
2277
+ max_iterations=zoom_frequency,
2278
+ mse_end=0,
2279
+ seed=seed,
2280
+ folder_name=folder_name,
2281
+ save_to_drive=save_to_drive,
2282
+ mse_decay_rate=mse_decay_rate,
2283
+ mse_decay=mse_decay,
2284
+ mse_with_zeros=mse_with_zeros,
2285
+ normal_flip_optim=normal_flip_optim,
2286
+ ema_val=ema_val,
2287
+ augs=augs,
2288
+ altaugs=altaugs,
2289
+ alt_cut_pow=alt_cut_pow,
2290
+ alt_mse_cut_pow=alt_mse_cut_pow,
2291
+ is_gumbel=is_gumbel,
2292
+ gen_seed=gen_seed,
2293
+ )
2294
+
2295
+ mh = ModelHost(args)
2296
+ x = 0
2297
+
2298
+ for x in range(batch_size):
2299
+ mh.setup_model(x)
2300
+ last_iter = mh.run(x)
2301
+ x = x + 1
2302
+
2303
+ #if batch_size != 1:
2304
+ # clear_output()
2305
+ # print("===============================================================================")
2306
+ #q = 0
2307
+ #while q < batch_size:
2308
+ #display(Image("/content/" + folder_name + "/" + str(q) + ".png"))
2309
+ # print("Image" + str(q) + '.png')
2310
+ #q += 1
2311
+
2312
+ if zoom:
2313
+ files = os.listdir(steps_path)
2314
+ for index, file in enumerate(files):
2315
+ os.rename(
2316
+ os.path.join(steps_path, file),
2317
+ os.path.join(
2318
+ steps_path,
2319
+ "".join([str(index + 1 + zoom_frequency * iter), ".png"]),
2320
+ ),
2321
+ )
2322
+ index = index + 1
2323
+
2324
+ from pathlib import Path
2325
+ import shutil
2326
+
2327
+ src_path = steps_path
2328
+ trg_path = zoom_path
2329
+
2330
+ for src_file in range(1, mdf_iter):
2331
+ shutil.move(os.path.join(src_path, src_file), trg_path)
2332
+
2333
+ ##################### START GRADIO HERE ############################
2334
+ image = gr.outputs.Image(type="pil", label="Your result")
2335
+ iface = gr.Interface(
2336
+ fn=run,
2337
+ inputs=[
2338
+ gr.inputs.Textbox(label="Prompt - try adding increments to your prompt such as 'oil on canvas', 'a painting', 'a book cover'",default="chalk pastel drawing of a dog wearing a funny hat"),
2339
+ gr.inputs.Slider(label="Steps - more steps can increase quality but will take longer to generate",default=45,maximum=50,minimum=1,step=1),
2340
+ gr.inputs.Dropdown(label="Style",choices=["none","Balanced","Detailed","Consistent Creativity","Realistic","Smooth","Subtle MSE","Hyper Fast Results"]),
2341
+ gr.inputs.Radio(label="Width", choices=[32,64,128,256,512],default=256),
2342
+ gr.inputs.Radio(label="Height", choices=[32,64,128,256,512],default=256),
2343
+ ],
2344
+ outputs=[image],
2345
+ title="Generate images from text with VQGAN+CLIP",
2346
+ #description="<div>By typing a prompt and pressing submit you can generate images based on this prompt. <a href='https://github.com/CompVis/latent-diffusion' target='_blank'>Latent Diffusion</a> is a text-to-image model created by <a href='https://github.com/CompVis' target='_blank'>CompVis</a>, trained on the <a href='https://laion.ai/laion-400-open-dataset/'>LAION-400M dataset.</a><br>This UI to the model was assembled by <a style='color: rgb(245, 158, 11);font-weight:bold' href='https://twitter.com/multimodalart' target='_blank'>@multimodalart</a></div>",
2347
+ #article="<h4 style='font-size: 110%;margin-top:.5em'>Biases acknowledgment</h4><div>Despite how impressive being able to turn text into image is, beware to the fact that this model may output content that reinforces or exarcbates societal biases. According to the <a href='https://arxiv.org/abs/2112.10752' target='_blank'>Latent Diffusion paper</a>:<i> \"Deep learning modules tend to reproduce or exacerbate biases that are already present in the data\"</i>. The model was trained on an unfiltered version the LAION-400M dataset, which scrapped non-curated image-text-pairs from the internet (the exception being the the removal of illegal content) and is meant to be used for research purposes, such as this one. <a href='https://laion.ai/laion-400-open-dataset/' target='_blank'>You can read more on LAION's website</a></div><h4 style='font-size: 110%;margin-top:1em'>Who owns the images produced by this demo?</h4><div>Definetly not me! Probably you do. I say probably because the Copyright discussion about AI generated art is ongoing. So <a href='https://www.theverge.com/2022/2/21/22944335/us-copyright-office-reject-ai-generated-art-recent-entrance-to-paradise' target='_blank'>it may be the case that everything produced here falls automatically into the public domain</a>. But in any case it is either yours or is in the public domain.</div>"
2348
+ )
2349
+ iface.launch(enable_queue=True)
requirements.txt CHANGED
@@ -1 +1,20 @@
1
- torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/CompVis/taming-transformers.git#egg=taming-transformers
2
+ ftfy
3
+ regex
4
+ pandas
5
+ omegaconf
6
+ pytorch-lightning
7
+ torch-fidelity
8
+ transformers
9
+ einops
10
+ gradio
11
+ torch
12
+ open_clip_torch
13
+ numpy
14
+ tqdm
15
+ torchvision
16
+ Pillow
17
+ autokeras
18
+ huggingface_hub
19
+ kornia
20
+ clip