darknoon commited on
Commit
f9661fe
1 Parent(s): 77819e0

Allow non-512x512 with chameleon tokenizer

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +5 -3
  3. chameleon/image_tokenizer.py +3 -2
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -130,11 +130,13 @@ class ChameleonVQImageRoundtripPipeline(ImageRoundtripPipeline):
130
  def roundtrip_image(self, image, output_type="pil"):
131
  # image = self.tokenizer._vqgan_input_from(image).to(device)
132
  image = self.preprocess(image).to(device)
 
133
  _, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
134
- # emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
135
- output = self.tokenizer.pil_from_img_toks(latents)
 
136
  # we actually do want this to be a grid, sorry!
137
- latents = latents.reshape(1, 32, 32)
138
 
139
  return (
140
  output,
 
130
  def roundtrip_image(self, image, output_type="pil"):
131
  # image = self.tokenizer._vqgan_input_from(image).to(device)
132
  image = self.preprocess(image).to(device)
133
+ _, _, im_height, im_width = image.shape
134
  _, _, [_, _, latents] = self.tokenizer._vq_model.encode(image)
135
+ scale = self.vae_scale_factor
136
+ shape = (1, im_height // scale, im_width // scale)
137
+ output = self.tokenizer.pil_from_img_toks(latents, shape=shape)
138
  # we actually do want this to be a grid, sorry!
139
+ latents = latents.reshape(*shape)
140
 
141
  return (
142
  output,
chameleon/image_tokenizer.py CHANGED
@@ -115,10 +115,11 @@ class ImageTokenizer:
115
 
116
  return pil_image
117
 
118
- def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
 
119
  emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
120
  codebook_entry = self._vq_model.quantize.get_codebook_entry(
121
- img_tensor, (1, 32, 32, emb_dim)
122
  )
123
  pixels = self._vq_model.decode(codebook_entry)
124
  return self._pil_from_chw_tensor(pixels[0])
 
115
 
116
  return pil_image
117
 
118
+ # darknoon: added shape parameter
119
+ def pil_from_img_toks(self, img_tensor: torch.Tensor, shape = (1, 32, 32,)) -> PIL.Image:
120
  emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
121
  codebook_entry = self._vq_model.quantize.get_codebook_entry(
122
+ img_tensor, (*shape, emb_dim)
123
  )
124
  pixels = self._vq_model.decode(codebook_entry)
125
  return self._pil_from_chw_tensor(pixels[0])