from typing import Literal import gradio as gr import torch import numpy as np import colorsys import yaml from huggingface_hub import hf_hub_download from diffusers import VQModel from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel from chameleon.image_tokenizer import ImageTokenizer import torch.backends import torch.mps from PIL import Image import spaces Model = Literal["vqgan", "paella", "chameleon"] models = ["vqgan", "paella", "chameleon"] if torch.cuda.is_available(): device = torch.device("cuda") elif torch.backends.mps.is_available(): device = torch.device("mps") else: device = torch.device("cpu") class ImageRoundtripPipeline: def roundtrip_image(self, image, output_type="pil"): ... class VQImageRoundtripPipeline(ImageRoundtripPipeline): vqvae: VQModel vae_scale_factor: int vqvae_processor: VaeImageProcessor def __init__(self): self.vqvae = VQModel.from_pretrained("amused/amused-512", subfolder="vqvae") self.vqvae.eval() self.vqvae.to(device) self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) self.vqvae_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False ) print("VQ-GAN model loaded", self.vqvae) def roundtrip_image(self, image, output_type="pil"): image = self.vqvae_processor.preprocess(image) device = self.vqvae.device needs_upcasting = ( self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast ) batch_size, im_channels, height, width = image.shape if needs_upcasting: self.vqvae.float() latents = self.vqvae.encode( image.to(dtype=self.vqvae.dtype, device=device) ).latents latents_batch_size, latent_channels, latents_height, latents_width = ( latents.shape ) latents = self.vqvae.quantize(latents)[2][2].reshape( batch_size, latents_height, latents_width ) # replace 20% of latents with random values # random_latents = torch.randint( # 0, self.vqvae.config.num_vq_embeddings, latents.shape, device=device # ) # random_mask = torch.rand(latents.shape, device=device) < 0.2 # latents = torch.where(random_mask, random_latents, latents) output = self.vqvae.decode( latents, force_not_quantize=True, shape=( batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor, self.vqvae.config.latent_channels, ), ).sample.clip(0, 1) output = self.vqvae_processor.postprocess(output, output_type) if needs_upcasting: self.vqvae.half() return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline): tokenizer: ImageTokenizer n_embed: int vae_scale_factor: int def __init__(self): vqgan_path = hf_hub_download( "darknoon/chameleon-tokenizer", "tokenizer/vqgan.ckpt" ) vqgan_config_path = hf_hub_download( "darknoon/chameleon-tokenizer", "tokenizer/vqgan.yaml" ) self.tokenizer = ImageTokenizer( cfg_path=vqgan_config_path, ckpt_path=vqgan_path, device=device ) with open(vqgan_config_path) as f: vq_config = yaml.safe_load(f) self.n_embed = vq_config["model"]["params"]["n_embed"] self.vae_scale_factor = 16 print("Chameleon VQGan model loaded", self.tokenizer._vq_model, self.n_embed) def preprocess(self, image: Image): # copied from _vqgan_input_from np_img = np.array(image) / 255.0 # Normalize to [0, 1] np_img = np_img * 2 - 1 # Scale to [-1, 1] tensor_img = ( torch.from_numpy(np_img).permute(2, 0, 1).float() ) # (Channels, Height, Width) format. # Add batch dimension. return tensor_img.unsqueeze(0) def roundtrip_image(self, image, output_type="pil"): # image = self.tokenizer._vqgan_input_from(image).to(device) image = self.preprocess(image).to(device) _, _, [_, _, latents] = self.tokenizer._vq_model.encode(image) # emb_dim = self._vq_model.quantize.embedding.weight.shape[-1] output = self.tokenizer.pil_from_img_toks(latents) # we actually do want this to be a grid, sorry! latents = latents.reshape(1, 32, 32) return ( output, latents.cpu().numpy(), self.n_embed, ) class PaellaImageRoundtripPipeline(ImageRoundtripPipeline): vqgan: PaellaVQModel vae_scale_factor: int vqvae_processor: VaeImageProcessor def __init__(self): self.vqgan = PaellaVQModel.from_pretrained( "warp-ai/wuerstchen", subfolder="vqgan" ) self.vqgan.eval() self.vqgan.to(device) self.vae_scale_factor = 4 self.vqvae_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False ) print("Paella VQ-GAN model loaded", self.vqgan) def roundtrip_image(self, image, output_type="pil"): image = self.vqvae_processor.preprocess(image) device = self.vqgan.device batch_size, im_channels, height, width = image.shape latents = self.vqgan.encode( image.to(dtype=self.vqgan.dtype, device=device) ).latents latents_batch_size, latent_channels, latents_height, latents_width = ( latents.shape ) # latents = latents * self.vqgan.config.scale_factor # Manually quantize so we can inspect latents_q = self.vqgan.vquantizer(latents)[2][2].reshape( batch_size, latents_height, latents_width ) print("latents after quantize", (latents_q.shape, latents_q.dtype)) images = self.vqgan.decode(latents).sample.clamp(0, 1) output = self.vqvae_processor.postprocess(images, output_type) # if needs_upcasting: # self.vqgan.half() return output[0], latents_q.cpu().numpy(), self.vqgan.config.num_vq_embeddings pipeline_paella = PaellaImageRoundtripPipeline() pipeline_vq = VQImageRoundtripPipeline() pipeline_vq_chameleon = ChameleonVQImageRoundtripPipeline() # Function to generate a list of unique colors def generate_unique_colors_hsl(n): colors = [] for i in range(n): hue = i / (n // 4) # Distribute hues evenly around the color wheel 4 times lightness = 0.8 - (i / n) * 0.6 # Decrease brightness from 0.8 to 0.2 saturation = 1.0 rgb = colorsys.hls_to_rgb(hue, lightness, saturation) rgb = tuple(int(255 * x) for x in rgb) colors.append(rgb) return colors # Function to create the image from VQGAN tokens def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor): # Generate unique colors for each token in the codebook colors = generate_unique_colors_hsl(codebook_size) # Create a lookup table lookup_table = np.array(colors, dtype=np.uint8) # Extract the token array (remove the batch dimension) token_array = tokens[0] # Map tokens to their RGB colors using the lookup table color_image = lookup_table[token_array] # Create a PIL image from the numpy array img = Image.fromarray(color_image, "RGB") # Upscale the image using nearest neighbor interpolation img = img.resize( ( color_image.shape[1] * downscale_factor, color_image.shape[0] * downscale_factor, ), Image.NEAREST, ) return img def describe_shape(shape): return f"Shape: {shape} num elements: {np.prod(shape)}" def calc_psnr(img1: Image, img2: Image): if img1.size != img2.size: raise ValueError("Images must have the same dimensions") img1 = np.array(img1) img2 = np.array(img2) mse = np.mean((img1 - img2) ** 2) if mse == 0: return float("inf") return 2 * 10 * np.log10(255.0 / np.sqrt(mse)) @spaces.GPU(duration=32) @torch.no_grad() def roundtrip_image( image, model: Model, size: Literal["256x256", "512x512", "1024x1024"], output_type="pil", ): if size == "256x256": image = image.resize((256, 256)) elif size == "512x512": image = image.resize((512, 512)) elif size == "1024x1024": image = image.resize((1024, 1024)) else: raise ValueError(f"Unknown size {size}") image_orig = image if model == "vqgan": pipeline = pipeline_vq elif model == "paella": pipeline = pipeline_paella elif model == "chameleon": pipeline = pipeline_vq_chameleon else: raise ValueError(f"Unknown model {model}") image, latents, codebook_size = pipeline.roundtrip_image(image, output_type) return ( image, vqgan_tokens_to_image( latents, codebook_size, downscale_factor=pipeline.vae_scale_factor ), describe_shape(latents.shape), f"{calc_psnr(image_orig, image):.2f}", ) demo = gr.Interface( fn=roundtrip_image, inputs=[ gr.Image(type="pil"), gr.Dropdown(models, label="Model", value="vqgan"), gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"), ], outputs=[ gr.Image(label="Reconstructed", format="png"), gr.Image(label="Tokens", format="png"), gr.Text(label="VQ Shape"), gr.Text(label="PSNR"), ], title="Image Tokenizer Playground", description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.", ) demo.launch()