Spaces:
Sleeping
Sleeping
from typing import List, Literal | |
import gradio as gr | |
import torch | |
import numpy as np | |
import colorsys | |
from diffusers import VQModel | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel | |
from abc import abstractmethod | |
import torch.backends | |
import torch.mps | |
from PIL import Image | |
import spaces | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
elif torch.backends.mps.is_available(): | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
# abstract class VQImageRoundtripPipeline: | |
class ImageRoundtripPipeline: | |
def roundtrip_image(self, image, output_type="pil"): ... | |
class VQImageRoundtripPipeline(ImageRoundtripPipeline): | |
vqvae: VQModel | |
vae_scale_factor: int | |
vqvae_processor: VaeImageProcessor | |
def __init__(self): | |
self.vqvae = VQModel.from_pretrained("amused/amused-512", subfolder="vqvae") | |
self.vqvae.eval() | |
self.vqvae.to(device) | |
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) | |
self.vqvae_processor = VaeImageProcessor( | |
vae_scale_factor=self.vae_scale_factor, do_normalize=False | |
) | |
print("VQ-GAN model loaded", self.vqvae) | |
def roundtrip_image(self, image, output_type="pil"): | |
image = self.vqvae_processor.preprocess(image) | |
device = self.vqvae.device | |
needs_upcasting = ( | |
self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast | |
) | |
batch_size, im_channels, height, width = image.shape | |
if needs_upcasting: | |
self.vqvae.float() | |
latents = self.vqvae.encode( | |
image.to(dtype=self.vqvae.dtype, device=device) | |
).latents | |
latents_batch_size, latent_channels, latents_height, latents_width = ( | |
latents.shape | |
) | |
latents = self.vqvae.quantize(latents)[2][2].reshape( | |
batch_size, latents_height, latents_width | |
) | |
output = self.vqvae.decode( | |
latents, | |
force_not_quantize=True, | |
shape=( | |
batch_size, | |
height // self.vae_scale_factor, | |
width // self.vae_scale_factor, | |
self.vqvae.config.latent_channels, | |
), | |
).sample.clip(0, 1) | |
output = self.vqvae_processor.postprocess(output, output_type) | |
if needs_upcasting: | |
self.vqvae.half() | |
return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings | |
class PaellaImageRoundtripPipeline(ImageRoundtripPipeline): | |
vqgan: PaellaVQModel | |
vae_scale_factor: int | |
vqvae_processor: VaeImageProcessor | |
def __init__(self): | |
self.vqgan = PaellaVQModel.from_pretrained( | |
"warp-ai/wuerstchen", subfolder="vqgan" | |
) | |
self.vqgan.eval() | |
self.vqgan.to(device) | |
self.vae_scale_factor = 4 | |
self.vqvae_processor = VaeImageProcessor( | |
vae_scale_factor=self.vae_scale_factor, do_normalize=False | |
) | |
print("Paella VQ-GAN model loaded", self.vqgan) | |
def roundtrip_image(self, image, output_type="pil"): | |
image = self.vqvae_processor.preprocess(image) | |
device = self.vqgan.device | |
batch_size, im_channels, height, width = image.shape | |
latents = self.vqgan.encode( | |
image.to(dtype=self.vqgan.dtype, device=device) | |
).latents | |
latents_batch_size, latent_channels, latents_height, latents_width = ( | |
latents.shape | |
) | |
# latents = latents * self.vqgan.config.scale_factor | |
# Manually quantize so we can inspect | |
latents_q = self.vqgan.vquantizer(latents)[2][2].reshape( | |
batch_size, latents_height, latents_width | |
) | |
print("latents after quantize", (latents_q.shape, latents_q.dtype)) | |
images = self.vqgan.decode(latents).sample.clamp(0, 1) | |
output = self.vqvae_processor.postprocess(images, output_type) | |
# if needs_upcasting: | |
# self.vqgan.half() | |
return output[0], latents_q.cpu().numpy(), self.vqgan.config.num_vq_embeddings | |
pipeline_paella = PaellaImageRoundtripPipeline() | |
pipeline_vq = VQImageRoundtripPipeline() | |
# Function to generate a list of unique colors | |
def generate_unique_colors_hsl(n): | |
colors = [] | |
for i in range(n): | |
hue = i / (n // 4) # Distribute hues evenly around the color wheel 4 times | |
lightness = 0.8 - (i / n) * 0.6 # Decrease brightness from 0.8 to 0.2 | |
saturation = 1.0 | |
rgb = colorsys.hls_to_rgb(hue, lightness, saturation) | |
rgb = tuple(int(255 * x) for x in rgb) | |
colors.append(rgb) | |
return colors | |
# Function to create the image from VQGAN tokens | |
def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor): | |
# Generate unique colors for each token in the codebook | |
colors = generate_unique_colors_hsl(codebook_size) | |
# Create a lookup table | |
lookup_table = np.array(colors, dtype=np.uint8) | |
# Extract the token array (remove the batch dimension) | |
token_array = tokens[0] | |
# Map tokens to their RGB colors using the lookup table | |
color_image = lookup_table[token_array] | |
# Create a PIL image from the numpy array | |
img = Image.fromarray(color_image, "RGB") | |
# Upscale the image using nearest neighbor interpolation | |
img = img.resize( | |
( | |
color_image.shape[1] * downscale_factor, | |
color_image.shape[0] * downscale_factor, | |
), | |
Image.NEAREST, | |
) | |
return img | |
# This is a gradio space that lets you encode an image with various encoder-decoder pairs, eg VQ-GAN, SDXL's VAE, etc and check the image quality | |
# def image_grid_to_string(image_grid): | |
# """Convert a latent vq index "image" grid to a string, input shape is (1, height, width)""" | |
# return "\n".join( | |
# [" ".join([str(int(x)) for x in row]) for row in image_grid.squeeze()] | |
# ) | |
def describe_shape(shape): | |
return f"Shape: {shape} num elements: {np.prod(shape)}" | |
def roundtrip_image( | |
image, | |
model: List[Literal["vqgan", Literal["paella"]]], | |
size: List[Literal["256x256", "512x512", "1024x1024"]], | |
output_type="pil", | |
): | |
if size == "256x256": | |
image = image.resize((256, 256)) | |
elif size == "512x512": | |
image = image.resize((512, 512)) | |
elif size == "1024x1024": | |
image = image.resize((1024, 1024)) | |
else: | |
raise ValueError(f"Unknown size {size}") | |
if model == "vqgan": | |
image, latents, codebook_size = pipeline_vq.roundtrip_image(image, output_type) | |
return ( | |
image, | |
vqgan_tokens_to_image( | |
latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor | |
), | |
describe_shape(latents.shape), | |
) | |
elif model == "paella": | |
image, latents, codebook_size = pipeline_paella.roundtrip_image( | |
image, output_type | |
) | |
return ( | |
image, | |
vqgan_tokens_to_image( | |
latents, codebook_size, downscale_factor=pipeline_vq.vae_scale_factor | |
), | |
describe_shape(latents.shape), | |
) | |
else: | |
raise ValueError(f"Unknown model {model}") | |
demo = gr.Interface( | |
fn=roundtrip_image, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Dropdown(["vqgan", "paella"], label="Model", value="vqgan"), | |
gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"), | |
], | |
outputs=[ | |
gr.Image(label="Reconstructed"), | |
gr.Image(label="Tokens"), | |
gr.Text(label="VQ Shape"), | |
], | |
title="Image Tokenizer Playground", | |
description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.", | |
) | |
demo.launch() | |