# Copyright (c) Meta Platforms, Inc. and affiliates # # This source code is licensed under the Chameleon License found in the # LICENSE file in the root directory of this source tree. import numpy as np import PIL import torch import yaml from PIL import Image from .vqgan import VQModel class ImageTokenizer: def __init__( self, cfg_path: str, ckpt_path: str, device: str | torch.device | None = None, ): with open(cfg_path) as f: config = yaml.safe_load(f) params = config["model"]["params"] if "lossconfig" in params: del params["lossconfig"] params["ckpt_path"] = ckpt_path self._vq_model = VQModel(**params) self._vq_model.eval() if device is None: devices = {p.device for p in self._vq_model.parameters()} assert len(devices) == 1 device = devices.pop() else: self._vq_model.to(device) self._device = device dtypes = {p.dtype for p in self._vq_model.parameters()} assert len(dtypes) == 1 self._dtype = dtypes.pop() def _whiten_transparency(self, img: PIL.Image) -> PIL.Image: # Check if it's already in RGB format. if img.mode == "RGB": return img vals_rgba = np.array(img.convert("RGBA")) # If there is no transparency layer, simple convert and return. if not (vals_rgba[:, :, 3] < 255).any(): return img.convert("RGB") # There is a transparency layer, blend it with a white background. # Calculate the alpha proportion for blending. alpha = vals_rgba[:, :, 3] / 255.0 # Blend with white background. vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[ :, :, np.newaxis ] * vals_rgba[:, :, :3] return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB") def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor: # Resize with aspect ratio preservation. s = min(img.size) scale = target_image_size / s new_size = (round(scale * img.size[0]), round(scale * img.size[1])) img = img.resize(new_size, PIL.Image.LANCZOS) # Center crop. x0 = (img.width - target_image_size) // 2 y0 = (img.height - target_image_size) // 2 img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size)) # Convert to tensor. np_img = np.array(img) / 255.0 # Normalize to [0, 1] np_img = np_img * 2 - 1 # Scale to [-1, 1] tensor_img = ( torch.from_numpy(np_img).permute(2, 0, 1).float() ) # (Channels, Height, Width) format. # Add batch dimension. return tensor_img.unsqueeze(0) def img_tokens_from_pil(self, image: PIL.Image) -> list[int]: image = self._whiten_transparency(image) vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype) _, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input) return img_toks def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image: # Ensure detachment and move tensor to CPU. detached_chw_tensor = chw_tensor.detach().cpu() # Normalize tensor to [0, 1] range from [-1, 1] range. normalized_chw_tensor = ( torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0 ) / 2.0 # Permute CHW tensor to HWC format and convert to NumPy array. hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy() # Convert to an 8-bit unsigned integer format. image_array_uint8 = (hwc_array * 255).astype(np.uint8) # Convert NumPy array to PIL Image. pil_image = Image.fromarray(image_array_uint8) # Convert image to RGB if it is not already. if pil_image.mode != "RGB": pil_image = pil_image.convert("RGB") return pil_image def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image: emb_dim = self._vq_model.quantize.embedding.weight.shape[-1] codebook_entry = self._vq_model.quantize.get_codebook_entry( img_tensor, (1, 32, 32, emb_dim) ) pixels = self._vq_model.decode(codebook_entry) return self._pil_from_chw_tensor(pixels[0])