Spaces:
Sleeping
Sleeping
Add model from Chameleon
Browse files- app.py +100 -37
- chameleon/LICENSE +51 -0
- chameleon/image_tokenizer.py +124 -0
- chameleon/vqgan.py +675 -0
app.py
CHANGED
@@ -1,18 +1,27 @@
|
|
1 |
-
from typing import
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import colorsys
|
|
|
6 |
|
|
|
7 |
from diffusers import VQModel
|
8 |
from diffusers.image_processor import VaeImageProcessor
|
9 |
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
10 |
-
|
|
|
|
|
11 |
import torch.backends
|
12 |
import torch.mps
|
13 |
from PIL import Image
|
|
|
14 |
import spaces
|
15 |
|
|
|
|
|
|
|
|
|
16 |
if torch.cuda.is_available():
|
17 |
device = torch.device("cuda")
|
18 |
elif torch.backends.mps.is_available():
|
@@ -21,9 +30,7 @@ else:
|
|
21 |
device = torch.device("cpu")
|
22 |
|
23 |
|
24 |
-
# abstract class VQImageRoundtripPipeline:
|
25 |
class ImageRoundtripPipeline:
|
26 |
-
@abstractmethod
|
27 |
def roundtrip_image(self, image, output_type="pil"): ...
|
28 |
|
29 |
|
@@ -63,6 +70,12 @@ class VQImageRoundtripPipeline(ImageRoundtripPipeline):
|
|
63 |
latents = self.vqvae.quantize(latents)[2][2].reshape(
|
64 |
batch_size, latents_height, latents_width
|
65 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
output = self.vqvae.decode(
|
67 |
latents,
|
68 |
force_not_quantize=True,
|
@@ -81,6 +94,55 @@ class VQImageRoundtripPipeline(ImageRoundtripPipeline):
|
|
81 |
return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
|
82 |
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
|
85 |
vqgan: PaellaVQModel
|
86 |
vae_scale_factor: int
|
@@ -127,6 +189,7 @@ class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
|
|
127 |
|
128 |
pipeline_paella = PaellaImageRoundtripPipeline()
|
129 |
pipeline_vq = VQImageRoundtripPipeline()
|
|
|
130 |
|
131 |
|
132 |
# Function to generate a list of unique colors
|
@@ -171,26 +234,27 @@ def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor):
|
|
171 |
return img
|
172 |
|
173 |
|
174 |
-
# This is a gradio space that lets you encode an image with various encoder-decoder pairs, eg VQ-GAN, SDXL's VAE, etc and check the image quality
|
175 |
-
|
176 |
-
|
177 |
-
# def image_grid_to_string(image_grid):
|
178 |
-
# """Convert a latent vq index "image" grid to a string, input shape is (1, height, width)"""
|
179 |
-
# return "\n".join(
|
180 |
-
# [" ".join([str(int(x)) for x in row]) for row in image_grid.squeeze()]
|
181 |
-
# )
|
182 |
-
|
183 |
-
|
184 |
def describe_shape(shape):
|
185 |
return f"Shape: {shape} num elements: {np.prod(shape)}"
|
186 |
|
187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
@spaces.GPU(duration=32)
|
189 |
@torch.no_grad()
|
190 |
def roundtrip_image(
|
191 |
image,
|
192 |
-
model:
|
193 |
-
size:
|
194 |
output_type="pil",
|
195 |
):
|
196 |
if size == "256x256":
|
@@ -202,41 +266,40 @@ def roundtrip_image(
|
|
202 |
else:
|
203 |
raise ValueError(f"Unknown size {size}")
|
204 |
|
|
|
205 |
if model == "vqgan":
|
206 |
-
|
207 |
-
return (
|
208 |
-
image,
|
209 |
-
vqgan_tokens_to_image(
|
210 |
-
latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
|
211 |
-
),
|
212 |
-
describe_shape(latents.shape),
|
213 |
-
)
|
214 |
elif model == "paella":
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
return (
|
219 |
-
image,
|
220 |
-
vqgan_tokens_to_image(
|
221 |
-
latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
|
222 |
-
),
|
223 |
-
describe_shape(latents.shape),
|
224 |
-
)
|
225 |
else:
|
226 |
raise ValueError(f"Unknown model {model}")
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
demo = gr.Interface(
|
230 |
fn=roundtrip_image,
|
231 |
inputs=[
|
232 |
gr.Image(type="pil"),
|
233 |
-
gr.Dropdown(
|
234 |
gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
|
235 |
],
|
236 |
outputs=[
|
237 |
-
gr.Image(label="Reconstructed"),
|
238 |
-
gr.Image(label="Tokens"),
|
239 |
gr.Text(label="VQ Shape"),
|
|
|
240 |
],
|
241 |
title="Image Tokenizer Playground",
|
242 |
description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
|
|
|
1 |
+
from typing import Literal
|
2 |
import gradio as gr
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
import colorsys
|
6 |
+
import yaml
|
7 |
|
8 |
+
from huggingface_hub import hf_hub_download
|
9 |
from diffusers import VQModel
|
10 |
from diffusers.image_processor import VaeImageProcessor
|
11 |
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
|
12 |
+
|
13 |
+
from chameleon.image_tokenizer import ImageTokenizer
|
14 |
+
|
15 |
import torch.backends
|
16 |
import torch.mps
|
17 |
from PIL import Image
|
18 |
+
|
19 |
import spaces
|
20 |
|
21 |
+
|
22 |
+
Model = Literal["vqgan", "paella", "chameleon"]
|
23 |
+
models = ["vqgan", "paella", "chameleon"]
|
24 |
+
|
25 |
if torch.cuda.is_available():
|
26 |
device = torch.device("cuda")
|
27 |
elif torch.backends.mps.is_available():
|
|
|
30 |
device = torch.device("cpu")
|
31 |
|
32 |
|
|
|
33 |
class ImageRoundtripPipeline:
|
|
|
34 |
def roundtrip_image(self, image, output_type="pil"): ...
|
35 |
|
36 |
|
|
|
70 |
latents = self.vqvae.quantize(latents)[2][2].reshape(
|
71 |
batch_size, latents_height, latents_width
|
72 |
)
|
73 |
+
# replace 20% of latents with random values
|
74 |
+
# random_latents = torch.randint(
|
75 |
+
# 0, self.vqvae.config.num_vq_embeddings, latents.shape, device=device
|
76 |
+
# )
|
77 |
+
# random_mask = torch.rand(latents.shape, device=device) < 0.2
|
78 |
+
# latents = torch.where(random_mask, random_latents, latents)
|
79 |
output = self.vqvae.decode(
|
80 |
latents,
|
81 |
force_not_quantize=True,
|
|
|
94 |
return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
|
95 |
|
96 |
|
97 |
+
class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline):
|
98 |
+
tokenizer: ImageTokenizer
|
99 |
+
n_embed: int
|
100 |
+
vae_scale_factor: int
|
101 |
+
|
102 |
+
def __init__(self):
|
103 |
+
vqgan_path = hf_hub_download(
|
104 |
+
"darknoon/chameleon-tokenizer", "tokenizer/vqgan.ckpt"
|
105 |
+
)
|
106 |
+
vqgan_config_path = hf_hub_download(
|
107 |
+
"darknoon/chameleon-tokenizer", "tokenizer/vqgan.yaml"
|
108 |
+
)
|
109 |
+
self.tokenizer = ImageTokenizer(
|
110 |
+
cfg_path=vqgan_config_path, ckpt_path=vqgan_path, device=device
|
111 |
+
)
|
112 |
+
with open(vqgan_config_path) as f:
|
113 |
+
vq_config = yaml.safe_load(f)
|
114 |
+
|
115 |
+
self.n_embed = vq_config["model"]["params"]["n_embed"]
|
116 |
+
self.vae_scale_factor = 16
|
117 |
+
print("Chameleon VQGan model loaded", self.tokenizer._vq_model, self.n_embed)
|
118 |
+
|
119 |
+
def preprocess(self, image: Image):
|
120 |
+
# copied from _vqgan_input_from
|
121 |
+
np_img = np.array(image) / 255.0 # Normalize to [0, 1]
|
122 |
+
np_img = np_img * 2 - 1 # Scale to [-1, 1]
|
123 |
+
tensor_img = (
|
124 |
+
torch.from_numpy(np_img).permute(2, 0, 1).float()
|
125 |
+
) # (Channels, Height, Width) format.
|
126 |
+
|
127 |
+
# Add batch dimension.
|
128 |
+
return tensor_img.unsqueeze(0)
|
129 |
+
|
130 |
+
def roundtrip_image(self, image, output_type="pil"):
|
131 |
+
# image = self.tokenizer._vqgan_input_from(image).to(device)
|
132 |
+
image = self.preprocess(image).to(device)
|
133 |
+
_, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
|
134 |
+
# emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
135 |
+
output = self.tokenizer.pil_from_img_toks(latents)
|
136 |
+
# we actually do want this to be a grid, sorry!
|
137 |
+
latents = latents.reshape(1, 32, 32)
|
138 |
+
|
139 |
+
return (
|
140 |
+
output,
|
141 |
+
latents.cpu().numpy(),
|
142 |
+
self.n_embed,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
|
147 |
vqgan: PaellaVQModel
|
148 |
vae_scale_factor: int
|
|
|
189 |
|
190 |
pipeline_paella = PaellaImageRoundtripPipeline()
|
191 |
pipeline_vq = VQImageRoundtripPipeline()
|
192 |
+
pipeline_vq_chameleon = ChameleonVQImageRoundtripPipeline()
|
193 |
|
194 |
|
195 |
# Function to generate a list of unique colors
|
|
|
234 |
return img
|
235 |
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
def describe_shape(shape):
|
238 |
return f"Shape: {shape} num elements: {np.prod(shape)}"
|
239 |
|
240 |
|
241 |
+
def calc_psnr(img1: Image, img2: Image):
|
242 |
+
if img1.size != img2.size:
|
243 |
+
raise ValueError("Images must have the same dimensions")
|
244 |
+
img1 = np.array(img1)
|
245 |
+
img2 = np.array(img2)
|
246 |
+
mse = np.mean((img1 - img2) ** 2)
|
247 |
+
if mse == 0:
|
248 |
+
return float("inf")
|
249 |
+
return 2 * 10 * np.log10(255.0 / np.sqrt(mse))
|
250 |
+
|
251 |
+
|
252 |
@spaces.GPU(duration=32)
|
253 |
@torch.no_grad()
|
254 |
def roundtrip_image(
|
255 |
image,
|
256 |
+
model: Model,
|
257 |
+
size: Literal["256x256", "512x512", "1024x1024"],
|
258 |
output_type="pil",
|
259 |
):
|
260 |
if size == "256x256":
|
|
|
266 |
else:
|
267 |
raise ValueError(f"Unknown size {size}")
|
268 |
|
269 |
+
image_orig = image
|
270 |
if model == "vqgan":
|
271 |
+
pipeline = pipeline_vq
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
elif model == "paella":
|
273 |
+
pipeline = pipeline_paella
|
274 |
+
elif model == "chameleon":
|
275 |
+
pipeline = pipeline_vq_chameleon
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
else:
|
277 |
raise ValueError(f"Unknown model {model}")
|
278 |
|
279 |
+
image, latents, codebook_size = pipeline.roundtrip_image(image, output_type)
|
280 |
+
|
281 |
+
return (
|
282 |
+
image,
|
283 |
+
vqgan_tokens_to_image(
|
284 |
+
latents, codebook_size, downscale_factor=pipeline.vae_scale_factor
|
285 |
+
),
|
286 |
+
describe_shape(latents.shape),
|
287 |
+
f"{calc_psnr(image_orig, image):.2f}",
|
288 |
+
)
|
289 |
+
|
290 |
|
291 |
demo = gr.Interface(
|
292 |
fn=roundtrip_image,
|
293 |
inputs=[
|
294 |
gr.Image(type="pil"),
|
295 |
+
gr.Dropdown(models, label="Model", value="vqgan"),
|
296 |
gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
|
297 |
],
|
298 |
outputs=[
|
299 |
+
gr.Image(label="Reconstructed", format="png"),
|
300 |
+
gr.Image(label="Tokens", format="png"),
|
301 |
gr.Text(label="VQ Shape"),
|
302 |
+
gr.Text(label="PSNR"),
|
303 |
],
|
304 |
title="Image Tokenizer Playground",
|
305 |
description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
|
chameleon/LICENSE
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Chameleon Research License
|
2 |
+
Chameleon Version Release Date: June 18, 2024
|
3 |
+
|
4 |
+
This Chameleon Research License ("Agreement") contains the terms and conditions that govern your access and use of the Chameleon Materials (as defined below). You may not use the Chameleon Materials if you do not accept this Agreement. By clicking "I Accept" to accept, or accessing, using, or distributing any portion or element of the Chameleon Materials you hereby agree to be bound by the terms of this Agreement. If you are agreeing to be bound by the Agreement on behalf of your employer or other entity, you represent and warrant to Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland) ("Meta") that you have full legal authority to bind your employer or such entity to this Agreement. If you do not have requisite authority, you may not accept the Agreement or access the Chameleon Materials on behalf of your employer or other entity.
|
5 |
+
|
6 |
+
This Agreement is effective upon the earlier of the date that you first access the Chameleon Materials or accept this Agreement ("Effective Date"), and is entered into by and between Meta, and you, or if you are entering into this Agreement on behalf of your employer or other entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules, or regulations to provide legal consent and, your employer or other entity and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf ("Licensee" or "You").
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
1. "Documentation" means the specifications, manuals and documentation accompanying Chameleon distributed by Meta at https://github.com/facebookresearch/chameleon and https://ai.meta.com/resources/models-and-libraries/chameleon-downloads/.
|
10 |
+
|
11 |
+
|
12 |
+
2. "Noncommercial Research Uses" means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
|
13 |
+
|
14 |
+
|
15 |
+
3. "Chameleon" means the models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta at [INSERT RESOURCE HYPERLINK].
|
16 |
+
|
17 |
+
|
18 |
+
4. "Chameleon Materials" means, collectively, Meta's proprietary Chameleon and Documentation (and any portion thereof) made available under this Agreement.
|
19 |
+
|
20 |
+
|
21 |
+
5. "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
22 |
+
|
23 |
+
|
24 |
+
6. "Acceptable Use Policy" means the Acceptable Use Policy applicable to Chameleon Materials ([INSERT Chameleon AUP HYPERLINK]) that is incorporated into this Agreement.
|
25 |
+
|
26 |
+
|
27 |
+
2. License Rights and Redistribution. Subject to Your compliance with the terms and conditions of this Agreement, Meta hereby grants you the following:
|
28 |
+
1. Grant of Rights. You are hereby granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta's intellectual property or other rights owned by Meta embodied in the Chameleon Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Chameleon Materials solely for Noncommercial Research Uses.
|
29 |
+
2. Redistribution and Use.
|
30 |
+
1. Distribution of Chameleon Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Chameleon Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
|
31 |
+
2. If you submit for publication the results of research you perform on, using, or otherwise in connection with Chameleon Materials, you must acknowledge the use of Chameleon Materials in your publication as follows (or an equivalent acknowledgement of your choosing): "This material is based on work supported by the Chameleon Research License, Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
32 |
+
|
33 |
+
3. You must retain in all copies of the Chameleon Materials that you distribute and include the following attribution notice within a "Notice" text file distributed as a part of such copies: "Chameleon is licensed under the Chameleon Research License, Copyright (c) Meta Platforms, Inc. All Rights Reserved."
|
34 |
+
4. Your use of the Chameleon Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy for the Chameleon Materials (https://ai.meta.com/resources/models-and-libraries/chameleon-use-policy/) which is hereby incorporated by reference into this Agreement.
|
35 |
+
3. Restrictions. You will not, and will not permit, assist or cause any third party to:
|
36 |
+
1. use the Chameleon Materials or any outputs or results of the Chameleon Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
|
37 |
+
2. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Meta in connection with the Chameleon Materials, or to circumvent or remove any usage restrictions or other safety measures, or to enable functionality disabled by Meta;
|
38 |
+
3. disguise your or their location through IP proxying or other methods;
|
39 |
+
4. use or download Chameleon Materials if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) will use Chameleon Materials for any purpose prohibited by Trade Control Laws; or
|
40 |
+
5. directly or indirectly export, re-export, provide, or otherwise transfer Chameleon Materials: (a) to any individual, entity, or country prohibited by Trade Control Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Trade Control Laws, including nuclear, chemical or biological weapons, or missile technology applications.
|
41 |
+
4. User Support. Your Noncommercial Research Use of the Chameleon Materials is done at your own discretion; Meta does not provide any service in relation to such use. Meta is under no obligation to provide any support services for the Chameleon Materials. Any support provided is "as is", "with all faults", and without warranty of any kind.
|
42 |
+
5. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE Chameleon MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE Chameleon MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE Chameleon MATERIALS AND ANY OUTPUT AND RESULTS.
|
43 |
+
6. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
7. Intellectual Property.
|
45 |
+
1. No trademark licenses are granted under this Agreement, and in connection with the Chameleon Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Chameleon Materials.
|
46 |
+
2. Subject to Meta's ownership of Chameleon Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Chameleon Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
|
47 |
+
3. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Chameleon Materials or Chameleon outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses and rights granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Chameleon Materials.
|
48 |
+
8. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Chameleon Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Chameleon Materials. Sections 3, 4, 5, 6(c), 7, 8 and 9 shall survive the termination of this Agreement.
|
49 |
+
9. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
|
50 |
+
10. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at https://ai.meta.com/resources/models-and-libraries/chameleon-license/
|
51 |
+
11. ; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Chameleon Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no other modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
|
chameleon/image_tokenizer.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from .vqgan import VQModel
|
13 |
+
|
14 |
+
|
15 |
+
class ImageTokenizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
cfg_path: str,
|
19 |
+
ckpt_path: str,
|
20 |
+
device: str | torch.device | None = None,
|
21 |
+
):
|
22 |
+
with open(cfg_path) as f:
|
23 |
+
config = yaml.safe_load(f)
|
24 |
+
|
25 |
+
params = config["model"]["params"]
|
26 |
+
if "lossconfig" in params:
|
27 |
+
del params["lossconfig"]
|
28 |
+
params["ckpt_path"] = ckpt_path
|
29 |
+
|
30 |
+
self._vq_model = VQModel(**params)
|
31 |
+
self._vq_model.eval()
|
32 |
+
|
33 |
+
if device is None:
|
34 |
+
devices = {p.device for p in self._vq_model.parameters()}
|
35 |
+
assert len(devices) == 1
|
36 |
+
device = devices.pop()
|
37 |
+
else:
|
38 |
+
self._vq_model.to(device)
|
39 |
+
self._device = device
|
40 |
+
|
41 |
+
dtypes = {p.dtype for p in self._vq_model.parameters()}
|
42 |
+
assert len(dtypes) == 1
|
43 |
+
self._dtype = dtypes.pop()
|
44 |
+
|
45 |
+
def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
|
46 |
+
# Check if it's already in RGB format.
|
47 |
+
if img.mode == "RGB":
|
48 |
+
return img
|
49 |
+
|
50 |
+
vals_rgba = np.array(img.convert("RGBA"))
|
51 |
+
|
52 |
+
# If there is no transparency layer, simple convert and return.
|
53 |
+
if not (vals_rgba[:, :, 3] < 255).any():
|
54 |
+
return img.convert("RGB")
|
55 |
+
|
56 |
+
# There is a transparency layer, blend it with a white background.
|
57 |
+
|
58 |
+
# Calculate the alpha proportion for blending.
|
59 |
+
alpha = vals_rgba[:, :, 3] / 255.0
|
60 |
+
# Blend with white background.
|
61 |
+
vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
|
62 |
+
:, :, np.newaxis
|
63 |
+
] * vals_rgba[:, :, :3]
|
64 |
+
return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
|
65 |
+
|
66 |
+
def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
|
67 |
+
# Resize with aspect ratio preservation.
|
68 |
+
s = min(img.size)
|
69 |
+
scale = target_image_size / s
|
70 |
+
new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
|
71 |
+
img = img.resize(new_size, PIL.Image.LANCZOS)
|
72 |
+
|
73 |
+
# Center crop.
|
74 |
+
x0 = (img.width - target_image_size) // 2
|
75 |
+
y0 = (img.height - target_image_size) // 2
|
76 |
+
img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
|
77 |
+
|
78 |
+
# Convert to tensor.
|
79 |
+
np_img = np.array(img) / 255.0 # Normalize to [0, 1]
|
80 |
+
np_img = np_img * 2 - 1 # Scale to [-1, 1]
|
81 |
+
tensor_img = (
|
82 |
+
torch.from_numpy(np_img).permute(2, 0, 1).float()
|
83 |
+
) # (Channels, Height, Width) format.
|
84 |
+
|
85 |
+
# Add batch dimension.
|
86 |
+
return tensor_img.unsqueeze(0)
|
87 |
+
|
88 |
+
def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
|
89 |
+
image = self._whiten_transparency(image)
|
90 |
+
vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
|
91 |
+
_, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
|
92 |
+
return img_toks
|
93 |
+
|
94 |
+
def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
|
95 |
+
# Ensure detachment and move tensor to CPU.
|
96 |
+
detached_chw_tensor = chw_tensor.detach().cpu()
|
97 |
+
|
98 |
+
# Normalize tensor to [0, 1] range from [-1, 1] range.
|
99 |
+
normalized_chw_tensor = (
|
100 |
+
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
|
101 |
+
) / 2.0
|
102 |
+
|
103 |
+
# Permute CHW tensor to HWC format and convert to NumPy array.
|
104 |
+
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
|
105 |
+
|
106 |
+
# Convert to an 8-bit unsigned integer format.
|
107 |
+
image_array_uint8 = (hwc_array * 255).astype(np.uint8)
|
108 |
+
|
109 |
+
# Convert NumPy array to PIL Image.
|
110 |
+
pil_image = Image.fromarray(image_array_uint8)
|
111 |
+
|
112 |
+
# Convert image to RGB if it is not already.
|
113 |
+
if pil_image.mode != "RGB":
|
114 |
+
pil_image = pil_image.convert("RGB")
|
115 |
+
|
116 |
+
return pil_image
|
117 |
+
|
118 |
+
def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
|
119 |
+
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
120 |
+
codebook_entry = self._vq_model.quantize.get_codebook_entry(
|
121 |
+
img_tensor, (1, 32, 32, emb_dim)
|
122 |
+
)
|
123 |
+
pixels = self._vq_model.decode(codebook_entry)
|
124 |
+
return self._pil_from_chw_tensor(pixels[0])
|
chameleon/vqgan.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
|
8 |
+
[with minimal dependencies]
|
9 |
+
|
10 |
+
This implementation is inference-only -- training steps and optimizer components
|
11 |
+
introduce significant additional dependencies
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
class VectorQuantizer2(nn.Module):
|
21 |
+
"""
|
22 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
23 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
27 |
+
# backwards compatibility we use the buggy version by default, but you can
|
28 |
+
# specify legacy=False to fix it.
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
n_e,
|
32 |
+
e_dim,
|
33 |
+
beta,
|
34 |
+
remap=None,
|
35 |
+
unknown_index="random",
|
36 |
+
sane_index_shape=False,
|
37 |
+
legacy=True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.n_e = n_e
|
41 |
+
self.e_dim = e_dim
|
42 |
+
self.beta = beta
|
43 |
+
self.legacy = legacy
|
44 |
+
|
45 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
46 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
47 |
+
|
48 |
+
self.remap = remap
|
49 |
+
if self.remap is not None:
|
50 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
51 |
+
self.re_embed = self.used.shape[0]
|
52 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
53 |
+
if self.unknown_index == "extra":
|
54 |
+
self.unknown_index = self.re_embed
|
55 |
+
self.re_embed = self.re_embed + 1
|
56 |
+
print(
|
57 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
58 |
+
f"Using {self.unknown_index} for unknown indices."
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
self.re_embed = n_e
|
62 |
+
|
63 |
+
self.sane_index_shape = sane_index_shape
|
64 |
+
|
65 |
+
def remap_to_used(self, inds):
|
66 |
+
ishape = inds.shape
|
67 |
+
assert len(ishape) > 1
|
68 |
+
inds = inds.reshape(ishape[0], -1)
|
69 |
+
used = self.used.to(inds)
|
70 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
71 |
+
new = match.argmax(-1)
|
72 |
+
unknown = match.sum(2) < 1
|
73 |
+
if self.unknown_index == "random":
|
74 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
75 |
+
device=new.device
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
new[unknown] = self.unknown_index
|
79 |
+
return new.reshape(ishape)
|
80 |
+
|
81 |
+
def unmap_to_all(self, inds):
|
82 |
+
ishape = inds.shape
|
83 |
+
assert len(ishape) > 1
|
84 |
+
inds = inds.reshape(ishape[0], -1)
|
85 |
+
used = self.used.to(inds)
|
86 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
87 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
88 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
89 |
+
return back.reshape(ishape)
|
90 |
+
|
91 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
92 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
93 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
94 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
95 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
96 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
97 |
+
z_flattened = z.view(-1, self.e_dim)
|
98 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
99 |
+
|
100 |
+
d = (
|
101 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
102 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
103 |
+
- 2
|
104 |
+
* torch.einsum(
|
105 |
+
"bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
110 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
111 |
+
perplexity = None
|
112 |
+
min_encodings = None
|
113 |
+
|
114 |
+
# compute loss for embedding
|
115 |
+
if not self.legacy:
|
116 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
117 |
+
(z_q - z.detach()) ** 2
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
121 |
+
(z_q - z.detach()) ** 2
|
122 |
+
)
|
123 |
+
|
124 |
+
# preserve gradients
|
125 |
+
z_q = z + (z_q - z).detach()
|
126 |
+
|
127 |
+
# reshape back to match original input shape
|
128 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
129 |
+
|
130 |
+
if self.remap is not None:
|
131 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
132 |
+
z.shape[0], -1
|
133 |
+
) # add batch axis
|
134 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
135 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
136 |
+
|
137 |
+
if self.sane_index_shape:
|
138 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
139 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
140 |
+
)
|
141 |
+
|
142 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
143 |
+
|
144 |
+
def get_codebook_entry(self, indices, shape):
|
145 |
+
# shape specifying (batch, height, width, channel)
|
146 |
+
if self.remap is not None:
|
147 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
148 |
+
indices = self.unmap_to_all(indices)
|
149 |
+
indices = indices.reshape(-1) # flatten again
|
150 |
+
|
151 |
+
# get quantized latent vectors
|
152 |
+
z_q = self.embedding(indices)
|
153 |
+
|
154 |
+
if shape is not None:
|
155 |
+
z_q = z_q.view(shape)
|
156 |
+
# reshape back to match original input shape
|
157 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
158 |
+
|
159 |
+
return z_q
|
160 |
+
|
161 |
+
|
162 |
+
# Alias
|
163 |
+
VectorQuantizer = VectorQuantizer2
|
164 |
+
|
165 |
+
|
166 |
+
def nonlinearity(x):
|
167 |
+
# swish
|
168 |
+
return x * torch.sigmoid(x)
|
169 |
+
|
170 |
+
|
171 |
+
def Normalize(in_channels, num_groups=32):
|
172 |
+
return torch.nn.GroupNorm(
|
173 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
class Upsample(nn.Module):
|
178 |
+
def __init__(self, in_channels, with_conv):
|
179 |
+
super().__init__()
|
180 |
+
self.with_conv = with_conv
|
181 |
+
if self.with_conv:
|
182 |
+
self.conv = torch.nn.Conv2d(
|
183 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
184 |
+
)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
188 |
+
if self.with_conv:
|
189 |
+
x = self.conv(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
class Downsample(nn.Module):
|
194 |
+
def __init__(self, in_channels, with_conv):
|
195 |
+
super().__init__()
|
196 |
+
self.with_conv = with_conv
|
197 |
+
if self.with_conv:
|
198 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
199 |
+
self.conv = torch.nn.Conv2d(
|
200 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
if self.with_conv:
|
205 |
+
pad = (0, 1, 0, 1)
|
206 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
207 |
+
x = self.conv(x)
|
208 |
+
else:
|
209 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
class ResnetBlock(nn.Module):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
*,
|
217 |
+
in_channels,
|
218 |
+
out_channels=None,
|
219 |
+
conv_shortcut=False,
|
220 |
+
dropout,
|
221 |
+
temb_channels=512,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.in_channels = in_channels
|
225 |
+
out_channels = in_channels if out_channels is None else out_channels
|
226 |
+
self.out_channels = out_channels
|
227 |
+
self.use_conv_shortcut = conv_shortcut
|
228 |
+
|
229 |
+
self.norm1 = Normalize(in_channels)
|
230 |
+
self.conv1 = torch.nn.Conv2d(
|
231 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
232 |
+
)
|
233 |
+
if temb_channels > 0:
|
234 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
235 |
+
self.norm2 = Normalize(out_channels)
|
236 |
+
self.dropout = torch.nn.Dropout(dropout)
|
237 |
+
self.conv2 = torch.nn.Conv2d(
|
238 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
239 |
+
)
|
240 |
+
if self.in_channels != self.out_channels:
|
241 |
+
if self.use_conv_shortcut:
|
242 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
243 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
247 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
248 |
+
)
|
249 |
+
|
250 |
+
def forward(self, x, temb):
|
251 |
+
h = x
|
252 |
+
h = self.norm1(h)
|
253 |
+
h = nonlinearity(h)
|
254 |
+
h = self.conv1(h)
|
255 |
+
|
256 |
+
if temb is not None:
|
257 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
258 |
+
|
259 |
+
h = self.norm2(h)
|
260 |
+
h = nonlinearity(h)
|
261 |
+
h = self.dropout(h)
|
262 |
+
h = self.conv2(h)
|
263 |
+
|
264 |
+
if self.in_channels != self.out_channels:
|
265 |
+
if self.use_conv_shortcut:
|
266 |
+
x = self.conv_shortcut(x)
|
267 |
+
else:
|
268 |
+
x = self.nin_shortcut(x)
|
269 |
+
|
270 |
+
return x + h
|
271 |
+
|
272 |
+
|
273 |
+
class AttnBlock(nn.Module):
|
274 |
+
def __init__(self, in_channels):
|
275 |
+
super().__init__()
|
276 |
+
self.in_channels = in_channels
|
277 |
+
|
278 |
+
self.norm = Normalize(in_channels)
|
279 |
+
self.q = torch.nn.Conv2d(
|
280 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
281 |
+
)
|
282 |
+
self.k = torch.nn.Conv2d(
|
283 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
284 |
+
)
|
285 |
+
self.v = torch.nn.Conv2d(
|
286 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
287 |
+
)
|
288 |
+
self.proj_out = torch.nn.Conv2d(
|
289 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
h_ = x
|
294 |
+
h_ = self.norm(h_)
|
295 |
+
q = self.q(h_)
|
296 |
+
k = self.k(h_)
|
297 |
+
v = self.v(h_)
|
298 |
+
|
299 |
+
# compute attention
|
300 |
+
b, c, h, w = q.shape
|
301 |
+
q = q.reshape(b, c, h * w)
|
302 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
303 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
304 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
305 |
+
w_ = w_ * (int(c) ** (-0.5))
|
306 |
+
w_ = F.softmax(w_, dim=2)
|
307 |
+
|
308 |
+
# attend to values
|
309 |
+
v = v.reshape(b, c, h * w)
|
310 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
311 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
312 |
+
h_ = h_.reshape(b, c, h, w)
|
313 |
+
|
314 |
+
h_ = self.proj_out(h_)
|
315 |
+
|
316 |
+
return x + h_
|
317 |
+
|
318 |
+
|
319 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
320 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
321 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
322 |
+
if attn_type == "vanilla":
|
323 |
+
return AttnBlock(in_channels)
|
324 |
+
elif attn_type == "none":
|
325 |
+
return nn.Identity(in_channels)
|
326 |
+
else:
|
327 |
+
raise ValueError("Unexpected attention type")
|
328 |
+
|
329 |
+
|
330 |
+
class Encoder(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
*,
|
334 |
+
ch,
|
335 |
+
out_ch,
|
336 |
+
ch_mult=(1, 2, 4, 8),
|
337 |
+
num_res_blocks,
|
338 |
+
attn_resolutions,
|
339 |
+
dropout=0.0,
|
340 |
+
resamp_with_conv=True,
|
341 |
+
in_channels,
|
342 |
+
resolution,
|
343 |
+
z_channels,
|
344 |
+
double_z=True,
|
345 |
+
use_linear_attn=False,
|
346 |
+
attn_type="vanilla",
|
347 |
+
**ignore_kwargs,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
if use_linear_attn:
|
351 |
+
attn_type = "linear"
|
352 |
+
self.ch = ch
|
353 |
+
self.temb_ch = 0
|
354 |
+
self.num_resolutions = len(ch_mult)
|
355 |
+
self.num_res_blocks = num_res_blocks
|
356 |
+
self.resolution = resolution
|
357 |
+
self.in_channels = in_channels
|
358 |
+
|
359 |
+
# downsampling
|
360 |
+
self.conv_in = torch.nn.Conv2d(
|
361 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
362 |
+
)
|
363 |
+
|
364 |
+
curr_res = resolution
|
365 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
366 |
+
self.in_ch_mult = in_ch_mult
|
367 |
+
self.down = nn.ModuleList()
|
368 |
+
for i_level in range(self.num_resolutions):
|
369 |
+
block = nn.ModuleList()
|
370 |
+
attn = nn.ModuleList()
|
371 |
+
block_in = ch * in_ch_mult[i_level]
|
372 |
+
block_out = ch * ch_mult[i_level]
|
373 |
+
for i_block in range(self.num_res_blocks):
|
374 |
+
block.append(
|
375 |
+
ResnetBlock(
|
376 |
+
in_channels=block_in,
|
377 |
+
out_channels=block_out,
|
378 |
+
temb_channels=self.temb_ch,
|
379 |
+
dropout=dropout,
|
380 |
+
)
|
381 |
+
)
|
382 |
+
block_in = block_out
|
383 |
+
if curr_res in attn_resolutions:
|
384 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
385 |
+
down = nn.Module()
|
386 |
+
down.block = block
|
387 |
+
down.attn = attn
|
388 |
+
if i_level != self.num_resolutions - 1:
|
389 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
390 |
+
curr_res = curr_res // 2
|
391 |
+
self.down.append(down)
|
392 |
+
|
393 |
+
# middle
|
394 |
+
self.mid = nn.Module()
|
395 |
+
self.mid.block_1 = ResnetBlock(
|
396 |
+
in_channels=block_in,
|
397 |
+
out_channels=block_in,
|
398 |
+
temb_channels=self.temb_ch,
|
399 |
+
dropout=dropout,
|
400 |
+
)
|
401 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
402 |
+
self.mid.block_2 = ResnetBlock(
|
403 |
+
in_channels=block_in,
|
404 |
+
out_channels=block_in,
|
405 |
+
temb_channels=self.temb_ch,
|
406 |
+
dropout=dropout,
|
407 |
+
)
|
408 |
+
|
409 |
+
# end
|
410 |
+
self.norm_out = Normalize(block_in)
|
411 |
+
self.conv_out = torch.nn.Conv2d(
|
412 |
+
block_in,
|
413 |
+
2 * z_channels if double_z else z_channels,
|
414 |
+
kernel_size=3,
|
415 |
+
stride=1,
|
416 |
+
padding=1,
|
417 |
+
)
|
418 |
+
|
419 |
+
def forward(self, x):
|
420 |
+
# timestep embedding
|
421 |
+
temb = None
|
422 |
+
|
423 |
+
# downsampling
|
424 |
+
hs = [self.conv_in(x)]
|
425 |
+
for i_level in range(self.num_resolutions):
|
426 |
+
for i_block in range(self.num_res_blocks):
|
427 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
428 |
+
if len(self.down[i_level].attn) > 0:
|
429 |
+
h = self.down[i_level].attn[i_block](h)
|
430 |
+
hs.append(h)
|
431 |
+
if i_level != self.num_resolutions - 1:
|
432 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
433 |
+
|
434 |
+
# middle
|
435 |
+
h = hs[-1]
|
436 |
+
h = self.mid.block_1(h, temb)
|
437 |
+
h = self.mid.attn_1(h)
|
438 |
+
h = self.mid.block_2(h, temb)
|
439 |
+
|
440 |
+
# end
|
441 |
+
h = self.norm_out(h)
|
442 |
+
h = nonlinearity(h)
|
443 |
+
h = self.conv_out(h)
|
444 |
+
return h
|
445 |
+
|
446 |
+
|
447 |
+
class Decoder(nn.Module):
|
448 |
+
def __init__(
|
449 |
+
self,
|
450 |
+
*,
|
451 |
+
ch,
|
452 |
+
out_ch,
|
453 |
+
ch_mult=(1, 2, 4, 8),
|
454 |
+
num_res_blocks,
|
455 |
+
attn_resolutions,
|
456 |
+
dropout=0.0,
|
457 |
+
resamp_with_conv=True,
|
458 |
+
in_channels,
|
459 |
+
resolution,
|
460 |
+
z_channels,
|
461 |
+
give_pre_end=False,
|
462 |
+
tanh_out=False,
|
463 |
+
use_linear_attn=False,
|
464 |
+
attn_type="vanilla",
|
465 |
+
**ignorekwargs,
|
466 |
+
):
|
467 |
+
super().__init__()
|
468 |
+
if use_linear_attn:
|
469 |
+
attn_type = "linear"
|
470 |
+
self.ch = ch
|
471 |
+
self.temb_ch = 0
|
472 |
+
self.num_resolutions = len(ch_mult)
|
473 |
+
self.num_res_blocks = num_res_blocks
|
474 |
+
self.resolution = resolution
|
475 |
+
self.in_channels = in_channels
|
476 |
+
self.give_pre_end = give_pre_end
|
477 |
+
self.tanh_out = tanh_out
|
478 |
+
|
479 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
480 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
481 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
482 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
483 |
+
|
484 |
+
# z to block_in
|
485 |
+
self.conv_in = torch.nn.Conv2d(
|
486 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
487 |
+
)
|
488 |
+
|
489 |
+
# middle
|
490 |
+
self.mid = nn.Module()
|
491 |
+
self.mid.block_1 = ResnetBlock(
|
492 |
+
in_channels=block_in,
|
493 |
+
out_channels=block_in,
|
494 |
+
temb_channels=self.temb_ch,
|
495 |
+
dropout=dropout,
|
496 |
+
)
|
497 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
498 |
+
self.mid.block_2 = ResnetBlock(
|
499 |
+
in_channels=block_in,
|
500 |
+
out_channels=block_in,
|
501 |
+
temb_channels=self.temb_ch,
|
502 |
+
dropout=dropout,
|
503 |
+
)
|
504 |
+
|
505 |
+
# upsampling
|
506 |
+
self.up = nn.ModuleList()
|
507 |
+
for i_level in reversed(range(self.num_resolutions)):
|
508 |
+
block = nn.ModuleList()
|
509 |
+
attn = nn.ModuleList()
|
510 |
+
block_out = ch * ch_mult[i_level]
|
511 |
+
for i_block in range(self.num_res_blocks + 1):
|
512 |
+
block.append(
|
513 |
+
ResnetBlock(
|
514 |
+
in_channels=block_in,
|
515 |
+
out_channels=block_out,
|
516 |
+
temb_channels=self.temb_ch,
|
517 |
+
dropout=dropout,
|
518 |
+
)
|
519 |
+
)
|
520 |
+
block_in = block_out
|
521 |
+
if curr_res in attn_resolutions:
|
522 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
523 |
+
up = nn.Module()
|
524 |
+
up.block = block
|
525 |
+
up.attn = attn
|
526 |
+
if i_level != 0:
|
527 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
528 |
+
curr_res = curr_res * 2
|
529 |
+
self.up.insert(0, up) # prepend to get consistent order
|
530 |
+
|
531 |
+
# end
|
532 |
+
self.norm_out = Normalize(block_in)
|
533 |
+
self.conv_out = torch.nn.Conv2d(
|
534 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
535 |
+
)
|
536 |
+
|
537 |
+
def forward(self, z):
|
538 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
539 |
+
self.last_z_shape = z.shape
|
540 |
+
|
541 |
+
# timestep embedding
|
542 |
+
temb = None
|
543 |
+
|
544 |
+
# z to block_in
|
545 |
+
h = self.conv_in(z)
|
546 |
+
|
547 |
+
# middle
|
548 |
+
h = self.mid.block_1(h, temb)
|
549 |
+
h = self.mid.attn_1(h)
|
550 |
+
h = self.mid.block_2(h, temb)
|
551 |
+
|
552 |
+
# upsampling
|
553 |
+
for i_level in reversed(range(self.num_resolutions)):
|
554 |
+
for i_block in range(self.num_res_blocks + 1):
|
555 |
+
h = self.up[i_level].block[i_block](h, temb)
|
556 |
+
if len(self.up[i_level].attn) > 0:
|
557 |
+
h = self.up[i_level].attn[i_block](h)
|
558 |
+
if i_level != 0:
|
559 |
+
h = self.up[i_level].upsample(h)
|
560 |
+
|
561 |
+
# end
|
562 |
+
if self.give_pre_end:
|
563 |
+
return h
|
564 |
+
|
565 |
+
h = self.norm_out(h)
|
566 |
+
h = nonlinearity(h)
|
567 |
+
h = self.conv_out(h)
|
568 |
+
if self.tanh_out:
|
569 |
+
h = torch.tanh(h)
|
570 |
+
return h
|
571 |
+
|
572 |
+
|
573 |
+
class VQModel(nn.Module):
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
ddconfig,
|
577 |
+
n_embed,
|
578 |
+
embed_dim,
|
579 |
+
ckpt_path=None,
|
580 |
+
ignore_keys=[],
|
581 |
+
image_key="image",
|
582 |
+
colorize_nlabels=None,
|
583 |
+
monitor=None,
|
584 |
+
scheduler_config=None,
|
585 |
+
lr_g_factor=1.0,
|
586 |
+
remap=None,
|
587 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
self.image_key = image_key
|
591 |
+
self.encoder = Encoder(**ddconfig)
|
592 |
+
self.decoder = Decoder(**ddconfig)
|
593 |
+
self.quantize = VectorQuantizer(
|
594 |
+
n_embed,
|
595 |
+
embed_dim,
|
596 |
+
beta=0.25,
|
597 |
+
remap=remap,
|
598 |
+
sane_index_shape=sane_index_shape,
|
599 |
+
)
|
600 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
601 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
602 |
+
if ckpt_path is not None:
|
603 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
604 |
+
self.image_key = image_key
|
605 |
+
if colorize_nlabels is not None:
|
606 |
+
assert isinstance(colorize_nlabels, int)
|
607 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
608 |
+
if monitor is not None:
|
609 |
+
self.monitor = monitor
|
610 |
+
self.scheduler_config = scheduler_config
|
611 |
+
self.lr_g_factor = lr_g_factor
|
612 |
+
|
613 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
614 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
615 |
+
keys = list(sd.keys())
|
616 |
+
for k in keys:
|
617 |
+
for ik in ignore_keys:
|
618 |
+
if k.startswith(ik):
|
619 |
+
print("Deleting key {} from state_dict.".format(k))
|
620 |
+
del sd[k]
|
621 |
+
self.load_state_dict(sd, strict=False)
|
622 |
+
print(f"VQModel loaded from {path}")
|
623 |
+
|
624 |
+
def encode(self, x):
|
625 |
+
h = self.encoder(x)
|
626 |
+
h = self.quant_conv(h)
|
627 |
+
quant, emb_loss, info = self.quantize(h)
|
628 |
+
return quant, emb_loss, info
|
629 |
+
|
630 |
+
def decode(self, quant):
|
631 |
+
quant = self.post_quant_conv(quant)
|
632 |
+
dec = self.decoder(quant)
|
633 |
+
return dec
|
634 |
+
|
635 |
+
def decode_code(self, code_b):
|
636 |
+
quant_b = self.quantize.embed_code(code_b)
|
637 |
+
dec = self.decode(quant_b)
|
638 |
+
return dec
|
639 |
+
|
640 |
+
def forward(self, input):
|
641 |
+
quant, diff, _ = self.encode(input)
|
642 |
+
dec = self.decode(quant)
|
643 |
+
return dec, diff
|
644 |
+
|
645 |
+
def get_input(self, batch, k):
|
646 |
+
x = batch[k]
|
647 |
+
if len(x.shape) == 3:
|
648 |
+
x = x[..., None]
|
649 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
650 |
+
return x.float()
|
651 |
+
|
652 |
+
def get_last_layer(self):
|
653 |
+
return self.decoder.conv_out.weight
|
654 |
+
|
655 |
+
def log_images(self, batch, **kwargs):
|
656 |
+
log = dict()
|
657 |
+
x = self.get_input(batch, self.image_key)
|
658 |
+
x = x.to(self.device)
|
659 |
+
xrec, _ = self(x)
|
660 |
+
if x.shape[1] > 3:
|
661 |
+
# colorize with random projection
|
662 |
+
assert xrec.shape[1] > 3
|
663 |
+
x = self.to_rgb(x)
|
664 |
+
xrec = self.to_rgb(xrec)
|
665 |
+
log["inputs"] = x
|
666 |
+
log["reconstructions"] = xrec
|
667 |
+
return log
|
668 |
+
|
669 |
+
def to_rgb(self, x):
|
670 |
+
assert self.image_key == "segmentation"
|
671 |
+
if not hasattr(self, "colorize"):
|
672 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
673 |
+
x = F.conv2d(x, weight=self.colorize)
|
674 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
675 |
+
return x
|