image-tokens / app.py
darknoon's picture
Allow non-512x512 with chameleon tokenizer
f9661fe
from typing import Literal
import gradio as gr
import torch
import numpy as np
import colorsys
import yaml
from huggingface_hub import hf_hub_download
from diffusers import VQModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.pipelines.wuerstchen.modeling_paella_vq_model import PaellaVQModel
from chameleon.image_tokenizer import ImageTokenizer
import torch.backends
import torch.mps
from PIL import Image
import spaces
Model = Literal["vqgan", "paella", "chameleon"]
models = ["vqgan", "paella", "chameleon"]
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
class ImageRoundtripPipeline:
def roundtrip_image(self, image, output_type="pil"): ...
class VQImageRoundtripPipeline(ImageRoundtripPipeline):
vqvae: VQModel
vae_scale_factor: int
vqvae_processor: VaeImageProcessor
def __init__(self):
self.vqvae = VQModel.from_pretrained("amused/amused-512", subfolder="vqvae")
self.vqvae.eval()
self.vqvae.to(device)
self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
self.vqvae_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False
)
print("VQ-GAN model loaded", self.vqvae)
def roundtrip_image(self, image, output_type="pil"):
image = self.vqvae_processor.preprocess(image)
device = self.vqvae.device
needs_upcasting = (
self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast
)
batch_size, im_channels, height, width = image.shape
if needs_upcasting:
self.vqvae.float()
latents = self.vqvae.encode(
image.to(dtype=self.vqvae.dtype, device=device)
).latents
latents_batch_size, latent_channels, latents_height, latents_width = (
latents.shape
)
latents = self.vqvae.quantize(latents)[2][2].reshape(
batch_size, latents_height, latents_width
)
# replace 20% of latents with random values
# random_latents = torch.randint(
# 0, self.vqvae.config.num_vq_embeddings, latents.shape, device=device
# )
# random_mask = torch.rand(latents.shape, device=device) < 0.2
# latents = torch.where(random_mask, random_latents, latents)
output = self.vqvae.decode(
latents,
force_not_quantize=True,
shape=(
batch_size,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
self.vqvae.config.latent_channels,
),
).sample.clip(0, 1)
output = self.vqvae_processor.postprocess(output, output_type)
if needs_upcasting:
self.vqvae.half()
return output[0], latents.cpu().numpy(), self.vqvae.config.num_vq_embeddings
class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline):
tokenizer: ImageTokenizer
n_embed: int
vae_scale_factor: int
def __init__(self):
vqgan_path = hf_hub_download(
"darknoon/chameleon-tokenizer", "tokenizer/vqgan.ckpt"
)
vqgan_config_path = hf_hub_download(
"darknoon/chameleon-tokenizer", "tokenizer/vqgan.yaml"
)
self.tokenizer = ImageTokenizer(
cfg_path=vqgan_config_path, ckpt_path=vqgan_path, device=device
)
with open(vqgan_config_path) as f:
vq_config = yaml.safe_load(f)
self.n_embed = vq_config["model"]["params"]["n_embed"]
self.vae_scale_factor = 16
print("Chameleon VQGan model loaded", self.tokenizer._vq_model, self.n_embed)
def preprocess(self, image: Image):
# copied from _vqgan_input_from
np_img = np.array(image) / 255.0 # Normalize to [0, 1]
np_img = np_img * 2 - 1 # Scale to [-1, 1]
tensor_img = (
torch.from_numpy(np_img).permute(2, 0, 1).float()
) # (Channels, Height, Width) format.
# Add batch dimension.
return tensor_img.unsqueeze(0)
def roundtrip_image(self, image, output_type="pil"):
# image = self.tokenizer._vqgan_input_from(image).to(device)
image = self.preprocess(image).to(device)
_, _, im_height, im_width = image.shape
_, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
scale = self.vae_scale_factor
shape = (1, im_height // scale, im_width // scale)
output = self.tokenizer.pil_from_img_toks(latents, shape=shape)
# we actually do want this to be a grid, sorry!
latents = latents.reshape(*shape)
return (
output,
latents.cpu().numpy(),
self.n_embed,
)
class PaellaImageRoundtripPipeline(ImageRoundtripPipeline):
vqgan: PaellaVQModel
vae_scale_factor: int
vqvae_processor: VaeImageProcessor
def __init__(self):
self.vqgan = PaellaVQModel.from_pretrained(
"warp-ai/wuerstchen", subfolder="vqgan"
)
self.vqgan.eval()
self.vqgan.to(device)
self.vae_scale_factor = 4
self.vqvae_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False
)
print("Paella VQ-GAN model loaded", self.vqgan)
def roundtrip_image(self, image, output_type="pil"):
image = self.vqvae_processor.preprocess(image)
device = self.vqgan.device
batch_size, im_channels, height, width = image.shape
latents = self.vqgan.encode(
image.to(dtype=self.vqgan.dtype, device=device)
).latents
latents_batch_size, latent_channels, latents_height, latents_width = (
latents.shape
)
# latents = latents * self.vqgan.config.scale_factor
# Manually quantize so we can inspect
latents_q = self.vqgan.vquantizer(latents)[2][2].reshape(
batch_size, latents_height, latents_width
)
print("latents after quantize", (latents_q.shape, latents_q.dtype))
images = self.vqgan.decode(latents).sample.clamp(0, 1)
output = self.vqvae_processor.postprocess(images, output_type)
# if needs_upcasting:
# self.vqgan.half()
return output[0], latents_q.cpu().numpy(), self.vqgan.config.num_vq_embeddings
pipeline_paella = PaellaImageRoundtripPipeline()
pipeline_vq = VQImageRoundtripPipeline()
pipeline_vq_chameleon = ChameleonVQImageRoundtripPipeline()
# Function to generate a list of unique colors
def generate_unique_colors_hsl(n):
colors = []
for i in range(n):
hue = i / (n // 4) # Distribute hues evenly around the color wheel 4 times
lightness = 0.8 - (i / n) * 0.6 # Decrease brightness from 0.8 to 0.2
saturation = 1.0
rgb = colorsys.hls_to_rgb(hue, lightness, saturation)
rgb = tuple(int(255 * x) for x in rgb)
colors.append(rgb)
return colors
# Function to create the image from VQGAN tokens
def vqgan_tokens_to_image(tokens, codebook_size, downscale_factor):
# Generate unique colors for each token in the codebook
colors = generate_unique_colors_hsl(codebook_size)
# Create a lookup table
lookup_table = np.array(colors, dtype=np.uint8)
# Extract the token array (remove the batch dimension)
token_array = tokens[0]
# Map tokens to their RGB colors using the lookup table
color_image = lookup_table[token_array]
# Create a PIL image from the numpy array
img = Image.fromarray(color_image, "RGB")
# Upscale the image using nearest neighbor interpolation
img = img.resize(
(
color_image.shape[1] * downscale_factor,
color_image.shape[0] * downscale_factor,
),
Image.NEAREST,
)
return img
def describe_shape(shape):
return f"Shape: {shape} num elements: {np.prod(shape)}"
def calc_psnr(img1: Image, img2: Image):
if img1.size != img2.size:
raise ValueError("Images must have the same dimensions")
img1 = np.array(img1)
img2 = np.array(img2)
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return float("inf")
return 2 * 10 * np.log10(255.0 / np.sqrt(mse))
@spaces.GPU(duration=32)
@torch.no_grad()
def roundtrip_image(
image,
model: Model,
size: Literal["256x256", "512x512", "1024x1024"],
output_type="pil",
):
if size == "256x256":
image = image.resize((256, 256))
elif size == "512x512":
image = image.resize((512, 512))
elif size == "1024x1024":
image = image.resize((1024, 1024))
else:
raise ValueError(f"Unknown size {size}")
image_orig = image
if model == "vqgan":
pipeline = pipeline_vq
elif model == "paella":
pipeline = pipeline_paella
elif model == "chameleon":
pipeline = pipeline_vq_chameleon
else:
raise ValueError(f"Unknown model {model}")
image, latents, codebook_size = pipeline.roundtrip_image(image, output_type)
return (
image,
vqgan_tokens_to_image(
latents, codebook_size, downscale_factor=pipeline.vae_scale_factor
),
describe_shape(latents.shape),
f"{calc_psnr(image_orig, image):.2f}",
)
demo = gr.Interface(
fn=roundtrip_image,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(models, label="Model", value="vqgan"),
gr.Dropdown(["256x256", "512x512", "1024x1024"], label="Size", value="512x512"),
],
outputs=[
gr.Image(label="Reconstructed", format="png"),
gr.Image(label="Tokens", format="png"),
gr.Text(label="VQ Shape"),
gr.Text(label="PSNR"),
],
title="Image Tokenizer Playground",
description="Round-trip an image through an encode-decoder pair to see the quality loss from the VQ-GAN for image generation, etc.",
)
demo.launch()