darknoon commited on
Commit
77819e0
1 Parent(s): 674d65b

Add model from Chameleon

Browse files
Files changed (4) hide show
  1. app.py +100 -37
  2. chameleon/LICENSE +51 -0
  3. chameleon/image_tokenizer.py +124 -0
  4. chameleon/vqgan.py +675 -0
app.py CHANGED
@@ -1,18 +1,27 @@
1
- from typing import List, Literal
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
  import colorsys
 
6
 
 
7
  from diffusers import VQModel
8
  from diffusers.image_processor import VaeImageProcessor
9
  from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
10
- from abc import abstractmethod
 
 
11
  import torch.backends
12
  import torch.mps
13
  from PIL import Image
 
14
  import spaces
15
 
 
 
 
 
16
  if torch.cuda.is_available():
17
  device = torch.device("cuda")
18
  elif torch.backends.mps.is_available():
@@ -21,9 +30,7 @@ else:
21
  device = torch.device("cpu")
22
 
23
 
24
- # abstract class VQImageRoundtripPipeline:
25
  class ImageRoundtripPipeline:
26
- @abstractmethod
27
  def roundtrip_image(self, image, output_type="pil"): ...
28
 
29
 
@@ -63,6 +70,12 @@ class VQImageRoundtripPipeline(ImageRoundtripPipeline):
63
  latents = self.vqvae.quantize(latents)[2][2].reshape(
64
  batch_size, latents_height, latents_width
65
  )
 
 
 
 
 
 
66
  output = self.vqvae.decode(
67
  latents,
68
  force_not_quantize=True,
@@ -81,6 +94,55 @@ class VQImageRoundtripPipeline(ImageRoundtripPipeline):
81
  return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
85
  vqgan: PaellaVQModel
86
  vae_scale_factor: int
@@ -127,6 +189,7 @@ class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
127
 
128
  pipeline_paella = PaellaImageRoundtripPipeline()
129
  pipeline_vq = VQImageRoundtripPipeline()
 
130
 
131
 
132
  # Function to generate a list of unique colors
@@ -171,26 +234,27 @@ def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor):
171
  return img
172
 
173
 
174
- # This is a gradio space that lets you encode an image with various encoder-decoder pairs, eg VQ-GAN, SDXL's VAE, etc and check the image quality
175
-
176
-
177
- # def image_grid_to_string(image_grid):
178
- # """Convert a latent vq index "image" grid to a string, input shape is (1, height, width)"""
179
- # return "\n".join(
180
- # [" ".join([str(int(x)) for x in row]) for row in image_grid.squeeze()]
181
- # )
182
-
183
-
184
  def describe_shape(shape):
185
  return f"Shape: {shape} num elements: {np.prod(shape)}"
186
 
187
 
 
 
 
 
 
 
 
 
 
 
 
188
  @spaces.GPU(duration=32)
189
  @torch.no_grad()
190
  def roundtrip_image(
191
  image,
192
- model: List[Literal["vqgan", Literal["paella"]]],
193
- size: List[Literal["256x256", "512x512", "1024x1024"]],
194
  output_type="pil",
195
  ):
196
  if size == "256x256":
@@ -202,41 +266,40 @@ def roundtrip_image(
202
  else:
203
  raise ValueError(f"Unknown size {size}")
204
 
 
205
  if model == "vqgan":
206
- image, latents, codebook_size = pipeline_vq.roundtrip_image(image, output_type)
207
- return (
208
- image,
209
- vqgan_tokens_to_image(
210
- latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
211
- ),
212
- describe_shape(latents.shape),
213
- )
214
  elif model == "paella":
215
- image, latents, codebook_size = pipeline_paella.roundtrip_image(
216
- image, output_type
217
- )
218
- return (
219
- image,
220
- vqgan_tokens_to_image(
221
- latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor
222
- ),
223
- describe_shape(latents.shape),
224
- )
225
  else:
226
  raise ValueError(f"Unknown model {model}")
227
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  demo = gr.Interface(
230
  fn=roundtrip_image,
231
  inputs=[
232
  gr.Image(type="pil"),
233
- gr.Dropdown(["vqgan", "paella"], label="Model", value="vqgan"),
234
  gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
235
  ],
236
  outputs=[
237
- gr.Image(label="Reconstructed"),
238
- gr.Image(label="Tokens"),
239
  gr.Text(label="VQ Shape"),
 
240
  ],
241
  title="Image Tokenizer Playground",
242
  description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
 
1
+ from typing import Literal
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
  import colorsys
6
+ import yaml
7
 
8
+ from huggingface_hub import hf_hub_download
9
  from diffusers import VQModel
10
  from diffusers.image_processor import VaeImageProcessor
11
  from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
12
+
13
+ from chameleon.image_tokenizer import ImageTokenizer
14
+
15
  import torch.backends
16
  import torch.mps
17
  from PIL import Image
18
+
19
  import spaces
20
 
21
+
22
+ Model = Literal["vqgan", "paella", "chameleon"]
23
+ models = ["vqgan", "paella", "chameleon"]
24
+
25
  if torch.cuda.is_available():
26
  device = torch.device("cuda")
27
  elif torch.backends.mps.is_available():
 
30
  device = torch.device("cpu")
31
 
32
 
 
33
  class ImageRoundtripPipeline:
 
34
  def roundtrip_image(self, image, output_type="pil"): ...
35
 
36
 
 
70
  latents = self.vqvae.quantize(latents)[2][2].reshape(
71
  batch_size, latents_height, latents_width
72
  )
73
+ # replace 20% of latents with random values
74
+ # random_latents = torch.randint(
75
+ # 0, self.vqvae.config.num_vq_embeddings, latents.shape, device=device
76
+ # )
77
+ # random_mask = torch.rand(latents.shape, device=device) < 0.2
78
+ # latents = torch.where(random_mask, random_latents, latents)
79
  output = self.vqvae.decode(
80
  latents,
81
  force_not_quantize=True,
 
94
  return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
95
 
96
 
97
+ class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline):
98
+ tokenizer: ImageTokenizer
99
+ n_embed: int
100
+ vae_scale_factor: int
101
+
102
+ def __init__(self):
103
+ vqgan_path = hf_hub_download(
104
+ "darknoon/chameleon-tokenizer", "tokenizer/vqgan.ckpt"
105
+ )
106
+ vqgan_config_path = hf_hub_download(
107
+ "darknoon/chameleon-tokenizer", "tokenizer/vqgan.yaml"
108
+ )
109
+ self.tokenizer = ImageTokenizer(
110
+ cfg_path=vqgan_config_path, ckpt_path=vqgan_path, device=device
111
+ )
112
+ with open(vqgan_config_path) as f:
113
+ vq_config = yaml.safe_load(f)
114
+
115
+ self.n_embed = vq_config["model"]["params"]["n_embed"]
116
+ self.vae_scale_factor = 16
117
+ print("Chameleon VQGan model loaded", self.tokenizer._vq_model, self.n_embed)
118
+
119
+ def preprocess(self, image: Image):
120
+ # copied from _vqgan_input_from
121
+ np_img = np.array(image) / 255.0 # Normalize to [0, 1]
122
+ np_img = np_img * 2 - 1 # Scale to [-1, 1]
123
+ tensor_img = (
124
+ torch.from_numpy(np_img).permute(2, 0, 1).float()
125
+ ) # (Channels, Height, Width) format.
126
+
127
+ # Add batch dimension.
128
+ return tensor_img.unsqueeze(0)
129
+
130
+ def roundtrip_image(self, image, output_type="pil"):
131
+ # image = self.tokenizer._vqgan_input_from(image).to(device)
132
+ image = self.preprocess(image).to(device)
133
+ _, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
134
+ # emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
135
+ output = self.tokenizer.pil_from_img_toks(latents)
136
+ # we actually do want this to be a grid, sorry!
137
+ latents = latents.reshape(1, 32, 32)
138
+
139
+ return (
140
+ output,
141
+ latents.cpu().numpy(),
142
+ self.n_embed,
143
+ )
144
+
145
+
146
  class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
147
  vqgan: PaellaVQModel
148
  vae_scale_factor: int
 
189
 
190
  pipeline_paella = PaellaImageRoundtripPipeline()
191
  pipeline_vq = VQImageRoundtripPipeline()
192
+ pipeline_vq_chameleon = ChameleonVQImageRoundtripPipeline()
193
 
194
 
195
  # Function to generate a list of unique colors
 
234
  return img
235
 
236
 
 
 
 
 
 
 
 
 
 
 
237
  def describe_shape(shape):
238
  return f"Shape: {shape} num elements: {np.prod(shape)}"
239
 
240
 
241
+ def calc_psnr(img1: Image, img2: Image):
242
+ if img1.size != img2.size:
243
+ raise ValueError("Images must have the same dimensions")
244
+ img1 = np.array(img1)
245
+ img2 = np.array(img2)
246
+ mse = np.mean((img1 - img2) ** 2)
247
+ if mse == 0:
248
+ return float("inf")
249
+ return 2 * 10 * np.log10(255.0 / np.sqrt(mse))
250
+
251
+
252
  @spaces.GPU(duration=32)
253
  @torch.no_grad()
254
  def roundtrip_image(
255
  image,
256
+ model: Model,
257
+ size: Literal["256x256", "512x512", "1024x1024"],
258
  output_type="pil",
259
  ):
260
  if size == "256x256":
 
266
  else:
267
  raise ValueError(f"Unknown size {size}")
268
 
269
+ image_orig = image
270
  if model == "vqgan":
271
+ pipeline = pipeline_vq
 
 
 
 
 
 
 
272
  elif model == "paella":
273
+ pipeline = pipeline_paella
274
+ elif model == "chameleon":
275
+ pipeline = pipeline_vq_chameleon
 
 
 
 
 
 
 
276
  else:
277
  raise ValueError(f"Unknown model {model}")
278
 
279
+ image, latents, codebook_size = pipeline.roundtrip_image(image, output_type)
280
+
281
+ return (
282
+ image,
283
+ vqgan_tokens_to_image(
284
+ latents, codebook_size, downscale_factor=pipeline.vae_scale_factor
285
+ ),
286
+ describe_shape(latents.shape),
287
+ f"{calc_psnr(image_orig, image):.2f}",
288
+ )
289
+
290
 
291
  demo = gr.Interface(
292
  fn=roundtrip_image,
293
  inputs=[
294
  gr.Image(type="pil"),
295
+ gr.Dropdown(models, label="Model", value="vqgan"),
296
  gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
297
  ],
298
  outputs=[
299
+ gr.Image(label="Reconstructed", format="png"),
300
+ gr.Image(label="Tokens", format="png"),
301
  gr.Text(label="VQ Shape"),
302
+ gr.Text(label="PSNR"),
303
  ],
304
  title="Image Tokenizer Playground",
305
  description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
chameleon/LICENSE ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Chameleon Research License
2
+ Chameleon Version Release Date: June 18, 2024
3
+
4
+ This Chameleon Research License ("Agreement") contains the terms and conditions that govern your access and use of the Chameleon Materials (as defined below). You may not use the Chameleon Materials if you do not accept this Agreement. By clicking "I Accept" to accept, or accessing, using, or distributing any portion or element of the Chameleon Materials you hereby agree to be bound by the terms of this Agreement. If you are agreeing to be bound by the Agreement on behalf of your employer or other entity, you represent and warrant to Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) and Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland) ("Meta") that you have full legal authority to bind your employer or such entity to this Agreement. If you do not have requisite authority, you may not accept the Agreement or access the Chameleon Materials on behalf of your employer or other entity.
5
+
6
+ This Agreement is effective upon the earlier of the date that you first access the Chameleon Materials or accept this Agreement ("Effective Date"), and is entered into by and between Meta, and you, or if you are entering into this Agreement on behalf of your employer or other entity (if you are entering into this Agreement on such person or entity's behalf), of the age required under applicable laws, rules, or regulations to provide legal consent and, your employer or other entity and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf ("Licensee" or "You").
7
+
8
+ 1. Definitions.
9
+ 1. "Documentation" means the specifications, manuals and documentation accompanying Chameleon distributed by Meta at https://github.com/facebookresearch/chameleon and https://ai.meta.com/resources/models-and-libraries/chameleon-downloads/.
10
+
11
+
12
+ 2. "Noncommercial Research Uses" means noncommercial research use cases related to research, development, education, processing, or analysis and in each case, is not primarily intended for commercial advantage or monetary compensation to you or others.
13
+
14
+
15
+ 3. "Chameleon" means the models and software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, demonstration materials and other elements of the foregoing distributed by Meta at [INSERT RESOURCE HYPERLINK].
16
+
17
+
18
+ 4. "Chameleon Materials" means, collectively, Meta's proprietary Chameleon and Documentation (and any portion thereof) made available under this Agreement.
19
+
20
+
21
+ 5. "Trade Control Laws" means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
22
+
23
+
24
+ 6. "Acceptable Use Policy" means the Acceptable Use Policy applicable to Chameleon Materials ([INSERT Chameleon AUP HYPERLINK]) that is incorporated into this Agreement.
25
+
26
+
27
+ 2. License Rights and Redistribution. Subject to Your compliance with the terms and conditions of this Agreement, Meta hereby grants you the following:
28
+ 1. Grant of Rights. You are hereby granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta's intellectual property or other rights owned by Meta embodied in the Chameleon Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Chameleon Materials solely for Noncommercial Research Uses.
29
+ 2. Redistribution and Use.
30
+ 1. Distribution of Chameleon Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the Chameleon Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement. You shall also provide a copy of this Agreement to such third party.
31
+ 2. If you submit for publication the results of research you perform on, using, or otherwise in connection with Chameleon Materials, you must acknowledge the use of Chameleon Materials in your publication as follows (or an equivalent acknowledgement of your choosing): "This material is based on work supported by the Chameleon Research License, Copyright (c) Meta Platforms, Inc. All Rights Reserved."
32
+
33
+ 3. You must retain in all copies of the Chameleon Materials that you distribute and include the following attribution notice within a "Notice" text file distributed as a part of such copies: "Chameleon is licensed under the Chameleon Research License, Copyright (c) Meta Platforms, Inc. All Rights Reserved."
34
+ 4. Your use of the Chameleon Materials must comply with applicable laws and regulations (including Trade Control Laws) and adhere to the Acceptable Use Policy for the Chameleon Materials (https://ai.meta.com/resources/models-and-libraries/chameleon-use-policy/) which is hereby incorporated by reference into this Agreement.
35
+ 3. Restrictions. You will not, and will not permit, assist or cause any third party to:
36
+ 1. use the Chameleon Materials or any outputs or results of the Chameleon Materials in connection with any commercial uses or for any uses other than Noncommercial Research Uses;
37
+ 2. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Meta in connection with the Chameleon Materials, or to circumvent or remove any usage restrictions or other safety measures, or to enable functionality disabled by Meta;
38
+ 3. disguise your or their location through IP proxying or other methods;
39
+ 4. use or download Chameleon Materials if you or they are: (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) will use Chameleon Materials for any purpose prohibited by Trade Control Laws; or
40
+ 5. directly or indirectly export, re-export, provide, or otherwise transfer Chameleon Materials: (a) to any individual, entity, or country prohibited by Trade Control Laws; (b) to anyone on U.S. or non-U.S. government restricted parties lists; or (c) for any purpose prohibited by Trade Control Laws, including nuclear, chemical or biological weapons, or missile technology applications.
41
+ 4. User Support. Your Noncommercial Research Use of the Chameleon Materials is done at your own discretion; Meta does not provide any service in relation to such use. Meta is under no obligation to provide any support services for the Chameleon Materials. Any support provided is "as is", "with all faults", and without warranty of any kind.
42
+ 5. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE Chameleon MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE Chameleon MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE Chameleon MATERIALS AND ANY OUTPUT AND RESULTS.
43
+ 6. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 7. Intellectual Property.
45
+ 1. No trademark licenses are granted under this Agreement, and in connection with the Chameleon Materials, neither Meta nor Licensee may use any name or mark owned by or associated with the other or any of its affiliates, except as required for reasonable and customary use in describing and redistributing the Chameleon Materials.
46
+ 2. Subject to Meta's ownership of Chameleon Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the Chameleon Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications.
47
+ 3. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Chameleon Materials or Chameleon outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses and rights granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the Chameleon Materials.
48
+ 8. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the Chameleon Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the Chameleon Materials. Sections 3, 4, 5, 6(c), 7, 8 and 9 shall survive the termination of this Agreement.
49
+ 9. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement.
50
+ 10. Modifications and Amendments. Meta may modify this Agreement from time to time by posting a revised version at https://ai.meta.com/resources/models-and-libraries/chameleon-license/
51
+ 11. ; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the Chameleon Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no other modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta.
chameleon/image_tokenizer.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import yaml
10
+ from PIL import Image
11
+
12
+ from .vqgan import VQModel
13
+
14
+
15
+ class ImageTokenizer:
16
+ def __init__(
17
+ self,
18
+ cfg_path: str,
19
+ ckpt_path: str,
20
+ device: str | torch.device | None = None,
21
+ ):
22
+ with open(cfg_path) as f:
23
+ config = yaml.safe_load(f)
24
+
25
+ params = config["model"]["params"]
26
+ if "lossconfig" in params:
27
+ del params["lossconfig"]
28
+ params["ckpt_path"] = ckpt_path
29
+
30
+ self._vq_model = VQModel(**params)
31
+ self._vq_model.eval()
32
+
33
+ if device is None:
34
+ devices = {p.device for p in self._vq_model.parameters()}
35
+ assert len(devices) == 1
36
+ device = devices.pop()
37
+ else:
38
+ self._vq_model.to(device)
39
+ self._device = device
40
+
41
+ dtypes = {p.dtype for p in self._vq_model.parameters()}
42
+ assert len(dtypes) == 1
43
+ self._dtype = dtypes.pop()
44
+
45
+ def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
46
+ # Check if it's already in RGB format.
47
+ if img.mode == "RGB":
48
+ return img
49
+
50
+ vals_rgba = np.array(img.convert("RGBA"))
51
+
52
+ # If there is no transparency layer, simple convert and return.
53
+ if not (vals_rgba[:, :, 3] < 255).any():
54
+ return img.convert("RGB")
55
+
56
+ # There is a transparency layer, blend it with a white background.
57
+
58
+ # Calculate the alpha proportion for blending.
59
+ alpha = vals_rgba[:, :, 3] / 255.0
60
+ # Blend with white background.
61
+ vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
62
+ :, :, np.newaxis
63
+ ] * vals_rgba[:, :, :3]
64
+ return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
65
+
66
+ def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
67
+ # Resize with aspect ratio preservation.
68
+ s = min(img.size)
69
+ scale = target_image_size / s
70
+ new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
71
+ img = img.resize(new_size, PIL.Image.LANCZOS)
72
+
73
+ # Center crop.
74
+ x0 = (img.width - target_image_size) // 2
75
+ y0 = (img.height - target_image_size) // 2
76
+ img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
77
+
78
+ # Convert to tensor.
79
+ np_img = np.array(img) / 255.0 # Normalize to [0, 1]
80
+ np_img = np_img * 2 - 1 # Scale to [-1, 1]
81
+ tensor_img = (
82
+ torch.from_numpy(np_img).permute(2, 0, 1).float()
83
+ ) # (Channels, Height, Width) format.
84
+
85
+ # Add batch dimension.
86
+ return tensor_img.unsqueeze(0)
87
+
88
+ def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
89
+ image = self._whiten_transparency(image)
90
+ vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
91
+ _, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
92
+ return img_toks
93
+
94
+ def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
95
+ # Ensure detachment and move tensor to CPU.
96
+ detached_chw_tensor = chw_tensor.detach().cpu()
97
+
98
+ # Normalize tensor to [0, 1] range from [-1, 1] range.
99
+ normalized_chw_tensor = (
100
+ torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
101
+ ) / 2.0
102
+
103
+ # Permute CHW tensor to HWC format and convert to NumPy array.
104
+ hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
105
+
106
+ # Convert to an 8-bit unsigned integer format.
107
+ image_array_uint8 = (hwc_array * 255).astype(np.uint8)
108
+
109
+ # Convert NumPy array to PIL Image.
110
+ pil_image = Image.fromarray(image_array_uint8)
111
+
112
+ # Convert image to RGB if it is not already.
113
+ if pil_image.mode != "RGB":
114
+ pil_image = pil_image.convert("RGB")
115
+
116
+ return pil_image
117
+
118
+ def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
119
+ emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
120
+ codebook_entry = self._vq_model.quantize.get_codebook_entry(
121
+ img_tensor, (1, 32, 32, emb_dim)
122
+ )
123
+ pixels = self._vq_model.decode(codebook_entry)
124
+ return self._pil_from_chw_tensor(pixels[0])
chameleon/vqgan.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
8
+ [with minimal dependencies]
9
+
10
+ This implementation is inference-only -- training steps and optimizer components
11
+ introduce significant additional dependencies
12
+ """
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class VectorQuantizer2(nn.Module):
21
+ """
22
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
23
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
24
+ """
25
+
26
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
27
+ # backwards compatibility we use the buggy version by default, but you can
28
+ # specify legacy=False to fix it.
29
+ def __init__(
30
+ self,
31
+ n_e,
32
+ e_dim,
33
+ beta,
34
+ remap=None,
35
+ unknown_index="random",
36
+ sane_index_shape=False,
37
+ legacy=True,
38
+ ):
39
+ super().__init__()
40
+ self.n_e = n_e
41
+ self.e_dim = e_dim
42
+ self.beta = beta
43
+ self.legacy = legacy
44
+
45
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
46
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
47
+
48
+ self.remap = remap
49
+ if self.remap is not None:
50
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
51
+ self.re_embed = self.used.shape[0]
52
+ self.unknown_index = unknown_index # "random" or "extra" or integer
53
+ if self.unknown_index == "extra":
54
+ self.unknown_index = self.re_embed
55
+ self.re_embed = self.re_embed + 1
56
+ print(
57
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
58
+ f"Using {self.unknown_index} for unknown indices."
59
+ )
60
+ else:
61
+ self.re_embed = n_e
62
+
63
+ self.sane_index_shape = sane_index_shape
64
+
65
+ def remap_to_used(self, inds):
66
+ ishape = inds.shape
67
+ assert len(ishape) > 1
68
+ inds = inds.reshape(ishape[0], -1)
69
+ used = self.used.to(inds)
70
+ match = (inds[:, :, None] == used[None, None, ...]).long()
71
+ new = match.argmax(-1)
72
+ unknown = match.sum(2) < 1
73
+ if self.unknown_index == "random":
74
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
75
+ device=new.device
76
+ )
77
+ else:
78
+ new[unknown] = self.unknown_index
79
+ return new.reshape(ishape)
80
+
81
+ def unmap_to_all(self, inds):
82
+ ishape = inds.shape
83
+ assert len(ishape) > 1
84
+ inds = inds.reshape(ishape[0], -1)
85
+ used = self.used.to(inds)
86
+ if self.re_embed > self.used.shape[0]: # extra token
87
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
88
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
89
+ return back.reshape(ishape)
90
+
91
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
92
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
93
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
94
+ assert return_logits is False, "Only for interface compatible with Gumbel"
95
+ # reshape z -> (batch, height, width, channel) and flatten
96
+ z = z.permute(0, 2, 3, 1).contiguous()
97
+ z_flattened = z.view(-1, self.e_dim)
98
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
99
+
100
+ d = (
101
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
102
+ + torch.sum(self.embedding.weight**2, dim=1)
103
+ - 2
104
+ * torch.einsum(
105
+ "bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
106
+ )
107
+ )
108
+
109
+ min_encoding_indices = torch.argmin(d, dim=1)
110
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
111
+ perplexity = None
112
+ min_encodings = None
113
+
114
+ # compute loss for embedding
115
+ if not self.legacy:
116
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
117
+ (z_q - z.detach()) ** 2
118
+ )
119
+ else:
120
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
121
+ (z_q - z.detach()) ** 2
122
+ )
123
+
124
+ # preserve gradients
125
+ z_q = z + (z_q - z).detach()
126
+
127
+ # reshape back to match original input shape
128
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
129
+
130
+ if self.remap is not None:
131
+ min_encoding_indices = min_encoding_indices.reshape(
132
+ z.shape[0], -1
133
+ ) # add batch axis
134
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
135
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
136
+
137
+ if self.sane_index_shape:
138
+ min_encoding_indices = min_encoding_indices.reshape(
139
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
140
+ )
141
+
142
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
143
+
144
+ def get_codebook_entry(self, indices, shape):
145
+ # shape specifying (batch, height, width, channel)
146
+ if self.remap is not None:
147
+ indices = indices.reshape(shape[0], -1) # add batch axis
148
+ indices = self.unmap_to_all(indices)
149
+ indices = indices.reshape(-1) # flatten again
150
+
151
+ # get quantized latent vectors
152
+ z_q = self.embedding(indices)
153
+
154
+ if shape is not None:
155
+ z_q = z_q.view(shape)
156
+ # reshape back to match original input shape
157
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
158
+
159
+ return z_q
160
+
161
+
162
+ # Alias
163
+ VectorQuantizer = VectorQuantizer2
164
+
165
+
166
+ def nonlinearity(x):
167
+ # swish
168
+ return x * torch.sigmoid(x)
169
+
170
+
171
+ def Normalize(in_channels, num_groups=32):
172
+ return torch.nn.GroupNorm(
173
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
174
+ )
175
+
176
+
177
+ class Upsample(nn.Module):
178
+ def __init__(self, in_channels, with_conv):
179
+ super().__init__()
180
+ self.with_conv = with_conv
181
+ if self.with_conv:
182
+ self.conv = torch.nn.Conv2d(
183
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
184
+ )
185
+
186
+ def forward(self, x):
187
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
188
+ if self.with_conv:
189
+ x = self.conv(x)
190
+ return x
191
+
192
+
193
+ class Downsample(nn.Module):
194
+ def __init__(self, in_channels, with_conv):
195
+ super().__init__()
196
+ self.with_conv = with_conv
197
+ if self.with_conv:
198
+ # no asymmetric padding in torch conv, must do it ourselves
199
+ self.conv = torch.nn.Conv2d(
200
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.with_conv:
205
+ pad = (0, 1, 0, 1)
206
+ x = F.pad(x, pad, mode="constant", value=0)
207
+ x = self.conv(x)
208
+ else:
209
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
210
+ return x
211
+
212
+
213
+ class ResnetBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ *,
217
+ in_channels,
218
+ out_channels=None,
219
+ conv_shortcut=False,
220
+ dropout,
221
+ temb_channels=512,
222
+ ):
223
+ super().__init__()
224
+ self.in_channels = in_channels
225
+ out_channels = in_channels if out_channels is None else out_channels
226
+ self.out_channels = out_channels
227
+ self.use_conv_shortcut = conv_shortcut
228
+
229
+ self.norm1 = Normalize(in_channels)
230
+ self.conv1 = torch.nn.Conv2d(
231
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
232
+ )
233
+ if temb_channels > 0:
234
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
235
+ self.norm2 = Normalize(out_channels)
236
+ self.dropout = torch.nn.Dropout(dropout)
237
+ self.conv2 = torch.nn.Conv2d(
238
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
239
+ )
240
+ if self.in_channels != self.out_channels:
241
+ if self.use_conv_shortcut:
242
+ self.conv_shortcut = torch.nn.Conv2d(
243
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
244
+ )
245
+ else:
246
+ self.nin_shortcut = torch.nn.Conv2d(
247
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
248
+ )
249
+
250
+ def forward(self, x, temb):
251
+ h = x
252
+ h = self.norm1(h)
253
+ h = nonlinearity(h)
254
+ h = self.conv1(h)
255
+
256
+ if temb is not None:
257
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
258
+
259
+ h = self.norm2(h)
260
+ h = nonlinearity(h)
261
+ h = self.dropout(h)
262
+ h = self.conv2(h)
263
+
264
+ if self.in_channels != self.out_channels:
265
+ if self.use_conv_shortcut:
266
+ x = self.conv_shortcut(x)
267
+ else:
268
+ x = self.nin_shortcut(x)
269
+
270
+ return x + h
271
+
272
+
273
+ class AttnBlock(nn.Module):
274
+ def __init__(self, in_channels):
275
+ super().__init__()
276
+ self.in_channels = in_channels
277
+
278
+ self.norm = Normalize(in_channels)
279
+ self.q = torch.nn.Conv2d(
280
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
281
+ )
282
+ self.k = torch.nn.Conv2d(
283
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
284
+ )
285
+ self.v = torch.nn.Conv2d(
286
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
287
+ )
288
+ self.proj_out = torch.nn.Conv2d(
289
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
290
+ )
291
+
292
+ def forward(self, x):
293
+ h_ = x
294
+ h_ = self.norm(h_)
295
+ q = self.q(h_)
296
+ k = self.k(h_)
297
+ v = self.v(h_)
298
+
299
+ # compute attention
300
+ b, c, h, w = q.shape
301
+ q = q.reshape(b, c, h * w)
302
+ q = q.permute(0, 2, 1) # b,hw,c
303
+ k = k.reshape(b, c, h * w) # b,c,hw
304
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
305
+ w_ = w_ * (int(c) ** (-0.5))
306
+ w_ = F.softmax(w_, dim=2)
307
+
308
+ # attend to values
309
+ v = v.reshape(b, c, h * w)
310
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
311
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
312
+ h_ = h_.reshape(b, c, h, w)
313
+
314
+ h_ = self.proj_out(h_)
315
+
316
+ return x + h_
317
+
318
+
319
+ def make_attn(in_channels, attn_type="vanilla"):
320
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
321
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
322
+ if attn_type == "vanilla":
323
+ return AttnBlock(in_channels)
324
+ elif attn_type == "none":
325
+ return nn.Identity(in_channels)
326
+ else:
327
+ raise ValueError("Unexpected attention type")
328
+
329
+
330
+ class Encoder(nn.Module):
331
+ def __init__(
332
+ self,
333
+ *,
334
+ ch,
335
+ out_ch,
336
+ ch_mult=(1, 2, 4, 8),
337
+ num_res_blocks,
338
+ attn_resolutions,
339
+ dropout=0.0,
340
+ resamp_with_conv=True,
341
+ in_channels,
342
+ resolution,
343
+ z_channels,
344
+ double_z=True,
345
+ use_linear_attn=False,
346
+ attn_type="vanilla",
347
+ **ignore_kwargs,
348
+ ):
349
+ super().__init__()
350
+ if use_linear_attn:
351
+ attn_type = "linear"
352
+ self.ch = ch
353
+ self.temb_ch = 0
354
+ self.num_resolutions = len(ch_mult)
355
+ self.num_res_blocks = num_res_blocks
356
+ self.resolution = resolution
357
+ self.in_channels = in_channels
358
+
359
+ # downsampling
360
+ self.conv_in = torch.nn.Conv2d(
361
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
362
+ )
363
+
364
+ curr_res = resolution
365
+ in_ch_mult = (1,) + tuple(ch_mult)
366
+ self.in_ch_mult = in_ch_mult
367
+ self.down = nn.ModuleList()
368
+ for i_level in range(self.num_resolutions):
369
+ block = nn.ModuleList()
370
+ attn = nn.ModuleList()
371
+ block_in = ch * in_ch_mult[i_level]
372
+ block_out = ch * ch_mult[i_level]
373
+ for i_block in range(self.num_res_blocks):
374
+ block.append(
375
+ ResnetBlock(
376
+ in_channels=block_in,
377
+ out_channels=block_out,
378
+ temb_channels=self.temb_ch,
379
+ dropout=dropout,
380
+ )
381
+ )
382
+ block_in = block_out
383
+ if curr_res in attn_resolutions:
384
+ attn.append(make_attn(block_in, attn_type=attn_type))
385
+ down = nn.Module()
386
+ down.block = block
387
+ down.attn = attn
388
+ if i_level != self.num_resolutions - 1:
389
+ down.downsample = Downsample(block_in, resamp_with_conv)
390
+ curr_res = curr_res // 2
391
+ self.down.append(down)
392
+
393
+ # middle
394
+ self.mid = nn.Module()
395
+ self.mid.block_1 = ResnetBlock(
396
+ in_channels=block_in,
397
+ out_channels=block_in,
398
+ temb_channels=self.temb_ch,
399
+ dropout=dropout,
400
+ )
401
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
402
+ self.mid.block_2 = ResnetBlock(
403
+ in_channels=block_in,
404
+ out_channels=block_in,
405
+ temb_channels=self.temb_ch,
406
+ dropout=dropout,
407
+ )
408
+
409
+ # end
410
+ self.norm_out = Normalize(block_in)
411
+ self.conv_out = torch.nn.Conv2d(
412
+ block_in,
413
+ 2 * z_channels if double_z else z_channels,
414
+ kernel_size=3,
415
+ stride=1,
416
+ padding=1,
417
+ )
418
+
419
+ def forward(self, x):
420
+ # timestep embedding
421
+ temb = None
422
+
423
+ # downsampling
424
+ hs = [self.conv_in(x)]
425
+ for i_level in range(self.num_resolutions):
426
+ for i_block in range(self.num_res_blocks):
427
+ h = self.down[i_level].block[i_block](hs[-1], temb)
428
+ if len(self.down[i_level].attn) > 0:
429
+ h = self.down[i_level].attn[i_block](h)
430
+ hs.append(h)
431
+ if i_level != self.num_resolutions - 1:
432
+ hs.append(self.down[i_level].downsample(hs[-1]))
433
+
434
+ # middle
435
+ h = hs[-1]
436
+ h = self.mid.block_1(h, temb)
437
+ h = self.mid.attn_1(h)
438
+ h = self.mid.block_2(h, temb)
439
+
440
+ # end
441
+ h = self.norm_out(h)
442
+ h = nonlinearity(h)
443
+ h = self.conv_out(h)
444
+ return h
445
+
446
+
447
+ class Decoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ *,
451
+ ch,
452
+ out_ch,
453
+ ch_mult=(1, 2, 4, 8),
454
+ num_res_blocks,
455
+ attn_resolutions,
456
+ dropout=0.0,
457
+ resamp_with_conv=True,
458
+ in_channels,
459
+ resolution,
460
+ z_channels,
461
+ give_pre_end=False,
462
+ tanh_out=False,
463
+ use_linear_attn=False,
464
+ attn_type="vanilla",
465
+ **ignorekwargs,
466
+ ):
467
+ super().__init__()
468
+ if use_linear_attn:
469
+ attn_type = "linear"
470
+ self.ch = ch
471
+ self.temb_ch = 0
472
+ self.num_resolutions = len(ch_mult)
473
+ self.num_res_blocks = num_res_blocks
474
+ self.resolution = resolution
475
+ self.in_channels = in_channels
476
+ self.give_pre_end = give_pre_end
477
+ self.tanh_out = tanh_out
478
+
479
+ # compute in_ch_mult, block_in and curr_res at lowest res
480
+ block_in = ch * ch_mult[self.num_resolutions - 1]
481
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
482
+ self.z_shape = (1, z_channels, curr_res, curr_res)
483
+
484
+ # z to block_in
485
+ self.conv_in = torch.nn.Conv2d(
486
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
487
+ )
488
+
489
+ # middle
490
+ self.mid = nn.Module()
491
+ self.mid.block_1 = ResnetBlock(
492
+ in_channels=block_in,
493
+ out_channels=block_in,
494
+ temb_channels=self.temb_ch,
495
+ dropout=dropout,
496
+ )
497
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
498
+ self.mid.block_2 = ResnetBlock(
499
+ in_channels=block_in,
500
+ out_channels=block_in,
501
+ temb_channels=self.temb_ch,
502
+ dropout=dropout,
503
+ )
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch * ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks + 1):
512
+ block.append(
513
+ ResnetBlock(
514
+ in_channels=block_in,
515
+ out_channels=block_out,
516
+ temb_channels=self.temb_ch,
517
+ dropout=dropout,
518
+ )
519
+ )
520
+ block_in = block_out
521
+ if curr_res in attn_resolutions:
522
+ attn.append(make_attn(block_in, attn_type=attn_type))
523
+ up = nn.Module()
524
+ up.block = block
525
+ up.attn = attn
526
+ if i_level != 0:
527
+ up.upsample = Upsample(block_in, resamp_with_conv)
528
+ curr_res = curr_res * 2
529
+ self.up.insert(0, up) # prepend to get consistent order
530
+
531
+ # end
532
+ self.norm_out = Normalize(block_in)
533
+ self.conv_out = torch.nn.Conv2d(
534
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
535
+ )
536
+
537
+ def forward(self, z):
538
+ # assert z.shape[1:] == self.z_shape[1:]
539
+ self.last_z_shape = z.shape
540
+
541
+ # timestep embedding
542
+ temb = None
543
+
544
+ # z to block_in
545
+ h = self.conv_in(z)
546
+
547
+ # middle
548
+ h = self.mid.block_1(h, temb)
549
+ h = self.mid.attn_1(h)
550
+ h = self.mid.block_2(h, temb)
551
+
552
+ # upsampling
553
+ for i_level in reversed(range(self.num_resolutions)):
554
+ for i_block in range(self.num_res_blocks + 1):
555
+ h = self.up[i_level].block[i_block](h, temb)
556
+ if len(self.up[i_level].attn) > 0:
557
+ h = self.up[i_level].attn[i_block](h)
558
+ if i_level != 0:
559
+ h = self.up[i_level].upsample(h)
560
+
561
+ # end
562
+ if self.give_pre_end:
563
+ return h
564
+
565
+ h = self.norm_out(h)
566
+ h = nonlinearity(h)
567
+ h = self.conv_out(h)
568
+ if self.tanh_out:
569
+ h = torch.tanh(h)
570
+ return h
571
+
572
+
573
+ class VQModel(nn.Module):
574
+ def __init__(
575
+ self,
576
+ ddconfig,
577
+ n_embed,
578
+ embed_dim,
579
+ ckpt_path=None,
580
+ ignore_keys=[],
581
+ image_key="image",
582
+ colorize_nlabels=None,
583
+ monitor=None,
584
+ scheduler_config=None,
585
+ lr_g_factor=1.0,
586
+ remap=None,
587
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
588
+ ):
589
+ super().__init__()
590
+ self.image_key = image_key
591
+ self.encoder = Encoder(**ddconfig)
592
+ self.decoder = Decoder(**ddconfig)
593
+ self.quantize = VectorQuantizer(
594
+ n_embed,
595
+ embed_dim,
596
+ beta=0.25,
597
+ remap=remap,
598
+ sane_index_shape=sane_index_shape,
599
+ )
600
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
601
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
602
+ if ckpt_path is not None:
603
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
604
+ self.image_key = image_key
605
+ if colorize_nlabels is not None:
606
+ assert isinstance(colorize_nlabels, int)
607
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
608
+ if monitor is not None:
609
+ self.monitor = monitor
610
+ self.scheduler_config = scheduler_config
611
+ self.lr_g_factor = lr_g_factor
612
+
613
+ def init_from_ckpt(self, path, ignore_keys=list()):
614
+ sd = torch.load(path, map_location="cpu")["state_dict"]
615
+ keys = list(sd.keys())
616
+ for k in keys:
617
+ for ik in ignore_keys:
618
+ if k.startswith(ik):
619
+ print("Deleting key {} from state_dict.".format(k))
620
+ del sd[k]
621
+ self.load_state_dict(sd, strict=False)
622
+ print(f"VQModel loaded from {path}")
623
+
624
+ def encode(self, x):
625
+ h = self.encoder(x)
626
+ h = self.quant_conv(h)
627
+ quant, emb_loss, info = self.quantize(h)
628
+ return quant, emb_loss, info
629
+
630
+ def decode(self, quant):
631
+ quant = self.post_quant_conv(quant)
632
+ dec = self.decoder(quant)
633
+ return dec
634
+
635
+ def decode_code(self, code_b):
636
+ quant_b = self.quantize.embed_code(code_b)
637
+ dec = self.decode(quant_b)
638
+ return dec
639
+
640
+ def forward(self, input):
641
+ quant, diff, _ = self.encode(input)
642
+ dec = self.decode(quant)
643
+ return dec, diff
644
+
645
+ def get_input(self, batch, k):
646
+ x = batch[k]
647
+ if len(x.shape) == 3:
648
+ x = x[..., None]
649
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
650
+ return x.float()
651
+
652
+ def get_last_layer(self):
653
+ return self.decoder.conv_out.weight
654
+
655
+ def log_images(self, batch, **kwargs):
656
+ log = dict()
657
+ x = self.get_input(batch, self.image_key)
658
+ x = x.to(self.device)
659
+ xrec, _ = self(x)
660
+ if x.shape[1] > 3:
661
+ # colorize with random projection
662
+ assert xrec.shape[1] > 3
663
+ x = self.to_rgb(x)
664
+ xrec = self.to_rgb(xrec)
665
+ log["inputs"] = x
666
+ log["reconstructions"] = xrec
667
+ return log
668
+
669
+ def to_rgb(self, x):
670
+ assert self.image_key == "segmentation"
671
+ if not hasattr(self, "colorize"):
672
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
673
+ x = F.conv2d(x, weight=self.colorize)
674
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
675
+ return x