xuefengli commited on
Commit
7362797
1 Parent(s): e0974a9
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -3
  2. app.py +77 -0
  3. chameleon/__init__.py +4 -0
  4. chameleon/download_data.py +88 -0
  5. chameleon/inference/__init__.py +4 -0
  6. chameleon/inference/alignment.py +79 -0
  7. chameleon/inference/chameleon.py +689 -0
  8. chameleon/inference/cudagraph.py +85 -0
  9. chameleon/inference/generation.py +162 -0
  10. chameleon/inference/image_tokenizer.py +125 -0
  11. chameleon/inference/loader.py +71 -0
  12. chameleon/inference/logits_processor.py +336 -0
  13. chameleon/inference/model_adapter.py +118 -0
  14. chameleon/inference/stopping_criteria.py +55 -0
  15. chameleon/inference/token_selector.py +47 -0
  16. chameleon/inference/transformer.py +421 -0
  17. chameleon/inference/utils.py +34 -0
  18. chameleon/inference/vocab.py +123 -0
  19. chameleon/inference/vqgan.py +675 -0
  20. chameleon/miniviewer/__init__.py +4 -0
  21. chameleon/miniviewer/__main__.py +9 -0
  22. chameleon/miniviewer/miniviewer.html +409 -0
  23. chameleon/miniviewer/miniviewer.py +254 -0
  24. chameleon/viewer/backend/__init__.py +4 -0
  25. chameleon/viewer/backend/data_types.py +90 -0
  26. chameleon/viewer/backend/model_viewer.py +66 -0
  27. chameleon/viewer/backend/models/__init__.py +4 -0
  28. chameleon/viewer/backend/models/abstract_model.py +67 -0
  29. chameleon/viewer/backend/models/chameleon_distributed.py +827 -0
  30. chameleon/viewer/backend/models/chameleon_local.py +642 -0
  31. chameleon/viewer/backend/models/service.py +300 -0
  32. chameleon/viewer/backend/requirements.txt +35 -0
  33. chameleon/viewer/backend/utils.py +28 -0
  34. chameleon/viewer/frontend/README.md +11 -0
  35. chameleon/viewer/frontend/index.html +17 -0
  36. chameleon/viewer/frontend/package-lock.json +0 -0
  37. chameleon/viewer/frontend/package.json +62 -0
  38. chameleon/viewer/frontend/postcss.config.cjs +13 -0
  39. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_DisplayVF_W_Wght.woff2 +0 -0
  40. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff +0 -0
  41. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff2 +0 -0
  42. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff +0 -0
  43. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff2 +0 -0
  44. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff +0 -0
  45. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff2 +0 -0
  46. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_TextVF_W_Wght.woff2 +0 -0
  47. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff +0 -0
  48. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff2 +0 -0
  49. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff +0 -0
  50. chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff2 +0 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: Anole
3
- emoji:
4
- colorFrom: red
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 4.38.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: Anole
3
+ emoji: 🏆
4
+ colorFrom: green
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.37.2
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import subprocess
3
+ import shutil
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from huggingface_hub import snapshot_download
7
+ import json
8
+ import os
9
+
10
+ # Specify the repository ID
11
+ repo_id = "GAIR/Anole-7b-v0.1"
12
+
13
+ if not os.path.exists("./Anole-7b-v0.1"):
14
+ os.system("git lfs install")
15
+ os.system("git clone https://huggingface.co/GAIR/Anole-7b-v0.1")
16
+
17
+ subprocess.run(["/bin/bash", "install.sh"], capture_output=True, text=True)
18
+ result = subprocess.run(["/bin/bash", "install.sh"], capture_output=True, text=True)
19
+
20
+ @spaces.GPU(duration=90)
21
+ def text_to_image(instruction):
22
+ result = subprocess.run(["python", "text2image.py", "-i", instruction, "-b", "1"], capture_output=True, text=True)
23
+ if result.returncode == 0:
24
+ return gr.update(value="Image Generated. Check the display below.", visible=True), "outputs/text2image/1.png"
25
+ else:
26
+ return "Error: " + result.stderr, None
27
+
28
+ @spaces.GPU(duration=150)
29
+ def text_to_interleaved(instruction):
30
+ result = subprocess.run(["python", "interleaved_generation.py", "-i", instruction], capture_output=True, text=True)
31
+ if result.returncode == 0:
32
+ outputs = [None for i in range(7)]
33
+ box_index = 0
34
+
35
+ # Read the segments.jsonl file
36
+ with open('./segments.jsonl', 'r') as file:
37
+ for line in file:
38
+ line_dict = json.loads(line.strip())
39
+ if line_dict['type'] == 'text':
40
+ if box_index % 2 != 0:
41
+ box_index += 1
42
+ outputs[box_index] = line_dict['content']
43
+ elif line_dict['type'] == 'image':
44
+ if box_index % 2 == 0:
45
+ box_index += 1
46
+ outputs[box_index] = Image.open(line_dict['content'])
47
+ box_index += 1
48
+
49
+ return outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5], outputs[6]
50
+ else:
51
+ return ("Error: " + result.stderr, ) * 7
52
+
53
+ # Use Blocks to organize the interfaces side by side
54
+ with gr.Blocks() as demo:
55
+ # Create a row to place columns side by side
56
+ with gr.Row():
57
+ # First column for Text-to-Image Interface
58
+ with gr.Column():
59
+ gr.Interface(
60
+ fn=text_to_image, # Function to generate cat images
61
+ inputs=gr.Textbox(label="Enter Instruction for Image Generation"), # Input textbox for user instructions
62
+ outputs=[gr.Text(label="Status"), gr.Image(label="Generated Image")], # Outputs: status message and generated image
63
+ title="Anole: Text-to-Image", # Title of the interface
64
+ description="Generate images based on text instructions. Check https://github.com/GAIR-NLP/anole for more information. Model can be downloaded at: https://huggingface.co/GAIR/Anole-7b-v0.1."
65
+ )
66
+ # Second column for Text-to-Interleaved Image-Text Interface
67
+ with gr.Column():
68
+ gr.Interface(
69
+ fn=text_to_interleaved,
70
+ inputs=gr.Textbox(label="Enter Instruction for Interleaved Content"),
71
+ outputs=[gr.Text(label="Text Output 1"), gr.Image(label="Image Output 1"), gr.Text(label="Text Output 2"), gr.Image(label="Image Output 2"), gr.Text(label="Text Output 3"), gr.Image(label="Image Output 3"), gr.Text(label="Text Output 4")],
72
+ title="Anole: Text-to-Interleaved", # Title of the interface
73
+ description="Generate interleaved text and images based on text instructions. Check https://github.com/GAIR-NLP/anole for more information. Model can be downloaded at: https://huggingface.co/GAIR/Anole-7b-v0.1."
74
+ )
75
+
76
+ # Launch the entire Blocks interface
77
+ demo.launch()
chameleon/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/download_data.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # This software may be used and distributed according to the terms of the Chameleon License Agreement.
3
+
4
+ import hashlib
5
+ import subprocess
6
+ import sys
7
+ from pathlib import Path
8
+
9
+
10
+ def download_file(url: str, output_path: Path):
11
+ print(f"Downloading {output_path}")
12
+ subprocess.check_call(["wget", "--continue", url, "-O", str(output_path)])
13
+
14
+
15
+ def validate_checksum(folder: Path):
16
+ chks_parts = (folder / "checklist.chk").read_text().split()
17
+ for expected_checksum, file in zip(chks_parts[::2], chks_parts[1::2]):
18
+ file_path = folder / file
19
+ checksum = hashlib.md5(file_path.read_bytes()).hexdigest()
20
+ if checksum != expected_checksum:
21
+ print(f"Checksum mismatch for {file_path}")
22
+ sys.exit(1)
23
+
24
+
25
+ def download_tokenizer(presigned_url: str, target_folder: Path):
26
+ tokenizer_folder = target_folder / "tokenizer"
27
+ tokenizer_folder.mkdir(parents=True, exist_ok=True)
28
+
29
+ for filename in [
30
+ "text_tokenizer.json",
31
+ "vqgan.ckpt",
32
+ "vqgan.yaml",
33
+ "checklist.chk",
34
+ ]:
35
+ download_file(
36
+ presigned_url.replace("*", f"tokenizer/{filename}"),
37
+ tokenizer_folder / filename,
38
+ )
39
+
40
+ validate_checksum(tokenizer_folder)
41
+
42
+
43
+ def download_model(presigned_url: str, target_folder: Path, model: str):
44
+ model_folder = target_folder / "models" / model
45
+ model_folder.mkdir(parents=True, exist_ok=True)
46
+
47
+ download_filenames = ["params.json", "consolidate_params.json", "checklist.chk"]
48
+
49
+ if model == "7b":
50
+ download_filenames += ["consolidated.pth"]
51
+ elif model == "30b":
52
+ download_filenames += [f"consolidated.{i:02}.pth" for i in range(4)]
53
+ else:
54
+ print(f"Unknown model: {model}")
55
+ sys.exit(1)
56
+
57
+ for filename in download_filenames:
58
+ download_file(
59
+ presigned_url.replace("*", f"{model}/{filename}"),
60
+ model_folder / filename,
61
+ )
62
+
63
+ validate_checksum(model_folder)
64
+
65
+
66
+ def main():
67
+ presigned_url = (
68
+ sys.argv[1] if len(sys.argv) > 1 else input("Enter the URL from email: ")
69
+ )
70
+
71
+ target_folder = Path("./data")
72
+ target_folder.mkdir(parents=True, exist_ok=True)
73
+
74
+ download_tokenizer(presigned_url, target_folder)
75
+
76
+ model_size = input(
77
+ "Enter the list of models to download without spaces (7B,30B), or press Enter for all: "
78
+ )
79
+ if not model_size:
80
+ model_size = "7B,30B"
81
+
82
+ for model in model_size.split(","):
83
+ model = model.strip().lower()
84
+ download_model(presigned_url, target_folder, model)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()
chameleon/inference/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/inference/alignment.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import torch
9
+
10
+
11
+ class PromptAlignment(ABC):
12
+ @abstractmethod
13
+ def start_index(self, input_ids: list[list[int]]) -> int:
14
+ ...
15
+
16
+ @abstractmethod
17
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
18
+ ...
19
+
20
+ @abstractmethod
21
+ def postprocess_inputs(
22
+ self, inputs: torch.Tensor, original_inputs: torch.Tensor
23
+ ) -> torch.Tensor:
24
+ ...
25
+
26
+
27
+ class AlignPromptRight(PromptAlignment):
28
+ def __init__(self, pad_id: int):
29
+ self.pad_id = pad_id
30
+
31
+ def start_index(self, input_ids: list[list[int]]) -> int:
32
+ return max(len(sublist) for sublist in input_ids)
33
+
34
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
35
+ max_length = max(len(sublist) for sublist in input_ids)
36
+ return torch.tensor(
37
+ [
38
+ ([self.pad_id] * (max_length - len(sublist))) + sublist
39
+ for sublist in input_ids
40
+ ],
41
+ requires_grad=False,
42
+ )
43
+
44
+ def postprocess_inputs(
45
+ self,
46
+ inputs: torch.Tensor,
47
+ original_inputs: torch.Tensor,
48
+ ) -> torch.Tensor:
49
+ return inputs
50
+
51
+
52
+ class AlignPromptLeft(PromptAlignment):
53
+ def __init__(self, pad_id: int = -1):
54
+ self.pad_id = pad_id
55
+
56
+ def start_index(self, input_ids: list[list[int]]) -> int:
57
+ return min(len(sublist) for sublist in input_ids)
58
+
59
+ def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
60
+ max_length = max(len(sublist) for sublist in input_ids)
61
+ return torch.tensor(
62
+ [
63
+ sublist + ([self.pad_id] * (max_length - len(sublist)))
64
+ for sublist in input_ids
65
+ ],
66
+ requires_grad=False,
67
+ )
68
+
69
+ def postprocess_inputs(
70
+ self,
71
+ inputs: torch.Tensor,
72
+ original_inputs: torch.Tensor,
73
+ ) -> torch.Tensor:
74
+ max_init_len = original_inputs.shape[1]
75
+ if inputs.shape[1] <= max_init_len:
76
+ original_inputs_limited = original_inputs[:, : inputs.shape[1]]
77
+ mask = original_inputs_limited != self.pad_id
78
+ inputs[mask] = original_inputs_limited[mask]
79
+ return inputs
chameleon/inference/chameleon.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import base64
7
+ import io
8
+ import json
9
+ import math
10
+ import queue
11
+ import threading
12
+ from dataclasses import dataclass, field
13
+ from tqdm import tqdm
14
+ from enum import Enum
15
+ from multiprocessing import managers, queues, synchronize
16
+ from typing import Literal, Union
17
+
18
+ import PIL
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.multiprocessing as mp
22
+ from PIL.Image import Image
23
+ from tokenizers import Tokenizer
24
+ from transformers import (
25
+ LogitsProcessor,
26
+ RepetitionPenaltyLogitsProcessor,
27
+ TemperatureLogitsWarper,
28
+ TopPLogitsWarper,
29
+ enable_full_determinism,
30
+ )
31
+
32
+ from chameleon.inference import loader
33
+ from chameleon.inference.alignment import AlignPromptRight
34
+ from chameleon.inference.generation import ChameleonGenerator
35
+ from chameleon.inference.image_tokenizer import ImageTokenizer
36
+ from chameleon.inference.logits_processor import (
37
+ AllowOnlyTokensLogitsProcessor,
38
+ DisallowTokensAtOrAfterIndexLogitsProcessor,
39
+ InBatchInstructCFGLogitsProcessor,
40
+ )
41
+ from chameleon.inference.model_adapter import ChameleonModelAdapter
42
+ from chameleon.inference.stopping_criteria import (
43
+ MaxLengthCriteria,
44
+ StopOnEOSAfterBatchIndex,
45
+ )
46
+ from chameleon.inference.token_selector import (
47
+ ArgmaxTokenSelector,
48
+ MultinomialTokenSelector,
49
+ ReplicatedInputTokenSelector,
50
+ )
51
+ from chameleon.inference.transformer import Transformer
52
+ from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port
53
+ from chameleon.inference.vocab import VocabInfo, VocabTranslation
54
+
55
+
56
+ @dataclass
57
+ class Options:
58
+ @dataclass
59
+ class Text:
60
+ repetition_penalty: float = 1.2
61
+ temp: float = 1.0
62
+ top_p: float = 0.9
63
+ greedy: bool = False
64
+
65
+ @dataclass
66
+ class Image:
67
+ @dataclass
68
+ class CFG:
69
+ guidance_scale_text: float = 3.0
70
+ guidance_scale_image: float = 1.2
71
+
72
+ cfg: CFG = field(default_factory=CFG)
73
+ temp: float = 0.7
74
+ top_p: float = 0.9
75
+ greedy: bool = False
76
+
77
+ max_seq_len: int = 4096
78
+ max_gen_len: int = 4096
79
+ seed: int | None = None
80
+ txt: Text | bool = True
81
+ img: Image | bool = True
82
+ extra_eos_tokens: list[int | str] = field(default_factory=lambda: [])
83
+
84
+ def __post_init__(self):
85
+ if self.txt is True:
86
+ self.txt = Options.Text()
87
+ if self.img is True:
88
+ self.img = Options.Image()
89
+
90
+
91
+ class TokenManager:
92
+ def __init__(
93
+ self,
94
+ tokenizer_path: str,
95
+ vqgan_cfg_path: str,
96
+ vqgan_ckpt_path: str,
97
+ device: str | None = None,
98
+ ):
99
+ self.tokenizer = Tokenizer.from_file(tokenizer_path)
100
+ self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
101
+ self.translation = VocabTranslation(self.vocab, device=device)
102
+ self.image_tokenizer = ImageTokenizer(
103
+ cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device
104
+ )
105
+
106
+ def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image:
107
+ image_tensor = self.translation.convert_bpe2img(bpe_tokens)
108
+ if image_tensor.shape[0] < 1024:
109
+ padding = (
110
+ torch.ones(
111
+ [1024 - image_tensor.shape[0]],
112
+ dtype=int,
113
+ device=image_tensor.device,
114
+ )
115
+ * image_tensor[0]
116
+ )
117
+ image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
118
+
119
+ return self.image_tokenizer.pil_from_img_toks(image_tensor)
120
+
121
+ def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
122
+ pil = self.pil_from_bpe_tokens(bpe_tokens)
123
+ img_io = io.BytesIO()
124
+ pil.save(img_io, format="PNG")
125
+ return img_io.getvalue()
126
+
127
+ def tokenize_text(self, text: str) -> list[int]:
128
+ return self.tokenizer.encode(text).ids
129
+
130
+ def tokenize_image(self, img: Image) -> list[int]:
131
+ return (
132
+ [self.vocab.begin_image]
133
+ + self.translation.convert_img2bp2(
134
+ self.image_tokenizer.img_tokens_from_pil(img) # [0 : 8191], vqgan codebook ids
135
+ ).tolist()
136
+ + [self.vocab.end_image]
137
+ )
138
+
139
+ def tokenize_b64img(self, b64img: str) -> list[int]:
140
+ image_data = base64.b64decode(b64img)
141
+ image_file = io.BytesIO(image_data)
142
+ return self.tokenize_image(PIL.Image.open(image_file))
143
+
144
+ def tokens_from_ui(self, inputs: list[dict]) -> list[int]:
145
+ tokens = [self.vocab.bos_id]
146
+ for input_ in inputs:
147
+ if input_["type"] == "text":
148
+ tokens += self.tokenize_text(input_["value"])
149
+ elif input_["type"] == "image":
150
+ if isinstance(input_["value"], str):
151
+ if input_["value"].startswith("data:"):
152
+ # Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}'
153
+ tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1])
154
+ elif input_["value"].startswith("file:"):
155
+ tokens += self.tokenize_image(
156
+ PIL.Image.open(input_["value"].split(":", 1)[1])
157
+ )
158
+ else:
159
+ raise ValueError("Unknown image format.")
160
+ elif isinstance(input_["value"], Image):
161
+ tokens += self.tokenize_image(input_["value"])
162
+ else:
163
+ raise ValueError("Unknown image type.")
164
+ elif input_["type"] == "sentinel":
165
+ tokens += [
166
+ {
167
+ "<START-OF-IMAGE>": self.vocab.begin_image,
168
+ "<END-OF-TURN>": self.vocab.eot_id,
169
+ }[input_["value"]]
170
+ ]
171
+ elif input_["type"] == "ids":
172
+ tokens += input_["value"]
173
+ else:
174
+ raise ValueError("Unknown input type.")
175
+ return tokens
176
+
177
+ def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
178
+ if isinstance(ids, torch.Tensor):
179
+ ids = ids.tolist()
180
+
181
+ for row, values in enumerate(ids):
182
+ try:
183
+ ids[row] = values[: values.index(self.vocab.eos_id)]
184
+ except ValueError:
185
+ pass
186
+
187
+ return self.tokenizer.decode_batch(ids)
188
+
189
+ def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
190
+ return [self.pil_from_bpe_tokens(sample) for sample in ids]
191
+
192
+
193
+ @dataclass
194
+ class DecodePiece:
195
+ token: ChameleonGenerator.Token
196
+ next_decoder: type["Decoder"] | None
197
+
198
+
199
+ class Decoder:
200
+ def __init__(
201
+ self,
202
+ model: Transformer,
203
+ vocab: VocabInfo,
204
+ options: Options,
205
+ input_ids: list[int],
206
+ ): ...
207
+
208
+ def __next__(self) -> DecodePiece: ...
209
+
210
+
211
+ class TextDecoder(Decoder):
212
+ def __init__(
213
+ self,
214
+ model: Transformer,
215
+ vocab: VocabInfo,
216
+ options: Options,
217
+ input_ids: list[list[int]],
218
+ ):
219
+ self.vocab = vocab
220
+ self.options = options
221
+ assert vocab.eos_id is not None
222
+
223
+ prompt_lens = [len(inp) for inp in input_ids]
224
+ max_prompt_len = max(prompt_lens)
225
+ max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len)
226
+
227
+ self.eos_ids = [vocab.eos_id]
228
+ for extra_eos_token in options.extra_eos_tokens:
229
+ if isinstance(extra_eos_token, str):
230
+ extra_eos_token = vocab.name2val[extra_eos_token]
231
+ assert isinstance(extra_eos_token, int)
232
+ self.eos_ids.append(extra_eos_token)
233
+
234
+ stopping_criteria = [
235
+ MaxLengthCriteria(max_seq_len),
236
+ ] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids]
237
+
238
+ self.gen = ChameleonGenerator(
239
+ model=ChameleonModelAdapter(model, max_seq_len=max_seq_len),
240
+ input_ids=input_ids,
241
+ stopping_criteria=stopping_criteria,
242
+ logits_processors=self._logits_processors(),
243
+ alignment=AlignPromptRight(vocab.pad_id),
244
+ token_selector=(
245
+ ArgmaxTokenSelector()
246
+ if options.txt.greedy
247
+ else MultinomialTokenSelector()
248
+ ),
249
+ )
250
+ advance(self.gen, max_prompt_len)
251
+
252
+ def _allowed_tokens(self) -> list[int]:
253
+ allowed_tokens = [self.vocab.eos_id]
254
+ if self.options.txt:
255
+ allowed_tokens += self.vocab.text_tokens
256
+ if self.options.img:
257
+ allowed_tokens += [self.vocab.begin_image]
258
+ return allowed_tokens
259
+
260
+ def _logits_processors(self) -> list[LogitsProcessor]:
261
+ logits_processors = [
262
+ AllowOnlyTokensLogitsProcessor(self._allowed_tokens()),
263
+ ]
264
+ if isinstance(self.options.img, Options.Image):
265
+ logits_processors += [
266
+ DisallowTokensAtOrAfterIndexLogitsProcessor(
267
+ [self.vocab.begin_image],
268
+ self.options.max_seq_len - 1026,
269
+ ),
270
+ ]
271
+ if isinstance(self.options.txt, Options.Text):
272
+ logits_processors += [
273
+ RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty),
274
+ TemperatureLogitsWarper(self.options.txt.temp),
275
+ TopPLogitsWarper(self.options.txt.top_p),
276
+ ]
277
+ return logits_processors
278
+
279
+ def __next__(self) -> DecodePiece:
280
+ tok = next(self.gen)
281
+ next_decoder = None
282
+ if (
283
+ self.vocab.begin_image not in self.eos_ids
284
+ and (tok.id == self.vocab.begin_image).all()
285
+ ):
286
+ next_decoder = ImageDecoder
287
+ return DecodePiece(tok, next_decoder)
288
+
289
+
290
+ class ImageDecoder(Decoder):
291
+ def __init__(
292
+ self,
293
+ model: Transformer,
294
+ vocab: VocabInfo,
295
+ options: Options,
296
+ input_ids: list[list[int]],
297
+ ):
298
+ assert isinstance(options.img, Options.Image)
299
+ self.vocab = vocab
300
+ self.options = options
301
+ self.batch_size = len(input_ids)
302
+ logits_processors = [
303
+ InBatchInstructCFGLogitsProcessor(
304
+ options.img.cfg.guidance_scale_text,
305
+ options.img.cfg.guidance_scale_image,
306
+ ),
307
+ AllowOnlyTokensLogitsProcessor(vocab.image_tokens),
308
+ TemperatureLogitsWarper(options.img.temp),
309
+ TopPLogitsWarper(options.img.top_p),
310
+ ]
311
+
312
+ for inp in input_ids:
313
+ if inp[-1] != self.vocab.begin_image:
314
+ inp.append(self.vocab.begin_image)
315
+
316
+ max_prompt_len = max(len(inp) for inp in input_ids)
317
+ self.gen = ChameleonGenerator(
318
+ model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024),
319
+ input_ids=self._split_inputs_for_cfg(input_ids),
320
+ logits_processors=logits_processors,
321
+ alignment=AlignPromptRight(vocab.pad_id),
322
+ token_selector=ReplicatedInputTokenSelector(
323
+ (
324
+ ArgmaxTokenSelector()
325
+ if options.img.greedy
326
+ else MultinomialTokenSelector()
327
+ ),
328
+ n=3,
329
+ ),
330
+ )
331
+ advance(self.gen, max_prompt_len)
332
+ self.gen_count = 0
333
+
334
+ def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]:
335
+ image_conditioned_allowed = set(self.vocab.image_tokens) | {
336
+ self.vocab.bos_id,
337
+ self.vocab.begin_image,
338
+ self.vocab.end_image,
339
+ }
340
+
341
+ full_conditioned = input_ids
342
+
343
+ image_conditioned = [
344
+ [id for id in sample if id in image_conditioned_allowed]
345
+ for sample in input_ids
346
+ ]
347
+
348
+ unconditioned = [
349
+ [
350
+ self.vocab.bos_id,
351
+ self.vocab.begin_image,
352
+ ]
353
+ ] * self.batch_size
354
+
355
+ return full_conditioned + image_conditioned + unconditioned
356
+
357
+ def __next__(self) -> DecodePiece:
358
+ if self.gen_count == 1024:
359
+ id = torch.tensor([self.vocab.end_image] * self.batch_size)
360
+ logits = torch.full(
361
+ (self.batch_size, len(self.vocab.all_tokens)), -math.inf
362
+ )
363
+ logits[:, self.vocab.end_image] = 0
364
+ return DecodePiece(
365
+ ChameleonGenerator.Token(id=id, logits=logits),
366
+ TextDecoder,
367
+ )
368
+
369
+ tok = next(self.gen)
370
+ tok.id = tok.id.chunk(3)[0]
371
+ self.gen_count += 1
372
+ return DecodePiece(tok, None)
373
+
374
+
375
+ class Generator(Decoder):
376
+ def __init__(
377
+ self,
378
+ model: Transformer,
379
+ vocab: VocabInfo,
380
+ options: Options,
381
+ input_ids: list[list[int]],
382
+ ):
383
+ if options.seed is not None:
384
+ enable_full_determinism(options.seed, warn_only=True)
385
+
386
+ self.model = model
387
+ self.vocab = vocab
388
+ self.input_ids = input_ids[:]
389
+ self.generated_token_ids: list[torch.LongTensor] = []
390
+ self.options = options
391
+ if not self.options.txt:
392
+ self.dyngen = DynamicGenerator(
393
+ ImageDecoder(model, vocab, options, input_ids)
394
+ )
395
+ else:
396
+ self.dyngen = DynamicGenerator(
397
+ TextDecoder(model, vocab, options, input_ids)
398
+ )
399
+
400
+ def __iter__(self):
401
+ return self
402
+
403
+ def __next__(self) -> ChameleonGenerator.Token:
404
+ piece = next(self.dyngen)
405
+ self.generated_token_ids.append(piece.token.id)
406
+ if piece.next_decoder is not None:
407
+ if not self.options.txt:
408
+ raise StopIteration
409
+
410
+ self.input_ids = [
411
+ old_list + generated
412
+ for old_list, generated in zip(
413
+ self.input_ids, torch.stack(self.generated_token_ids).T.tolist()
414
+ )
415
+ ]
416
+ self.generated_token_ids = []
417
+ self.dyngen.gen = piece.next_decoder(
418
+ self.model,
419
+ self.vocab,
420
+ self.options,
421
+ self.input_ids,
422
+ )
423
+ return piece.token
424
+
425
+
426
+ class DistributedMode(Enum):
427
+ AUTO = 0
428
+ THREAD = 1
429
+ PROCESS = 2
430
+
431
+
432
+ @dataclass
433
+ class _DistributedContext:
434
+ req_q: Union[queue.Queue, queues.Queue]
435
+ res_q: Union[queue.Queue, queues.Queue]
436
+ active_key: Union[dict[int, Literal[True]], managers.DictProxy]
437
+ active_key_lock: Union[threading.Lock, synchronize.Lock]
438
+ ready_barrier: Union[threading.Barrier, synchronize.Barrier]
439
+ worker_launcher: Union[type[threading.Thread], type[mp.Process]]
440
+
441
+ @staticmethod
442
+ def make_for_threading(world_size: int):
443
+ return _DistributedContext(
444
+ req_q=queue.Queue(),
445
+ res_q=queue.Queue(),
446
+ active_key={},
447
+ active_key_lock=threading.Lock(),
448
+ ready_barrier=threading.Barrier(world_size + 1),
449
+ worker_launcher=threading.Thread,
450
+ )
451
+
452
+ @staticmethod
453
+ def make_for_multiprocessing(world_size: int):
454
+ local_mp = mp.get_context("spawn")
455
+ return _DistributedContext(
456
+ req_q=local_mp.Queue(),
457
+ res_q=local_mp.Queue(),
458
+ active_key=local_mp.Manager().dict(),
459
+ active_key_lock=local_mp.Lock(),
460
+ ready_barrier=local_mp.Barrier(world_size + 1),
461
+ worker_launcher=local_mp.Process,
462
+ )
463
+
464
+ @staticmethod
465
+ def make(mode: DistributedMode, world_size: int):
466
+ if mode == DistributedMode.AUTO:
467
+ mode = DistributedMode.PROCESS
468
+
469
+ if mode == DistributedMode.THREAD:
470
+ return _DistributedContext.make_for_threading(world_size)
471
+ elif mode == DistributedMode.PROCESS:
472
+ return _DistributedContext.make_for_multiprocessing(world_size)
473
+ else:
474
+ raise ValueError("Unknown DistributedMode")
475
+
476
+
477
+ def _worker_impl(
478
+ init_method: str,
479
+ model: Transformer | str,
480
+ world_size: int,
481
+ rank: int,
482
+ vocab: VocabInfo,
483
+ dctx: _DistributedContext,
484
+ ):
485
+ dist.init_process_group(
486
+ "nccl",
487
+ init_method=init_method,
488
+ world_size=world_size,
489
+ rank=rank,
490
+ )
491
+
492
+ torch.set_default_device(f"cuda:{rank}")
493
+ torch.cuda.set_device(rank)
494
+ if isinstance(model, str):
495
+ model = loader.load_model(model, rank=rank)
496
+ dctx.ready_barrier.wait()
497
+
498
+ is_coord = rank == 0
499
+
500
+ while True:
501
+ req = [Options(), [], 0, False]
502
+ if is_coord:
503
+ req = dctx.req_q.get()
504
+
505
+ dist.broadcast_object_list(req, src=0)
506
+ options, input_ids, key, shutdown = req
507
+ if shutdown:
508
+ break
509
+
510
+ for token in Generator(
511
+ model=model,
512
+ vocab=vocab,
513
+ options=options,
514
+ input_ids=input_ids,
515
+ ):
516
+ if is_coord:
517
+ dctx.res_q.put((key, token))
518
+
519
+ to_continue = [True]
520
+ if is_coord:
521
+ with dctx.active_key_lock:
522
+ to_continue = [key in dctx.active_key]
523
+ dist.broadcast_object_list(to_continue, src=0)
524
+ if not to_continue[0]:
525
+ break
526
+
527
+ if is_coord:
528
+ dctx.res_q.put((key, None))
529
+
530
+
531
+ class ChameleonInferenceModel:
532
+ def __init__(
533
+ self,
534
+ model: Transformer | str,
535
+ tokenizer_path: str,
536
+ vqgan_cfg_path: str,
537
+ vqgan_ckpt_path: str,
538
+ *,
539
+ options: Options | None = None,
540
+ distributed_mode: DistributedMode = DistributedMode.AUTO,
541
+ ):
542
+ self.options = options or Options()
543
+ self.next_key = 0
544
+
545
+ self.token_manager = TokenManager(
546
+ tokenizer_path=tokenizer_path,
547
+ vqgan_cfg_path=vqgan_cfg_path,
548
+ vqgan_ckpt_path=vqgan_ckpt_path,
549
+ device="cuda",
550
+ )
551
+ self.vocab = self.token_manager.vocab
552
+
553
+ world_size = 1
554
+ if isinstance(model, str):
555
+ world_size = loader.detect_shard_count(model)
556
+ self.dctx = _DistributedContext.make(distributed_mode, world_size)
557
+
558
+ init_method = f"tcp://0.0.0.0:{random_unused_port()}"
559
+ self.workers = [
560
+ self.dctx.worker_launcher(
561
+ target=_worker_impl,
562
+ args=(init_method, model, world_size, i, self.vocab, self.dctx),
563
+ daemon=True,
564
+ )
565
+ for i in range(world_size)
566
+ ]
567
+ for w in self.workers:
568
+ w.start()
569
+ self.dctx.ready_barrier.wait()
570
+
571
+ def __del__(self):
572
+ try:
573
+ with self.dctx.active_key_lock:
574
+ self.dctx.active_key.clear()
575
+ self.dctx.req_q.put([None, None, None, True])
576
+ for w in self.workers:
577
+ w.join()
578
+ except FileNotFoundError:
579
+ pass
580
+
581
+ def stream(
582
+ self,
583
+ *,
584
+ input_ids: list[int] | None = None,
585
+ prompt_text: str | None = None,
586
+ prompt_ui: list[dict] | None = None,
587
+ batch_input_ids: list[list[int]] | None = None,
588
+ batch_prompt_text: list[str] | None = None,
589
+ batch_prompt_ui: list[list[dict]] | None = None,
590
+ options: Options | None = None,
591
+ ):
592
+ # NOTE: Not thread-safe! Only one instance of generate may be run at a time.
593
+
594
+ if (
595
+ sum(
596
+ x is not None
597
+ for x in [
598
+ input_ids,
599
+ prompt_text,
600
+ prompt_ui,
601
+ batch_input_ids,
602
+ batch_prompt_text,
603
+ batch_prompt_ui,
604
+ ]
605
+ )
606
+ != 1
607
+ ):
608
+ raise ValueError(
609
+ "Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui"
610
+ )
611
+
612
+ options = options or self.options
613
+
614
+ if prompt_text is not None:
615
+ batch_prompt_text = [prompt_text]
616
+ if prompt_ui is not None:
617
+ batch_prompt_ui = [prompt_ui]
618
+ if input_ids is not None:
619
+ batch_input_ids = [input_ids]
620
+ if batch_prompt_text is not None:
621
+ batch_prompt_ui = [
622
+ [{"type": "text", "value": prompt_text}]
623
+ for prompt_text in batch_prompt_text
624
+ ]
625
+ if batch_prompt_ui is not None:
626
+ batch_input_ids = [
627
+ self.token_manager.tokens_from_ui(prompt_ui)
628
+ for prompt_ui in batch_prompt_ui
629
+ ]
630
+
631
+ assert batch_input_ids
632
+
633
+ if not options.txt and not options.img:
634
+ raise ValueError("Must specify at least one modality.")
635
+ if options.txt and options.img and len(batch_input_ids) > 1:
636
+ raise ValueError(
637
+ "Batch generation only supported for one modality at a time."
638
+ )
639
+
640
+ req_key = self.next_key
641
+ self.next_key += 1
642
+
643
+ with self.dctx.active_key_lock:
644
+ self.dctx.active_key[req_key] = True
645
+
646
+ self.dctx.req_q.put([options, batch_input_ids, req_key, False])
647
+
648
+ try:
649
+ while key_token := self.dctx.res_q.get():
650
+ key, token = key_token
651
+ if key != req_key:
652
+ # Residual from prior calls to generation. Skip.
653
+ continue
654
+ if token is None:
655
+ break
656
+ yield token
657
+ finally:
658
+ with self.dctx.active_key_lock:
659
+ del self.dctx.active_key[req_key]
660
+
661
+ def step(self, *args, **kwargs) -> ChameleonGenerator.Token:
662
+ return next(self.stream(*args, **kwargs))
663
+
664
+ def generate(self, *args, **kwargs) -> torch.LongTensor:
665
+ tokens = [t.id for t in self.stream(*args, **kwargs)]
666
+ if not tokens:
667
+ return torch.LongTensor()
668
+ return torch.stack(tokens).T
669
+
670
+ def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
671
+ return self.token_manager.decode_text(ids)
672
+
673
+ def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
674
+ return self.token_manager.decode_image(ids)
675
+
676
+ def sft_tokenization(self, json_path: str) -> list[dict]:
677
+ with open(json_path, 'r') as input_file:
678
+ jsonl_input = [json.loads(line) for line in input_file]
679
+
680
+ output_data = []
681
+ for entry in tqdm(jsonl_input, desc="Tokenize dataset"):
682
+ # print(i)
683
+ text_tokens = self.token_manager.tokenize_text(entry['text'])
684
+ image_tokens = self.token_manager.tokenize_image(PIL.Image.open(entry['image']))
685
+ entry['text_tokens'] = text_tokens
686
+ entry['image_tokens'] = image_tokens
687
+ output_data.append(entry)
688
+
689
+ return output_data
chameleon/inference/cudagraph.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import functools
7
+ from typing import Any, Callable, TypeVar
8
+
9
+ import torch
10
+
11
+ T = TypeVar("T")
12
+ FN = Callable[..., T] # type: ignore
13
+
14
+
15
+ class CUDAGraphWrapper:
16
+ def __init__(
17
+ self,
18
+ fn: FN[T],
19
+ warmup_iter: int = 1,
20
+ debug_dump_path: str | None = None,
21
+ ):
22
+ self.fn = fn
23
+ self.warmup_iter = warmup_iter
24
+ self.debug_dump_path = debug_dump_path
25
+ self.graph: torch.cuda.CUDAGraph | None = None
26
+ self.result: T | None = None
27
+
28
+ def __call__(self, *args, **kwargs) -> Any: # type: ignore
29
+ if self.warmup_iter > 0:
30
+ self.warmup_iter -= 1
31
+ return self.fn(*args, **kwargs)
32
+
33
+ if self.graph is None:
34
+ self.graph = torch.cuda.CUDAGraph()
35
+ if self.debug_dump_path is not None:
36
+ self.graph.enable_debug_mode()
37
+ recording_kwargs = {}
38
+ if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
39
+ # In PyTorch 2.1+ and nightlies from late Aug 2023,
40
+ # we can do this to maybe avoid watchdog-related crashes
41
+ recording_kwargs["capture_error_mode"] = "thread_local"
42
+ with torch.cuda.graph(self.graph, **recording_kwargs):
43
+ self.result = self.fn(*args, **kwargs)
44
+ torch.cuda.synchronize()
45
+ if self.debug_dump_path is not None:
46
+ self.graph.debug_dump(self.debug_dump_path)
47
+
48
+ assert self.graph is not None
49
+ self.graph.replay()
50
+ return self.result
51
+
52
+
53
+ def cudagraph_wrap(
54
+ *args,
55
+ warmup_iter: int = 1,
56
+ debug_dump_path: str | None = None,
57
+ ) -> Callable[[FN[T]], FN[T]]:
58
+ def wrapper(fn: FN[T]) -> FN[T]:
59
+ graph_wrapper = CUDAGraphWrapper(
60
+ fn, warmup_iter=warmup_iter, debug_dump_path=debug_dump_path
61
+ )
62
+
63
+ @functools.wraps(fn)
64
+ def call_wrapper(*inner_args, **inner_kwargs):
65
+ return graph_wrapper(*inner_args, **inner_kwargs)
66
+
67
+ return call_wrapper
68
+
69
+ # @cudagraph_wrap
70
+ # def fn(...):
71
+ # ...
72
+ #
73
+ # - or -
74
+ #
75
+ # fast_fn = cudagraph_wrap(slow_fn, warmup_iter=2)
76
+ if len(args) == 1 and callable(args[0]):
77
+ return wrapper(args[0])
78
+
79
+ # @cudagraph_wrap(warmup_iter=3)
80
+ # def fn(...):
81
+ # ...
82
+ def decorator(fn: FN[T]) -> FN[T]:
83
+ return wrapper(fn)
84
+
85
+ return decorator
chameleon/inference/generation.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ from transformers import (
10
+ LogitsProcessor,
11
+ LogitsProcessorList,
12
+ )
13
+ from transformers.generation.streamers import BaseStreamer
14
+
15
+ from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment
16
+ from chameleon.inference.model_adapter import ModelAdapter
17
+ from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList
18
+ from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector
19
+
20
+
21
+ class ChameleonGenerator:
22
+ @dataclass
23
+ class Token:
24
+ id: torch.LongTensor
25
+ logits: torch.Tensor | None
26
+
27
+ def __init__(
28
+ self,
29
+ model: ModelAdapter,
30
+ input_ids: list[list[int]],
31
+ stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None,
32
+ logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
33
+ probability_processors: LogitsProcessorList
34
+ | list[LogitsProcessor]
35
+ | None = None,
36
+ token_selector: TokenSelector | None = None,
37
+ alignment: PromptAlignment = AlignPromptLeft(),
38
+ ):
39
+ assert model.supports_alignment(alignment)
40
+
41
+ self.model = model
42
+
43
+ self.stopping_criteria = stopping_criteria
44
+ self.logits_processors = logits_processors
45
+ self.probability_processors = probability_processors
46
+ self.token_selector: TokenSelector = (
47
+ token_selector or MultinomialTokenSelector()
48
+ )
49
+
50
+ self.alignment = alignment
51
+
52
+ self.model.initialize(input_ids)
53
+
54
+ self._inputs = self.alignment.prepare_inputs(
55
+ input_ids
56
+ ) # inputs.shape = [batch, seq-len]
57
+
58
+ self._idx = 0
59
+ self._start_idx = self.alignment.start_index(input_ids)
60
+
61
+ self._original_inputs = self._inputs.clone()
62
+ self._inputs = self._inputs[:, : self._start_idx]
63
+
64
+ def __iter__(self):
65
+ return self
66
+
67
+ @torch.inference_mode()
68
+ def __next__(self) -> Token:
69
+ # Are we done?
70
+ if self.stopping_criteria(self._inputs, None):
71
+ raise StopIteration
72
+
73
+ # Emit initial tokens.
74
+ # Model is not run for these.
75
+ # If you want the logits, you can do a separate forward pass outside generation.
76
+ if self._idx < self._start_idx:
77
+ idx, self._idx = self._idx, self._idx + 1
78
+ return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None)
79
+
80
+ # Run the model for the next token.
81
+ self._inputs = self._inputs.contiguous()
82
+ outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab]
83
+
84
+ # Pull out and process the logits.
85
+ logits = outputs[:, -1, :] # logits.shape = [batch, vocab]
86
+ logits = self.logits_processors(self._inputs, logits)
87
+ probs = logits.softmax(dim=1) # probs.shape = [batch, vocab]
88
+ probs = self.probability_processors(self._inputs, probs)
89
+
90
+ # Select a token and add it to the inputs.
91
+ next_tokens = self.token_selector(
92
+ self._inputs, probs
93
+ ) # next_tokens.shape = [batch]
94
+ self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1)
95
+
96
+ # Run alignment specific postprocessing.
97
+ self._inputs = self.alignment.postprocess_inputs(
98
+ self._inputs, self._original_inputs
99
+ )
100
+
101
+ # Return the next step result.
102
+ return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits)
103
+
104
+ @property
105
+ def stopping_criteria(self) -> StoppingCriteriaList:
106
+ return self._stopping_criteria
107
+
108
+ @stopping_criteria.setter
109
+ def stopping_criteria(
110
+ self, value: StoppingCriteriaList | list[StoppingCriteria] | None
111
+ ):
112
+ self._stopping_criteria = StoppingCriteriaList(value or [])
113
+
114
+ @property
115
+ def logits_processors(self) -> LogitsProcessorList:
116
+ return self._logits_processors
117
+
118
+ @logits_processors.setter
119
+ def logits_processors(
120
+ self, value: LogitsProcessorList | list[LogitsProcessor] | None
121
+ ):
122
+ self._logits_processors = LogitsProcessorList(value or [])
123
+
124
+ @property
125
+ def probability_processors(self) -> LogitsProcessorList:
126
+ return self._probability_processors
127
+
128
+ @probability_processors.setter
129
+ def probability_processors(
130
+ self, value: LogitsProcessorList | list[LogitsProcessor] | None
131
+ ):
132
+ self._probability_processors = LogitsProcessorList(value or [])
133
+
134
+
135
+ def run_generation(
136
+ model: torch.nn.Module,
137
+ input_ids: list[list[int]],
138
+ stopping_criteria: StoppingCriteriaList | list[StoppingCriteria],
139
+ logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
140
+ probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
141
+ token_selector: TokenSelector | None = None,
142
+ alignment: PromptAlignment = AlignPromptLeft(),
143
+ streamer: BaseStreamer | None = None,
144
+ ) -> torch.LongTensor:
145
+ result = torch.empty((len(input_ids), 0), dtype=int)
146
+ for tok in ChameleonGenerator(
147
+ model=model,
148
+ input_ids=input_ids,
149
+ stopping_criteria=stopping_criteria,
150
+ logits_processors=logits_processors,
151
+ probability_processors=probability_processors,
152
+ token_selector=token_selector,
153
+ alignment=alignment,
154
+ ):
155
+ if streamer is not None:
156
+ streamer.put(tok.id)
157
+ result = torch.cat([result, tok.id.view(-1, 1)], dim=1)
158
+
159
+ if streamer is not None:
160
+ streamer.end()
161
+
162
+ return result
chameleon/inference/image_tokenizer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import yaml
10
+ from PIL import Image
11
+
12
+ from chameleon.inference.vqgan import VQModel
13
+
14
+
15
+ class ImageTokenizer:
16
+ def __init__(
17
+ self,
18
+ cfg_path: str,
19
+ ckpt_path: str,
20
+ device: str | torch.device | None = None,
21
+ ):
22
+ with open(cfg_path) as f:
23
+ config = yaml.safe_load(f)
24
+
25
+ params = config["model"]["params"]
26
+ if "lossconfig" in params:
27
+ del params["lossconfig"]
28
+ params["ckpt_path"] = ckpt_path
29
+
30
+ self._vq_model = VQModel(**params)
31
+ self._vq_model.eval()
32
+
33
+ if device is None:
34
+ devices = {p.device for p in self._vq_model.parameters()}
35
+ assert len(devices) == 1
36
+ device = devices.pop()
37
+ else:
38
+ self._vq_model.to(device)
39
+ self._device = device
40
+
41
+ dtypes = {p.dtype for p in self._vq_model.parameters()}
42
+ assert len(dtypes) == 1
43
+ self._dtype = dtypes.pop()
44
+
45
+ def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
46
+ # Check if it's already in RGB format.
47
+ if img.mode == "RGB":
48
+ return img
49
+
50
+ vals_rgba = np.array(img.convert("RGBA"))
51
+
52
+ # If there is no transparency layer, simple convert and return.
53
+ if not (vals_rgba[:, :, 3] < 255).any():
54
+ return img.convert("RGB")
55
+
56
+ # There is a transparency layer, blend it with a white background.
57
+
58
+ # Calculate the alpha proportion for blending.
59
+ alpha = vals_rgba[:, :, 3] / 255.0
60
+ # Blend with white background.
61
+ vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
62
+ :, :, np.newaxis
63
+ ] * vals_rgba[:, :, :3]
64
+ return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
65
+
66
+ def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
67
+ # Resize with aspect ratio preservation.
68
+ s = min(img.size)
69
+ scale = target_image_size / s
70
+ new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
71
+ img = img.resize(new_size, PIL.Image.LANCZOS)
72
+
73
+ # Center crop.
74
+ x0 = (img.width - target_image_size) // 2
75
+ y0 = (img.height - target_image_size) // 2
76
+ img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
77
+
78
+ # Convert to tensor.
79
+ np_img = np.array(img) / 255.0 # Normalize to [0, 1]
80
+ np_img = np_img * 2 - 1 # Scale to [-1, 1]
81
+ tensor_img = (
82
+ torch.from_numpy(np_img).permute(2, 0, 1).float()
83
+ ) # (Channels, Height, Width) format.
84
+
85
+ # Add batch dimension.
86
+ return tensor_img.unsqueeze(0)
87
+
88
+ def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
89
+ image = self._whiten_transparency(image)
90
+ vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
91
+ _, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
92
+ return img_toks
93
+
94
+ def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
95
+
96
+ # Ensure detachment and move tensor to CPU.
97
+ detached_chw_tensor = chw_tensor.detach().cpu()
98
+
99
+ # Normalize tensor to [0, 1] range from [-1, 1] range.
100
+ normalized_chw_tensor = (
101
+ torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
102
+ ) / 2.0
103
+
104
+ # Permute CHW tensor to HWC format and convert to NumPy array.
105
+ hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
106
+
107
+ # Convert to an 8-bit unsigned integer format.
108
+ image_array_uint8 = (hwc_array * 255).astype(np.uint8)
109
+
110
+ # Convert NumPy array to PIL Image.
111
+ pil_image = Image.fromarray(image_array_uint8)
112
+
113
+ # Convert image to RGB if it is not already.
114
+ if pil_image.mode != "RGB":
115
+ pil_image = pil_image.convert("RGB")
116
+
117
+ return pil_image
118
+
119
+ def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
120
+ emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
121
+ codebook_entry = self._vq_model.quantize.get_codebook_entry(
122
+ img_tensor, (1, 32, 32, emb_dim)
123
+ )
124
+ pixels = self._vq_model.decode(codebook_entry)
125
+ return self._pil_from_chw_tensor(pixels[0])
chameleon/inference/loader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import glob
7
+ import inspect
8
+ import json
9
+ from pathlib import Path
10
+
11
+ import torch
12
+
13
+ from chameleon.inference.transformer import ModelArgs, Transformer
14
+
15
+
16
+ def _convert(model_args: ModelArgs, consolidated_path: Path) -> Transformer:
17
+ old_default_dtype = torch.get_default_dtype()
18
+ torch.set_default_dtype(torch.bfloat16)
19
+
20
+ model = Transformer(model_args)
21
+
22
+ transfer_results = model.load_state_dict(
23
+ torch.load(str(consolidated_path), map_location='cuda'),
24
+ strict=False,
25
+ )
26
+
27
+ # TODO: More generally, assert missing or unexpected keys are buffers.
28
+ assert transfer_results.missing_keys == []
29
+ assert transfer_results.unexpected_keys == ["rope.freqs"]
30
+
31
+ model.eval()
32
+
33
+ torch.set_default_dtype(old_default_dtype)
34
+ return model
35
+
36
+
37
+ def _get_checkpoint_path(src_dir: Path, rank: int | None) -> Path:
38
+ base_path = src_dir / "consolidated.pth"
39
+ if not rank and base_path.exists():
40
+ return base_path
41
+
42
+ alt_path = src_dir / f"consolidated.{rank:02}.pth"
43
+ if alt_path.exists():
44
+ return alt_path
45
+
46
+ raise ValueError("Consolidated checkpoint not found.")
47
+
48
+
49
+ def load_model(path: str, rank: int | None = None) -> Transformer:
50
+ src_dir = Path(path)
51
+
52
+ with open(src_dir / "params.json", "r") as f:
53
+ params = json.loads(f.read())
54
+ with open(src_dir / "consolidate_params.json", "r") as f:
55
+ consolidate_params = json.loads(f.read())
56
+ params = {**params, **params["model"], **consolidate_params}
57
+
58
+ known_param = inspect.signature(ModelArgs.__init__).parameters
59
+ filtered_params = {k: v for k, v in params.items() if k in known_param}
60
+
61
+ return _convert(
62
+ ModelArgs(**filtered_params),
63
+ _get_checkpoint_path(src_dir, rank),
64
+ )
65
+
66
+
67
+ def detect_shard_count(path: str) -> int:
68
+ src_dir = Path(path)
69
+ if (src_dir / "consolidated.pth").exists():
70
+ return 1
71
+ return len(glob.glob(str(src_dir / "consolidated.*.pth")))
chameleon/inference/logits_processor.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ from transformers import LogitsProcessor
10
+
11
+
12
+ class TopPProbabilityProcessor(LogitsProcessor):
13
+ # Modified version of TopPLogitsWarper to act on probabilities.
14
+ # Changes:
15
+ # * filter_value changed from -inf to 0
16
+ # * removed softmax
17
+ # * renormalize L1
18
+
19
+ def __init__(
20
+ self,
21
+ top_p: float,
22
+ min_tokens_to_keep: int = 1,
23
+ ):
24
+ top_p = float(top_p)
25
+ if top_p < 0 or top_p > 1.0:
26
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
27
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
28
+ raise ValueError(
29
+ f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
30
+ )
31
+
32
+ self.top_p = top_p
33
+ self.min_tokens_to_keep = min_tokens_to_keep
34
+
35
+ def __call__(
36
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
37
+ ) -> torch.FloatTensor:
38
+ # input_ids.shape=[batch, seq-len]
39
+ # probs.shape=[batch, vocab]
40
+ sorted_probs, sorted_indices = torch.sort(probs, descending=False)
41
+ cumulative_probs = sorted_probs.cumsum(dim=-1)
42
+
43
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
44
+ sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
45
+ # Keep at least min_tokens_to_keep
46
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
47
+
48
+ # scatter sorted tensors to original indexing
49
+ indices_to_remove = sorted_indices_to_remove.scatter(
50
+ 1, sorted_indices, sorted_indices_to_remove
51
+ )
52
+ probs = probs.masked_fill(indices_to_remove, 0.0)
53
+ probs = probs / probs.sum(dim=-1, keepdim=True)
54
+ return probs
55
+
56
+
57
+ class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
58
+ def __init__(
59
+ self, token_ids: list[int], start_index: int, end_index: int | None = None
60
+ ):
61
+ self.token_ids = torch.tensor(token_ids)
62
+ self.start_index = start_index
63
+ self.end_index = end_index if end_index is not None else math.inf
64
+
65
+ def __call__(
66
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
67
+ ) -> torch.FloatTensor:
68
+ current_index = input_ids.shape[1]
69
+ if self.start_index <= current_index < self.end_index:
70
+ logits[:, self.token_ids] = -math.inf
71
+ return logits
72
+
73
+
74
+ class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
75
+ def __init__(self, token_ids: list[int]):
76
+ super().__init__(token_ids, 0)
77
+
78
+
79
+ class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
80
+ def __init__(self, token_ids: list[int], index: int):
81
+ super().__init__(token_ids, index, index + 1)
82
+
83
+
84
+ class DisallowTokensAfterIndexLogitsProcessor(
85
+ DisallowTokensInIndexRangeLogitsProcessor
86
+ ):
87
+ def __init__(self, token_ids: list[int], index: int):
88
+ super().__init__(token_ids, index + 1)
89
+
90
+
91
+ class DisallowTokensAtOrAfterIndexLogitsProcessor(
92
+ DisallowTokensInIndexRangeLogitsProcessor
93
+ ):
94
+ def __init__(self, token_ids: list[int], index: int):
95
+ super().__init__(token_ids, index)
96
+
97
+
98
+ class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
99
+ def __init__(
100
+ self,
101
+ token_ids: list[int],
102
+ start_indices: list[int],
103
+ end_indices: list[int] | None = None,
104
+ ):
105
+ self.token_ids = torch.tensor(token_ids)
106
+ self.start_indices = torch.tensor(start_indices)
107
+ self.end_indices = (
108
+ torch.tensor(end_indices)
109
+ if end_indices is not None
110
+ else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
111
+ )
112
+
113
+ def __call__(
114
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
115
+ ) -> torch.FloatTensor:
116
+ # input_ids.shape = [batch, seq_len]
117
+ # logits.shape = [batch, vocab]
118
+ current_index = input_ids.shape[1]
119
+ mask = (self.start_indices <= current_index) & (
120
+ current_index < self.end_indices
121
+ )
122
+ # The following will fail if the mask is all False.
123
+ # logits[mask, self.token_ids] = -math.inf
124
+ logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
125
+ return logits
126
+
127
+
128
+ class DisallowTokensAtBatchIndexLogitsProcessor(
129
+ DisallowTokensInBatchIndexRangeLogitsProcessor
130
+ ):
131
+ def __init__(self, token_ids: list[int], batch_index: list[int]):
132
+ super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])
133
+
134
+
135
+ class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
136
+ def __init__(
137
+ self, token_ids: list[int], start_index: int, end_index: int | None = None
138
+ ):
139
+ self.token_ids = torch.tensor(token_ids)
140
+ self.start_index = start_index
141
+ self.end_index = end_index if end_index is not None else math.inf
142
+
143
+ def __call__(
144
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
145
+ ) -> torch.FloatTensor:
146
+ current_index = input_ids.shape[1]
147
+ if self.start_index <= current_index < self.end_index:
148
+ replacement = torch.full_like(logits, -math.inf)
149
+ replacement[:, self.token_ids] = logits[:, self.token_ids]
150
+ logits[:] = replacement
151
+ return logits
152
+
153
+
154
+ class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
155
+ def __init__(self, token_ids: list[int]):
156
+ super().__init__(token_ids, 0)
157
+
158
+
159
+ class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
160
+ def __init__(self, token_ids: list[int], index: int):
161
+ super().__init__(token_ids, index, index + 1)
162
+
163
+
164
+ class AllowOnlyTokensAfterIndexLogitsProcessor(
165
+ AllowOnlyTokensInIndexRangeLogitsProcessor
166
+ ):
167
+ def __init__(self, token_ids: list[int], index: int):
168
+ super().__init__(token_ids, index + 1)
169
+
170
+
171
+ class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
172
+ AllowOnlyTokensInIndexRangeLogitsProcessor
173
+ ):
174
+ def __init__(self, token_ids: list[int], index: int):
175
+ super().__init__(token_ids, index)
176
+
177
+
178
+ class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
179
+ def __init__(
180
+ self,
181
+ token_ids: list[int],
182
+ start_indices: list[int],
183
+ end_indices: list[int] | None = None,
184
+ ):
185
+ self.token_ids = torch.tensor(token_ids)
186
+ self.start_indices = torch.tensor(start_indices)
187
+ self.end_indices = (
188
+ torch.tensor(end_indices)
189
+ if end_indices is not None
190
+ else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
191
+ )
192
+
193
+ def __call__(
194
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
195
+ ) -> torch.FloatTensor:
196
+ # input_ids.shape = [batch, seq_len]
197
+ # logits.shape = [batch, vocab]
198
+ current_index = input_ids.shape[1]
199
+ mask = (self.start_indices <= current_index) & (
200
+ current_index < self.end_indices
201
+ )
202
+
203
+ valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
204
+ full_mask = torch.full_like(logits, -math.inf)
205
+ full_mask[valid_batch_indices, self.token_ids] = logits[
206
+ valid_batch_indices, self.token_ids
207
+ ]
208
+
209
+ logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
210
+ return logits
211
+
212
+
213
+ class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
214
+ def __init__(
215
+ self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
216
+ ):
217
+ self.trigger_token_id = trigger_token_id
218
+ self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
219
+ self.offset = offset
220
+
221
+ def __call__(
222
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
223
+ ) -> torch.FloatTensor:
224
+ # input_ids.shape=[batch, seq_len]
225
+ # logits.shape=[batch, vocab]
226
+ if input_ids.shape[1] < self.offset:
227
+ return logits
228
+
229
+ trigger_positions = (
230
+ input_ids[:, -self.offset] == self.trigger_token_id
231
+ ).unsqueeze(-1)
232
+
233
+ disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
234
+ disallowed_tokens_mask[:, self.subsequent_token_ids] = False
235
+
236
+ return logits.masked_fill_(
237
+ disallowed_tokens_mask & trigger_positions,
238
+ -math.inf,
239
+ )
240
+
241
+
242
+ class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
243
+ def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
244
+ self.trigger_token_id = trigger_token_id
245
+ self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
246
+ 0
247
+ ) # shape: [1, num_allowed_tokens]
248
+ self.width = width
249
+
250
+ def __call__(
251
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
252
+ ) -> torch.FloatTensor:
253
+ # input_ids.shape=[batch, seq_len]
254
+ # logits.shape=[batch, vocab]
255
+ width = min(self.width, input_ids.shape[1])
256
+ trigger_positions = (
257
+ (input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
258
+ )
259
+
260
+ disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
261
+ disallowed_tokens_mask[:, self.allowed_token_ids] = False
262
+
263
+ return logits.masked_fill_(
264
+ disallowed_tokens_mask & trigger_positions,
265
+ -math.inf,
266
+ )
267
+
268
+
269
+ class CFGLogitsProcessor(LogitsProcessor):
270
+ def __init__(
271
+ self,
272
+ guidance_scale: float,
273
+ unconditional_ids: torch.LongTensor,
274
+ model,
275
+ ):
276
+ self.guidance_scale = guidance_scale
277
+ self.unconditional_ids = unconditional_ids
278
+ self.model = model
279
+
280
+ def __call__(
281
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
282
+ ) -> torch.FloatTensor:
283
+ conditioned_logits = logits
284
+
285
+ self.unconditional_ids = torch.cat(
286
+ [self.unconditional_ids, input_ids[:, -1:]], dim=1
287
+ )
288
+ unconditioned_outputs = self.model(self.unconditional_ids)
289
+ unconditioned_logits = unconditioned_outputs[:, -1, :]
290
+ return (
291
+ self.guidance_scale * (conditioned_logits - unconditioned_logits)
292
+ + unconditioned_logits
293
+ )
294
+
295
+
296
+ class InBatchCFGLogitsProcessor(LogitsProcessor):
297
+ def __init__(self, guidance_scale: float):
298
+ self.guidance_scale = guidance_scale
299
+
300
+ def __call__(
301
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
302
+ ) -> torch.FloatTensor:
303
+ # input_ids.shape=[2*batch, seq-len]
304
+ # logits.shape=[2*batch, vocab]
305
+ conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
306
+ mixed_logits = unconditioned_logits + self.guidance_scale * (
307
+ conditioned_logits - unconditioned_logits
308
+ )
309
+ return mixed_logits.repeat(2, 1)
310
+
311
+
312
+ class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
313
+ # See https://arxiv.org/abs/2211.09800
314
+
315
+ def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
316
+ self.guidance_scale_text = guidance_scale_text
317
+ self.guidance_scale_image = guidance_scale_image
318
+
319
+ def __call__(
320
+ self, input_ids: torch.LongTensor, logits: torch.FloatTensor
321
+ ) -> torch.FloatTensor:
322
+ # input_ids.shape=[3*batch, seq-len]
323
+ # logits.shape=[3*batch, vocab]
324
+ (
325
+ full_conditioned_logits,
326
+ image_conditioned_logits,
327
+ unconditioned_logits,
328
+ ) = logits.chunk(3)
329
+ mixed_logits = (
330
+ unconditioned_logits
331
+ + self.guidance_scale_image
332
+ * (image_conditioned_logits - unconditioned_logits)
333
+ + self.guidance_scale_text
334
+ * (full_conditioned_logits - image_conditioned_logits)
335
+ )
336
+ return mixed_logits.repeat(3, 1)
chameleon/inference/model_adapter.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from abc import ABC, abstractmethod
8
+
9
+ import torch
10
+
11
+ from chameleon.inference import transformer
12
+ from chameleon.inference.alignment import (
13
+ AlignPromptLeft,
14
+ AlignPromptRight,
15
+ PromptAlignment,
16
+ )
17
+ from chameleon.inference.cudagraph import cudagraph_wrap
18
+
19
+
20
+ class ModelAdapter(ABC):
21
+ @abstractmethod
22
+ def initialize(self, prompt_tokens: list[list[int]]):
23
+ ...
24
+
25
+ @abstractmethod
26
+ def supports_alignment(self, alignment: PromptAlignment) -> bool:
27
+ ...
28
+
29
+ @abstractmethod
30
+ @torch.inference_mode()
31
+ def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
32
+ ...
33
+
34
+
35
+ class ChameleonModelAdapter(ModelAdapter):
36
+ """Adapter for Chameleon-style model that handles state, such as cache."""
37
+
38
+ def __init__(
39
+ self,
40
+ model: transformer.Transformer,
41
+ max_seq_len: int,
42
+ dtype: torch.dtype | None = None,
43
+ ):
44
+ super().__init__()
45
+ self._args = model.args
46
+ self._model = model
47
+ self._max_seq_len = max_seq_len
48
+ self._dtype = dtype or next(model.parameters()).data.dtype
49
+
50
+ def initialize(self, prompt_tokens: list[list[int]]):
51
+ self._prompt_lengths = [len(toks) for toks in prompt_tokens]
52
+ batch_size = len(prompt_tokens)
53
+
54
+ self._cache = transformer.make_cache(
55
+ args=self._args,
56
+ length=batch_size * self._max_seq_len,
57
+ dtype=self._dtype,
58
+ )
59
+
60
+ self._local_inputs = torch.zeros([batch_size], dtype=int, device="cuda")
61
+
62
+ self._forward = cudagraph_wrap(self._model.forward_with_attn_bias)
63
+
64
+ self._first_pass = True
65
+
66
+ def supports_alignment(self, alignment: PromptAlignment) -> bool:
67
+ return isinstance(alignment, AlignPromptLeft) or isinstance(
68
+ alignment, AlignPromptRight
69
+ )
70
+
71
+ def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
72
+ # inputs.shape=[batch, seq-len]
73
+ batch_size, seq_len = inputs.shape
74
+
75
+ if self._first_pass:
76
+ attn_seqlen = [min(pl, seq_len) for pl in self._prompt_lengths]
77
+ self._bias = transformer.AttnBias.from_seqlens(
78
+ q_seqlen=attn_seqlen,
79
+ kv_seqlen=attn_seqlen,
80
+ kv_padding=self._max_seq_len,
81
+ )
82
+
83
+ mask = torch.zeros_like(inputs, dtype=torch.bool)
84
+ for i, k in enumerate(self._prompt_lengths):
85
+ mask[i, -k:] = True
86
+
87
+ flat_outputs: torch.Tensor = self._forward( # type: ignore
88
+ token_values=inputs[mask],
89
+ attn_bias=self._bias,
90
+ cache=self._cache,
91
+ )
92
+ self._local_outputs = torch.full(
93
+ (inputs.shape[0], inputs.shape[1], flat_outputs.shape[-1]),
94
+ -math.inf,
95
+ )
96
+ self._local_outputs[mask] = flat_outputs
97
+
98
+ self._vocab_size = self._local_outputs.shape[-1]
99
+
100
+ self._bias.q_seqinfo.seqstart.copy_(
101
+ torch.arange(batch_size + 1, dtype=torch.int)
102
+ )
103
+ self._bias.q_seqinfo.max_seqlen = 1
104
+ self._bias.q_seqinfo.seqstart_py = self._bias.q_seqinfo.seqstart.tolist()
105
+
106
+ self._first_pass = False
107
+
108
+ else:
109
+ self._local_inputs.copy_(inputs[:, -1]) # type: ignore
110
+
111
+ self._local_outputs = self._forward( # type: ignore
112
+ token_values=self._local_inputs,
113
+ attn_bias=self._bias,
114
+ cache=self._cache,
115
+ )
116
+
117
+ self._bias.k_seqinfo.seqlen.add_(1)
118
+ return self._local_outputs.view(batch_size, -1, self._vocab_size)
chameleon/inference/stopping_criteria.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class StoppingCriteria:
10
+ def __call__(
11
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
12
+ ) -> bool:
13
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
14
+
15
+
16
+ class StoppingCriteriaList(list):
17
+ def __call__(
18
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
19
+ ) -> bool:
20
+ return any(criteria(input_ids, scores, **kwargs) for criteria in self)
21
+
22
+
23
+ class MaxLengthCriteria(StoppingCriteria):
24
+ def __init__(self, max_length: int):
25
+ self.max_length = max_length
26
+
27
+ def __call__(
28
+ self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
29
+ ) -> bool:
30
+ cur_len = input_ids.shape[-1]
31
+ return cur_len >= self.max_length
32
+
33
+
34
+ class StopOnEOS(StoppingCriteria):
35
+ def __init__(self, eos_id: int):
36
+ self._eos_id = eos_id
37
+
38
+ def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
39
+ # input_ids.shape=[batch, seq_len]
40
+ return (input_ids == self._eos_id).sum(dim=1).all()
41
+
42
+
43
+ class StopOnEOSAfterBatchIndex(StoppingCriteria):
44
+ def __init__(self, eos_id: int, batch_index: list[int]):
45
+ self._eos_id = eos_id
46
+ self.batch_index = torch.tensor(batch_index, dtype=torch.long).unsqueeze(1)
47
+
48
+ def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
49
+ # input_ids.shape=[batch, seq_len]
50
+ eos_mask = input_ids == self._eos_id
51
+ consider_eos_mask = (
52
+ torch.arange(input_ids.shape[1]).unsqueeze(0) >= self.batch_index
53
+ )
54
+ valid_eos = eos_mask & consider_eos_mask
55
+ return valid_eos.sum(dim=1).all()
chameleon/inference/token_selector.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ class TokenSelector:
10
+ def __call__(
11
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
12
+ ) -> torch.FloatTensor:
13
+ # input_ids.shape=[batch, seq_len]
14
+ # probs.shape=[batch, vocab]
15
+ ...
16
+
17
+
18
+ class ArgmaxTokenSelector(TokenSelector):
19
+ def __call__(
20
+ self, _: torch.LongTensor, probs: torch.FloatTensor
21
+ ) -> torch.LongTensor:
22
+ # probs.shape=[batch, vocab]
23
+ return probs.argmax(dim=1)
24
+
25
+
26
+ class MultinomialTokenSelector(TokenSelector):
27
+ def __call__(
28
+ self, _: torch.LongTensor, probs: torch.FloatTensor
29
+ ) -> torch.LongTensor:
30
+ # probs.shape=[batch, vocab]
31
+ return probs.multinomial(num_samples=1).squeeze(1)
32
+
33
+
34
+ class ReplicatedInputTokenSelector(TokenSelector):
35
+ def __init__(self, token_selector: TokenSelector, n: int):
36
+ self.token_selector = token_selector
37
+ self.n = n
38
+
39
+ def __call__(
40
+ self, input_ids: torch.LongTensor, probs: torch.FloatTensor
41
+ ) -> torch.LongTensor:
42
+ # input_ids.shape=[n*batch, seq_len]
43
+ # probs.shape=[n*batch, vocab]
44
+ primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
45
+ primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
46
+ tokens = self.token_selector(primary_input_ids, primary_probs)
47
+ return tokens.repeat(self.n)
chameleon/inference/transformer.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ from torch import distributed as dist
10
+ from torch import nn
11
+ from torch.nn import functional as F
12
+ from xformers.ops import RMSNorm, fmha, rope_padded
13
+ from xformers.ops.fmha.attn_bias import (
14
+ BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
15
+ )
16
+
17
+
18
+ @dataclass
19
+ class ModelArgs:
20
+ model_parallel_size: int = 1
21
+ dim: int = 512
22
+ n_layers: int = 8
23
+ n_heads: int = 8
24
+ n_kv_heads: int | None = None
25
+ vocab_size: int = -1
26
+ ffn_dim_multiplier: float | None = None
27
+ multiple_of: int = 256
28
+ norm_eps: float = 1e-5
29
+ rope_theta: float = 10000.0
30
+ qk_normalization: bool = False
31
+ swin_norm: bool = False
32
+
33
+
34
+ LayerCache = tuple[torch.Tensor, torch.Tensor]
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ model_parallel_size: int,
41
+ dim: int,
42
+ head_dim: int,
43
+ n_heads: int,
44
+ n_kv_heads: int,
45
+ rope_theta: float,
46
+ qk_normalization: bool = False,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.model_parallel_size = model_parallel_size
51
+
52
+ self.head_dim = head_dim
53
+ self.rope_theta = rope_theta
54
+
55
+ self.n_local_heads = n_heads // model_parallel_size
56
+ self.n_local_kv_heads = n_kv_heads // model_parallel_size
57
+
58
+ self.wqkv = nn.Linear(
59
+ dim,
60
+ (self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
61
+ bias=False,
62
+ dtype=torch.bfloat16,
63
+ )
64
+ self.wo = nn.Linear(
65
+ self.n_local_heads * head_dim,
66
+ dim,
67
+ bias=False,
68
+ dtype=torch.bfloat16,
69
+ )
70
+
71
+ self.qk_normalization = qk_normalization
72
+ if qk_normalization:
73
+ self.q_normalization = torch.nn.LayerNorm(head_dim)
74
+ self.k_normalization = torch.nn.LayerNorm(head_dim)
75
+
76
+ self._register_load_state_dict_pre_hook(self.load_hook)
77
+
78
+ # This adapter makes sure we can load vanilla
79
+ # Llama checkpoints where wq, wk, and wv are
80
+ # not fused in a single parameter
81
+ def load_hook(
82
+ self,
83
+ state_dict,
84
+ prefix,
85
+ local_metadata,
86
+ strict,
87
+ missing_keys,
88
+ unexpected_keys,
89
+ error_msgs,
90
+ ):
91
+ if prefix + "wq.weight" in state_dict:
92
+ wq = state_dict.pop(prefix + "wq.weight")
93
+ wk = state_dict.pop(prefix + "wk.weight")
94
+ wv = state_dict.pop(prefix + "wv.weight")
95
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
96
+
97
+ def forward(
98
+ self,
99
+ x: torch.Tensor,
100
+ cache: LayerCache,
101
+ attn_bias: AttnBias,
102
+ group: dist.ProcessGroup | None = None,
103
+ ) -> torch.Tensor:
104
+ # x.shape is (sum(seq_lens), dim)
105
+ #
106
+ # Since we support heterogenous sequence
107
+ # lengths, the hidden states are all
108
+ # concatenated together along the usual
109
+ # sequence dimension. The attention below
110
+ # finds out where sequences start & end
111
+ # using the provided attention bias.
112
+ xqkv = self.wqkv(x)
113
+ xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
114
+ xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
115
+ xk, xv = xkv.chunk(2, 1)
116
+
117
+ if self.qk_normalization:
118
+ xq = xq.view(-1, self.n_local_heads, self.head_dim)
119
+ xq = self.q_normalization(xq)
120
+ xq = xq.view(-1, self.n_local_heads * self.head_dim)
121
+
122
+ xk = xk.view(-1, self.n_local_kv_heads, self.head_dim)
123
+ xk = self.k_normalization(xk)
124
+ xk = xk.view(-1, self.n_local_kv_heads * self.head_dim)
125
+
126
+ output_shape = xq.shape
127
+ xq = xq.view(1, xq.shape[0], self.n_local_heads, self.head_dim)
128
+ xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, self.head_dim)
129
+ xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, self.head_dim)
130
+ cache_k, cache_v = cache
131
+
132
+ xq = rope_padded(
133
+ xq=xq,
134
+ xk=xk,
135
+ xv=xv,
136
+ cache_k=cache_k,
137
+ cache_v=cache_v,
138
+ attn_bias=attn_bias,
139
+ theta=self.rope_theta,
140
+ )
141
+
142
+ # Handle GQA
143
+ # Q shape: [B, M, Hkv, Hq // Hkv, K]
144
+ heads_per_group = self.n_local_heads // self.n_local_kv_heads
145
+ cache_k = cache_k.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
146
+ cache_v = cache_v.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
147
+ xq = xq.reshape(
148
+ [*xq.shape[:2], self.n_local_kv_heads, heads_per_group, xq.shape[-1]]
149
+ )
150
+
151
+ # rope_padded() updated the caches, so we
152
+ # call attention directly
153
+ output = fmha.memory_efficient_attention_forward(
154
+ xq, cache_k, cache_v, attn_bias
155
+ )
156
+
157
+ output = self.wo(output.reshape(output_shape))
158
+ if self.model_parallel_size > 1:
159
+ dist.all_reduce(output, group=group)
160
+
161
+ return output
162
+
163
+
164
+ class FeedForward(nn.Module):
165
+ def __init__(
166
+ self,
167
+ model_parallel_size: int,
168
+ dim: int,
169
+ hidden_dim: int,
170
+ multiple_of: int,
171
+ ffn_dim_multiplier: float | None,
172
+ ):
173
+ super().__init__()
174
+
175
+ self.model_parallel_size = model_parallel_size
176
+
177
+ hidden_dim = int(2 * hidden_dim / 3)
178
+ if ffn_dim_multiplier is not None:
179
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
180
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
181
+ assert hidden_dim % model_parallel_size == 0
182
+
183
+ self.w13 = nn.Linear(
184
+ dim,
185
+ 2 * hidden_dim // model_parallel_size,
186
+ bias=False,
187
+ )
188
+ self.w2 = nn.Linear(
189
+ hidden_dim // model_parallel_size,
190
+ dim,
191
+ bias=False,
192
+ )
193
+ self._register_load_state_dict_pre_hook(self.load_hook)
194
+
195
+ # This adapter makes sure we can load vanilla
196
+ # Llama checkpoints where w1 and w3 are not
197
+ # fused in a single parameter
198
+ def load_hook(
199
+ self,
200
+ state_dict,
201
+ prefix,
202
+ local_metadata,
203
+ strict,
204
+ missing_keys,
205
+ unexpected_keys,
206
+ error_msgs,
207
+ ):
208
+ if prefix + "w1.weight" in state_dict:
209
+ w1 = state_dict.pop(prefix + "w1.weight")
210
+ w3 = state_dict.pop(prefix + "w3.weight")
211
+ state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
212
+
213
+ def forward(
214
+ self, x: torch.Tensor, group: dist.ProcessGroup | None = None
215
+ ) -> torch.Tensor:
216
+ x13 = self.w13(x)
217
+ x1, x3 = x13.chunk(2, -1)
218
+ output = self.w2(F.silu(x1) * x3)
219
+ if self.model_parallel_size > 1:
220
+ dist.all_reduce(output, group=group)
221
+ return output
222
+
223
+
224
+ class TransformerBlock(nn.Module):
225
+ def __init__(self, args: ModelArgs):
226
+ super().__init__()
227
+
228
+ assert args.dim % args.n_heads == 0
229
+ head_dim = args.dim // args.n_heads
230
+ if args.n_kv_heads is not None:
231
+ n_kv_heads = args.n_kv_heads
232
+ else:
233
+ n_kv_heads = args.n_heads
234
+
235
+ model_parallel_size = args.model_parallel_size
236
+ assert args.n_heads % n_kv_heads == 0
237
+ assert args.n_heads % model_parallel_size == 0
238
+ assert n_kv_heads % model_parallel_size == 0
239
+
240
+ self.attention = Attention(
241
+ model_parallel_size=model_parallel_size,
242
+ dim=args.dim,
243
+ head_dim=head_dim,
244
+ n_heads=args.n_heads,
245
+ n_kv_heads=n_kv_heads,
246
+ rope_theta=args.rope_theta,
247
+ qk_normalization=args.qk_normalization,
248
+ )
249
+ self.feed_forward = FeedForward(
250
+ model_parallel_size=model_parallel_size,
251
+ dim=args.dim,
252
+ hidden_dim=4 * args.dim,
253
+ multiple_of=args.multiple_of,
254
+ ffn_dim_multiplier=args.ffn_dim_multiplier,
255
+ )
256
+ self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
257
+ self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
258
+ self.swin_norm = args.swin_norm
259
+
260
+ def forward(
261
+ self,
262
+ x: torch.Tensor,
263
+ cache: LayerCache,
264
+ attn_bias: AttnBias,
265
+ group: dist.ProcessGroup | None = None,
266
+ ) -> torch.Tensor:
267
+ if self.swin_norm:
268
+ h = x + self.attention_norm(
269
+ self.attention.forward(
270
+ x,
271
+ cache,
272
+ attn_bias,
273
+ group=group,
274
+ )
275
+ )
276
+ out = h + self.ffn_norm(self.feed_forward(h, group=group))
277
+ else:
278
+ h = x + self.attention.forward(
279
+ self.attention_norm(x),
280
+ cache,
281
+ attn_bias,
282
+ group=group,
283
+ )
284
+ out = h + self.feed_forward(self.ffn_norm(h), group=group)
285
+ return out
286
+
287
+
288
+ class Transformer(nn.Module):
289
+ def __init__(self, args: ModelArgs):
290
+ super().__init__()
291
+ self.args = args
292
+
293
+ self.model_parallel_size = args.model_parallel_size
294
+ assert args.dim % self.model_parallel_size == 0
295
+ assert args.vocab_size > 0
296
+ assert args.vocab_size % self.model_parallel_size == 0
297
+
298
+ self.tok_embeddings = nn.Embedding(
299
+ num_embeddings=args.vocab_size,
300
+ embedding_dim=args.dim // self.model_parallel_size,
301
+ )
302
+
303
+ self.layers = nn.ModuleList()
304
+ for _ in range(args.n_layers):
305
+ self.layers.append(TransformerBlock(args))
306
+
307
+ self.norm = RMSNorm(args.dim, eps=args.norm_eps)
308
+
309
+ self.output = nn.Linear(
310
+ args.dim,
311
+ args.vocab_size // self.model_parallel_size,
312
+ bias=False,
313
+ )
314
+
315
+ @torch.no_grad()
316
+ def forward_with_attn_bias(
317
+ self,
318
+ token_values: torch.Tensor,
319
+ attn_bias: AttnBias,
320
+ cache: list[LayerCache],
321
+ group: dist.ProcessGroup | None = None,
322
+ ) -> torch.Tensor:
323
+ h = self.tok_embeddings(token_values)
324
+ if self.model_parallel_size > 1:
325
+ gather = [torch.empty_like(h) for _ in range(self.model_parallel_size)]
326
+ dist.all_gather(gather, h, group=group)
327
+ h = torch.cat(gather, dim=-1)
328
+
329
+ for i, layer in enumerate(self.layers):
330
+ h = layer(h, cache[i], attn_bias, group=group)
331
+
332
+ logits = self.output(self.norm(h))
333
+ if self.model_parallel_size > 1:
334
+ gather = [torch.empty_like(logits) for _ in range(self.model_parallel_size)]
335
+ dist.all_gather(gather, logits, group=group)
336
+ logits = torch.cat(gather, dim=-1)
337
+ return logits.float()
338
+
339
+ def forward(
340
+ self,
341
+ token_values: torch.Tensor,
342
+ token_lengths: torch.Tensor,
343
+ start_pos: torch.Tensor,
344
+ cache: list[LayerCache],
345
+ kv_padding: int,
346
+ group: dist.ProcessGroup | None = None,
347
+ ) -> torch.Tensor:
348
+ attn_bias = AttnBias.from_seqlens(
349
+ q_seqlen=token_lengths.tolist(),
350
+ kv_seqlen=(start_pos + token_lengths).tolist(),
351
+ kv_padding=kv_padding,
352
+ )
353
+ return self.forward_with_attn_bias(token_values, attn_bias, cache, group=group)
354
+
355
+
356
+ def make_cache(
357
+ args: ModelArgs,
358
+ length: int,
359
+ device: str | torch.device | None = None,
360
+ n_layers: int | None = None,
361
+ dtype: torch.dtype | None = None,
362
+ ) -> list[LayerCache]:
363
+ """
364
+ Allocate a cache to be used with the Transformer module.
365
+
366
+ Args:
367
+ args (ModelArgs): the model configuration.
368
+ length (int): per layer cache size.
369
+ It is usually budgeted as ``max_batch * max_seq``
370
+ device (torch.device, optional): the device on which
371
+ the cache should be allocated.
372
+ n_layers (int, optional): the number of layers to
373
+ allocate a cache for (defaults to the model
374
+ settings).
375
+ dtype (torch.dtype, optional): the dtype to use for
376
+ cache entries (defaults to the default dtype).
377
+
378
+ Returns:
379
+ The cache object to pass to ``Tranformer.forward``.
380
+ """
381
+
382
+ head_dim = args.dim // args.n_heads
383
+ n_kv_heads = args.n_kv_heads
384
+ if n_kv_heads is None:
385
+ n_kv_heads = args.n_heads
386
+ n_local_kv_heads = n_kv_heads // args.model_parallel_size
387
+
388
+ if n_layers is None:
389
+ n_layers = args.n_layers
390
+
391
+ shape = (1, length, n_local_kv_heads, head_dim)
392
+ return [
393
+ (
394
+ torch.zeros(shape, device=device, dtype=dtype),
395
+ torch.zeros(shape, device=device, dtype=dtype),
396
+ )
397
+ for _ in range(n_layers)
398
+ ]
399
+
400
+
401
+ def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
402
+ """
403
+ Take a prefix view of a larger cache.
404
+
405
+ The original cache object remains of identical size and valid
406
+ after the shrinked alias has been used. This function is useful
407
+ when a cache was allocated for a larger batch size than what is
408
+ necessary.
409
+
410
+ Args:
411
+ cache: the cache to take a view in.
412
+ length (int): the desired length
413
+
414
+ Returns:
415
+ A view in the input cache object.
416
+ """
417
+
418
+ if len(cache) > 0:
419
+ assert cache[0][0].shape[1] >= length
420
+
421
+ return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
chameleon/inference/utils.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import socket
7
+ from typing import Generator, Generic, Iterator, TypeVar
8
+
9
+ T = TypeVar("T")
10
+
11
+
12
+ class DynamicGenerator(Generic[T]):
13
+ def __init__(self, gen: Generator[T, None, None]):
14
+ self.gen = gen
15
+
16
+ def __iter__(self) -> Iterator[T]:
17
+ return self
18
+
19
+ def __next__(self) -> T:
20
+ return next(self.gen)
21
+
22
+
23
+ def advance(iterator: Iterator[T], steps: int):
24
+ try:
25
+ for _ in range(steps):
26
+ next(iterator)
27
+ except StopIteration:
28
+ pass
29
+
30
+
31
+ def random_unused_port():
32
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
33
+ s.bind(("", 0))
34
+ return s.getsockname()[1]
chameleon/inference/vocab.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from functools import cached_property
7
+
8
+ import torch
9
+
10
+
11
+ class VocabInfo:
12
+ def __init__(self, vocab_map: dict[str, int]):
13
+ self.name2val = vocab_map
14
+
15
+ self.bos_id = vocab_map.get("<s>")
16
+ self.eos_id = vocab_map.get("</s>")
17
+ self.boi_id = vocab_map.get("<racm3:break>")
18
+ self.eoi_id = vocab_map.get("<eoss>")
19
+ self.pad_id = vocab_map.get("<pad>")
20
+ self.eot_id = vocab_map.get("<reserved08706>")
21
+
22
+ @property
23
+ def begin_sequence(self) -> int:
24
+ return self.bos_id
25
+
26
+ @property
27
+ def end_sequence(self) -> int:
28
+ return self.eos_id
29
+
30
+ @property
31
+ def begin_image(self) -> int:
32
+ return self.boi_id
33
+
34
+ @property
35
+ def end_image(self) -> int:
36
+ return self.eoi_id
37
+
38
+ @property
39
+ def padding(self) -> int:
40
+ return self.pad_id
41
+
42
+ @property
43
+ def end_turn(self) -> int:
44
+ return self.eot_id
45
+
46
+ @cached_property
47
+ def val2name(self) -> dict[int, str]:
48
+ return {v: k for k, v in self.name2val.items()}
49
+
50
+ @cached_property
51
+ def all_tokens(self) -> list[int]:
52
+ return sorted(self.name2val.values())
53
+
54
+ @cached_property
55
+ def image_tokens(self) -> list[int]:
56
+ return sorted(
57
+ [val for name, val in self.name2val.items() if name.startswith("IMGIMG")]
58
+ )
59
+
60
+ @cached_property
61
+ def special_tokens(self) -> list[int]:
62
+ return sorted(
63
+ [
64
+ val
65
+ for name, val in self.name2val.items()
66
+ if name.startswith("<") and name != "<"
67
+ ]
68
+ )
69
+
70
+ @cached_property
71
+ def text_tokens(self) -> list[int]:
72
+ return sorted(
73
+ set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)
74
+ )
75
+
76
+
77
+ class VocabTranslation:
78
+ def __init__(self, vocab_info: VocabInfo, device: str | None = None):
79
+ self._vocab = vocab_info
80
+ self._device = device
81
+
82
+ @cached_property
83
+ def bpe2img(self) -> dict[int, int]: # vocab id => codebook id, i.e. [4:8195] => [0:8191]
84
+ img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} # A-J: 0-9
85
+
86
+ def remap(old_name: str) -> str:
87
+ return "".join(
88
+ img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] # last chr is 'Z'
89
+ )
90
+ # e.g.: IMGIMGFDZ => FD => 53,
91
+
92
+ return {
93
+ tok: int(remap(self._vocab.val2name[tok]))
94
+ for tok in self._vocab.image_tokens # the token starts with 'IMGIMG', value: [4: 8195]
95
+ }
96
+
97
+ @cached_property
98
+ def img2bpe(self) -> dict[int, int]:
99
+ return {v: k for k, v in self.bpe2img.items()} # codebook id => vocab id, i.e. [0:8191] => [4:8191]
100
+
101
+ @cached_property
102
+ def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
103
+ sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
104
+ sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
105
+ return sorted_bpe, sorted_img
106
+
107
+ @cached_property
108
+ def img2bpe_mapping_tensor(self) -> torch.LongTensor:
109
+ mapping = torch.zeros(
110
+ max(self.img2bpe.keys()) + 1,
111
+ dtype=torch.int,
112
+ device=self._device,
113
+ )
114
+ for k, v in self.img2bpe.items():
115
+ mapping[k] = v
116
+ return mapping
117
+
118
+ def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
119
+ bpe_tok, img_tok = self.bpe2img_search_tensors
120
+ return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]
121
+
122
+ def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
123
+ return self.img2bpe_mapping_tensor[img_batch]
chameleon/inference/vqgan.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
8
+ [with minimal dependencies]
9
+
10
+ This implementation is inference-only -- training steps and optimizer components
11
+ introduce significant additional dependencies
12
+ """
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class VectorQuantizer2(nn.Module):
21
+ """
22
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
23
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
24
+ """
25
+
26
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
27
+ # backwards compatibility we use the buggy version by default, but you can
28
+ # specify legacy=False to fix it.
29
+ def __init__(
30
+ self,
31
+ n_e,
32
+ e_dim,
33
+ beta,
34
+ remap=None,
35
+ unknown_index="random",
36
+ sane_index_shape=False,
37
+ legacy=True,
38
+ ):
39
+ super().__init__()
40
+ self.n_e = n_e
41
+ self.e_dim = e_dim
42
+ self.beta = beta
43
+ self.legacy = legacy
44
+
45
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
46
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
47
+
48
+ self.remap = remap
49
+ if self.remap is not None:
50
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
51
+ self.re_embed = self.used.shape[0]
52
+ self.unknown_index = unknown_index # "random" or "extra" or integer
53
+ if self.unknown_index == "extra":
54
+ self.unknown_index = self.re_embed
55
+ self.re_embed = self.re_embed + 1
56
+ print(
57
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
58
+ f"Using {self.unknown_index} for unknown indices."
59
+ )
60
+ else:
61
+ self.re_embed = n_e
62
+
63
+ self.sane_index_shape = sane_index_shape
64
+
65
+ def remap_to_used(self, inds):
66
+ ishape = inds.shape
67
+ assert len(ishape) > 1
68
+ inds = inds.reshape(ishape[0], -1)
69
+ used = self.used.to(inds)
70
+ match = (inds[:, :, None] == used[None, None, ...]).long()
71
+ new = match.argmax(-1)
72
+ unknown = match.sum(2) < 1
73
+ if self.unknown_index == "random":
74
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
75
+ device=new.device
76
+ )
77
+ else:
78
+ new[unknown] = self.unknown_index
79
+ return new.reshape(ishape)
80
+
81
+ def unmap_to_all(self, inds):
82
+ ishape = inds.shape
83
+ assert len(ishape) > 1
84
+ inds = inds.reshape(ishape[0], -1)
85
+ used = self.used.to(inds)
86
+ if self.re_embed > self.used.shape[0]: # extra token
87
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
88
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
89
+ return back.reshape(ishape)
90
+
91
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
92
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
93
+ assert rescale_logits is False, "Only for interface compatible with Gumbel"
94
+ assert return_logits is False, "Only for interface compatible with Gumbel"
95
+ # reshape z -> (batch, height, width, channel) and flatten
96
+ z = z.permute(0, 2, 3, 1).contiguous()
97
+ z_flattened = z.view(-1, self.e_dim)
98
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
99
+
100
+ d = (
101
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
102
+ + torch.sum(self.embedding.weight**2, dim=1)
103
+ - 2
104
+ * torch.einsum(
105
+ "bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
106
+ )
107
+ )
108
+
109
+ min_encoding_indices = torch.argmin(d, dim=1)
110
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
111
+ perplexity = None
112
+ min_encodings = None
113
+
114
+ # compute loss for embedding
115
+ if not self.legacy:
116
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
117
+ (z_q - z.detach()) ** 2
118
+ )
119
+ else:
120
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
121
+ (z_q - z.detach()) ** 2
122
+ )
123
+
124
+ # preserve gradients
125
+ z_q = z + (z_q - z).detach()
126
+
127
+ # reshape back to match original input shape
128
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
129
+
130
+ if self.remap is not None:
131
+ min_encoding_indices = min_encoding_indices.reshape(
132
+ z.shape[0], -1
133
+ ) # add batch axis
134
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
135
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
136
+
137
+ if self.sane_index_shape:
138
+ min_encoding_indices = min_encoding_indices.reshape(
139
+ z_q.shape[0], z_q.shape[2], z_q.shape[3]
140
+ )
141
+
142
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
143
+
144
+ def get_codebook_entry(self, indices, shape):
145
+ # shape specifying (batch, height, width, channel)
146
+ if self.remap is not None:
147
+ indices = indices.reshape(shape[0], -1) # add batch axis
148
+ indices = self.unmap_to_all(indices)
149
+ indices = indices.reshape(-1) # flatten again
150
+
151
+ # get quantized latent vectors
152
+ z_q = self.embedding(indices)
153
+
154
+ if shape is not None:
155
+ z_q = z_q.view(shape)
156
+ # reshape back to match original input shape
157
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
158
+
159
+ return z_q
160
+
161
+
162
+ # Alias
163
+ VectorQuantizer = VectorQuantizer2
164
+
165
+
166
+ def nonlinearity(x):
167
+ # swish
168
+ return x * torch.sigmoid(x)
169
+
170
+
171
+ def Normalize(in_channels, num_groups=32):
172
+ return torch.nn.GroupNorm(
173
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
174
+ )
175
+
176
+
177
+ class Upsample(nn.Module):
178
+ def __init__(self, in_channels, with_conv):
179
+ super().__init__()
180
+ self.with_conv = with_conv
181
+ if self.with_conv:
182
+ self.conv = torch.nn.Conv2d(
183
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
184
+ )
185
+
186
+ def forward(self, x):
187
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
188
+ if self.with_conv:
189
+ x = self.conv(x)
190
+ return x
191
+
192
+
193
+ class Downsample(nn.Module):
194
+ def __init__(self, in_channels, with_conv):
195
+ super().__init__()
196
+ self.with_conv = with_conv
197
+ if self.with_conv:
198
+ # no asymmetric padding in torch conv, must do it ourselves
199
+ self.conv = torch.nn.Conv2d(
200
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
201
+ )
202
+
203
+ def forward(self, x):
204
+ if self.with_conv:
205
+ pad = (0, 1, 0, 1)
206
+ x = F.pad(x, pad, mode="constant", value=0)
207
+ x = self.conv(x)
208
+ else:
209
+ x = F.avg_pool2d(x, kernel_size=2, stride=2)
210
+ return x
211
+
212
+
213
+ class ResnetBlock(nn.Module):
214
+ def __init__(
215
+ self,
216
+ *,
217
+ in_channels,
218
+ out_channels=None,
219
+ conv_shortcut=False,
220
+ dropout,
221
+ temb_channels=512,
222
+ ):
223
+ super().__init__()
224
+ self.in_channels = in_channels
225
+ out_channels = in_channels if out_channels is None else out_channels
226
+ self.out_channels = out_channels
227
+ self.use_conv_shortcut = conv_shortcut
228
+
229
+ self.norm1 = Normalize(in_channels)
230
+ self.conv1 = torch.nn.Conv2d(
231
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
232
+ )
233
+ if temb_channels > 0:
234
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
235
+ self.norm2 = Normalize(out_channels)
236
+ self.dropout = torch.nn.Dropout(dropout)
237
+ self.conv2 = torch.nn.Conv2d(
238
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
239
+ )
240
+ if self.in_channels != self.out_channels:
241
+ if self.use_conv_shortcut:
242
+ self.conv_shortcut = torch.nn.Conv2d(
243
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
244
+ )
245
+ else:
246
+ self.nin_shortcut = torch.nn.Conv2d(
247
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
248
+ )
249
+
250
+ def forward(self, x, temb):
251
+ h = x
252
+ h = self.norm1(h)
253
+ h = nonlinearity(h)
254
+ h = self.conv1(h)
255
+
256
+ if temb is not None:
257
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
258
+
259
+ h = self.norm2(h)
260
+ h = nonlinearity(h)
261
+ h = self.dropout(h)
262
+ h = self.conv2(h)
263
+
264
+ if self.in_channels != self.out_channels:
265
+ if self.use_conv_shortcut:
266
+ x = self.conv_shortcut(x)
267
+ else:
268
+ x = self.nin_shortcut(x)
269
+
270
+ return x + h
271
+
272
+
273
+ class AttnBlock(nn.Module):
274
+ def __init__(self, in_channels):
275
+ super().__init__()
276
+ self.in_channels = in_channels
277
+
278
+ self.norm = Normalize(in_channels)
279
+ self.q = torch.nn.Conv2d(
280
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
281
+ )
282
+ self.k = torch.nn.Conv2d(
283
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
284
+ )
285
+ self.v = torch.nn.Conv2d(
286
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
287
+ )
288
+ self.proj_out = torch.nn.Conv2d(
289
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
290
+ )
291
+
292
+ def forward(self, x):
293
+ h_ = x
294
+ h_ = self.norm(h_)
295
+ q = self.q(h_)
296
+ k = self.k(h_)
297
+ v = self.v(h_)
298
+
299
+ # compute attention
300
+ b, c, h, w = q.shape
301
+ q = q.reshape(b, c, h * w)
302
+ q = q.permute(0, 2, 1) # b,hw,c
303
+ k = k.reshape(b, c, h * w) # b,c,hw
304
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
305
+ w_ = w_ * (int(c) ** (-0.5))
306
+ w_ = F.softmax(w_, dim=2)
307
+
308
+ # attend to values
309
+ v = v.reshape(b, c, h * w)
310
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
311
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
312
+ h_ = h_.reshape(b, c, h, w)
313
+
314
+ h_ = self.proj_out(h_)
315
+
316
+ return x + h_
317
+
318
+
319
+ def make_attn(in_channels, attn_type="vanilla"):
320
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
321
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
322
+ if attn_type == "vanilla":
323
+ return AttnBlock(in_channels)
324
+ elif attn_type == "none":
325
+ return nn.Identity(in_channels)
326
+ else:
327
+ raise ValueError("Unexpected attention type")
328
+
329
+
330
+ class Encoder(nn.Module):
331
+ def __init__(
332
+ self,
333
+ *,
334
+ ch,
335
+ out_ch,
336
+ ch_mult=(1, 2, 4, 8),
337
+ num_res_blocks,
338
+ attn_resolutions,
339
+ dropout=0.0,
340
+ resamp_with_conv=True,
341
+ in_channels,
342
+ resolution,
343
+ z_channels,
344
+ double_z=True,
345
+ use_linear_attn=False,
346
+ attn_type="vanilla",
347
+ **ignore_kwargs,
348
+ ):
349
+ super().__init__()
350
+ if use_linear_attn:
351
+ attn_type = "linear"
352
+ self.ch = ch
353
+ self.temb_ch = 0
354
+ self.num_resolutions = len(ch_mult)
355
+ self.num_res_blocks = num_res_blocks
356
+ self.resolution = resolution
357
+ self.in_channels = in_channels
358
+
359
+ # downsampling
360
+ self.conv_in = torch.nn.Conv2d(
361
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
362
+ )
363
+
364
+ curr_res = resolution
365
+ in_ch_mult = (1,) + tuple(ch_mult)
366
+ self.in_ch_mult = in_ch_mult
367
+ self.down = nn.ModuleList()
368
+ for i_level in range(self.num_resolutions):
369
+ block = nn.ModuleList()
370
+ attn = nn.ModuleList()
371
+ block_in = ch * in_ch_mult[i_level]
372
+ block_out = ch * ch_mult[i_level]
373
+ for i_block in range(self.num_res_blocks):
374
+ block.append(
375
+ ResnetBlock(
376
+ in_channels=block_in,
377
+ out_channels=block_out,
378
+ temb_channels=self.temb_ch,
379
+ dropout=dropout,
380
+ )
381
+ )
382
+ block_in = block_out
383
+ if curr_res in attn_resolutions:
384
+ attn.append(make_attn(block_in, attn_type=attn_type))
385
+ down = nn.Module()
386
+ down.block = block
387
+ down.attn = attn
388
+ if i_level != self.num_resolutions - 1:
389
+ down.downsample = Downsample(block_in, resamp_with_conv)
390
+ curr_res = curr_res // 2
391
+ self.down.append(down)
392
+
393
+ # middle
394
+ self.mid = nn.Module()
395
+ self.mid.block_1 = ResnetBlock(
396
+ in_channels=block_in,
397
+ out_channels=block_in,
398
+ temb_channels=self.temb_ch,
399
+ dropout=dropout,
400
+ )
401
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
402
+ self.mid.block_2 = ResnetBlock(
403
+ in_channels=block_in,
404
+ out_channels=block_in,
405
+ temb_channels=self.temb_ch,
406
+ dropout=dropout,
407
+ )
408
+
409
+ # end
410
+ self.norm_out = Normalize(block_in)
411
+ self.conv_out = torch.nn.Conv2d(
412
+ block_in,
413
+ 2 * z_channels if double_z else z_channels,
414
+ kernel_size=3,
415
+ stride=1,
416
+ padding=1,
417
+ )
418
+
419
+ def forward(self, x):
420
+ # timestep embedding
421
+ temb = None
422
+
423
+ # downsampling
424
+ hs = [self.conv_in(x)]
425
+ for i_level in range(self.num_resolutions):
426
+ for i_block in range(self.num_res_blocks):
427
+ h = self.down[i_level].block[i_block](hs[-1], temb)
428
+ if len(self.down[i_level].attn) > 0:
429
+ h = self.down[i_level].attn[i_block](h)
430
+ hs.append(h)
431
+ if i_level != self.num_resolutions - 1:
432
+ hs.append(self.down[i_level].downsample(hs[-1]))
433
+
434
+ # middle
435
+ h = hs[-1]
436
+ h = self.mid.block_1(h, temb)
437
+ h = self.mid.attn_1(h)
438
+ h = self.mid.block_2(h, temb)
439
+
440
+ # end
441
+ h = self.norm_out(h)
442
+ h = nonlinearity(h)
443
+ h = self.conv_out(h)
444
+ return h
445
+
446
+
447
+ class Decoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ *,
451
+ ch,
452
+ out_ch,
453
+ ch_mult=(1, 2, 4, 8),
454
+ num_res_blocks,
455
+ attn_resolutions,
456
+ dropout=0.0,
457
+ resamp_with_conv=True,
458
+ in_channels,
459
+ resolution,
460
+ z_channels,
461
+ give_pre_end=False,
462
+ tanh_out=False,
463
+ use_linear_attn=False,
464
+ attn_type="vanilla",
465
+ **ignorekwargs,
466
+ ):
467
+ super().__init__()
468
+ if use_linear_attn:
469
+ attn_type = "linear"
470
+ self.ch = ch
471
+ self.temb_ch = 0
472
+ self.num_resolutions = len(ch_mult)
473
+ self.num_res_blocks = num_res_blocks
474
+ self.resolution = resolution
475
+ self.in_channels = in_channels
476
+ self.give_pre_end = give_pre_end
477
+ self.tanh_out = tanh_out
478
+
479
+ # compute in_ch_mult, block_in and curr_res at lowest res
480
+ block_in = ch * ch_mult[self.num_resolutions - 1]
481
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
482
+ self.z_shape = (1, z_channels, curr_res, curr_res)
483
+
484
+ # z to block_in
485
+ self.conv_in = torch.nn.Conv2d(
486
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
487
+ )
488
+
489
+ # middle
490
+ self.mid = nn.Module()
491
+ self.mid.block_1 = ResnetBlock(
492
+ in_channels=block_in,
493
+ out_channels=block_in,
494
+ temb_channels=self.temb_ch,
495
+ dropout=dropout,
496
+ )
497
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
498
+ self.mid.block_2 = ResnetBlock(
499
+ in_channels=block_in,
500
+ out_channels=block_in,
501
+ temb_channels=self.temb_ch,
502
+ dropout=dropout,
503
+ )
504
+
505
+ # upsampling
506
+ self.up = nn.ModuleList()
507
+ for i_level in reversed(range(self.num_resolutions)):
508
+ block = nn.ModuleList()
509
+ attn = nn.ModuleList()
510
+ block_out = ch * ch_mult[i_level]
511
+ for i_block in range(self.num_res_blocks + 1):
512
+ block.append(
513
+ ResnetBlock(
514
+ in_channels=block_in,
515
+ out_channels=block_out,
516
+ temb_channels=self.temb_ch,
517
+ dropout=dropout,
518
+ )
519
+ )
520
+ block_in = block_out
521
+ if curr_res in attn_resolutions:
522
+ attn.append(make_attn(block_in, attn_type=attn_type))
523
+ up = nn.Module()
524
+ up.block = block
525
+ up.attn = attn
526
+ if i_level != 0:
527
+ up.upsample = Upsample(block_in, resamp_with_conv)
528
+ curr_res = curr_res * 2
529
+ self.up.insert(0, up) # prepend to get consistent order
530
+
531
+ # end
532
+ self.norm_out = Normalize(block_in)
533
+ self.conv_out = torch.nn.Conv2d(
534
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
535
+ )
536
+
537
+ def forward(self, z):
538
+ # assert z.shape[1:] == self.z_shape[1:]
539
+ self.last_z_shape = z.shape
540
+
541
+ # timestep embedding
542
+ temb = None
543
+
544
+ # z to block_in
545
+ h = self.conv_in(z)
546
+
547
+ # middle
548
+ h = self.mid.block_1(h, temb)
549
+ h = self.mid.attn_1(h)
550
+ h = self.mid.block_2(h, temb)
551
+
552
+ # upsampling
553
+ for i_level in reversed(range(self.num_resolutions)):
554
+ for i_block in range(self.num_res_blocks + 1):
555
+ h = self.up[i_level].block[i_block](h, temb)
556
+ if len(self.up[i_level].attn) > 0:
557
+ h = self.up[i_level].attn[i_block](h)
558
+ if i_level != 0:
559
+ h = self.up[i_level].upsample(h)
560
+
561
+ # end
562
+ if self.give_pre_end:
563
+ return h
564
+
565
+ h = self.norm_out(h)
566
+ h = nonlinearity(h)
567
+ h = self.conv_out(h)
568
+ if self.tanh_out:
569
+ h = torch.tanh(h)
570
+ return h
571
+
572
+
573
+ class VQModel(nn.Module):
574
+ def __init__(
575
+ self,
576
+ ddconfig,
577
+ n_embed,
578
+ embed_dim,
579
+ ckpt_path=None,
580
+ ignore_keys=[],
581
+ image_key="image",
582
+ colorize_nlabels=None,
583
+ monitor=None,
584
+ scheduler_config=None,
585
+ lr_g_factor=1.0,
586
+ remap=None,
587
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
588
+ ):
589
+ super().__init__()
590
+ self.image_key = image_key
591
+ self.encoder = Encoder(**ddconfig)
592
+ self.decoder = Decoder(**ddconfig)
593
+ self.quantize = VectorQuantizer(
594
+ n_embed,
595
+ embed_dim,
596
+ beta=0.25,
597
+ remap=remap,
598
+ sane_index_shape=sane_index_shape,
599
+ )
600
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
601
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
602
+ if ckpt_path is not None:
603
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
604
+ self.image_key = image_key
605
+ if colorize_nlabels is not None:
606
+ assert isinstance(colorize_nlabels, int)
607
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
608
+ if monitor is not None:
609
+ self.monitor = monitor
610
+ self.scheduler_config = scheduler_config
611
+ self.lr_g_factor = lr_g_factor
612
+
613
+ def init_from_ckpt(self, path, ignore_keys=list()):
614
+ sd = torch.load(path, map_location="cpu")["state_dict"]
615
+ keys = list(sd.keys())
616
+ for k in keys:
617
+ for ik in ignore_keys:
618
+ if k.startswith(ik):
619
+ print("Deleting key {} from state_dict.".format(k))
620
+ del sd[k]
621
+ self.load_state_dict(sd, strict=False)
622
+ print(f"VQModel loaded from {path}")
623
+
624
+ def encode(self, x):
625
+ h = self.encoder(x)
626
+ h = self.quant_conv(h)
627
+ quant, emb_loss, info = self.quantize(h)
628
+ return quant, emb_loss, info
629
+
630
+ def decode(self, quant):
631
+ quant = self.post_quant_conv(quant)
632
+ dec = self.decoder(quant)
633
+ return dec
634
+
635
+ def decode_code(self, code_b):
636
+ quant_b = self.quantize.embed_code(code_b)
637
+ dec = self.decode(quant_b)
638
+ return dec
639
+
640
+ def forward(self, input):
641
+ quant, diff, _ = self.encode(input)
642
+ dec = self.decode(quant)
643
+ return dec, diff
644
+
645
+ def get_input(self, batch, k):
646
+ x = batch[k]
647
+ if len(x.shape) == 3:
648
+ x = x[..., None]
649
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
650
+ return x.float()
651
+
652
+ def get_last_layer(self):
653
+ return self.decoder.conv_out.weight
654
+
655
+ def log_images(self, batch, **kwargs):
656
+ log = dict()
657
+ x = self.get_input(batch, self.image_key)
658
+ x = x.to(self.device)
659
+ xrec, _ = self(x)
660
+ if x.shape[1] > 3:
661
+ # colorize with random projection
662
+ assert xrec.shape[1] > 3
663
+ x = self.to_rgb(x)
664
+ xrec = self.to_rgb(xrec)
665
+ log["inputs"] = x
666
+ log["reconstructions"] = xrec
667
+ return log
668
+
669
+ def to_rgb(self, x):
670
+ assert self.image_key == "segmentation"
671
+ if not hasattr(self, "colorize"):
672
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
673
+ x = F.conv2d(x, weight=self.colorize)
674
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
675
+ return x
chameleon/miniviewer/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/miniviewer/__main__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from chameleon.miniviewer.miniviewer import main
7
+
8
+ if __name__ == "__main__":
9
+ main()
chameleon/miniviewer/miniviewer.html ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- Copyright (c) Meta Platforms, Inc. and affiliates. -->
2
+
3
+ <!-- This source code is licensed under the Chameleon License found in the -->
4
+ <!-- LICENSE file in the root directory of this source tree. -->
5
+ <h1>
6
+ <div id="connection-status"></div>
7
+ MiniViewer:
8
+ </h1>
9
+ <div class="container">
10
+ <div class="sidebar with-padding">
11
+ <h4>Input Controls</h4>
12
+ <div class="input-controls-container">
13
+ <button class="button" onclick="addInput('text')">Add text input</button>
14
+ <button class="button" onclick="addInput('image')">
15
+ Add image input
16
+ </button>
17
+ <button class="button" onclick="addInput('<END-OF-TURN>')">
18
+ Add end-of-turn token
19
+ </button>
20
+ </div>
21
+ <hr />
22
+ <h4>General Options</h4>
23
+ <div class="option">
24
+ <label for="seed">seed</label>
25
+ <input type="number" id="seed" value="0" />
26
+ </div>
27
+ <div class="option">
28
+ <label for="max-seq-len">max sequence length</label>
29
+ <input type="number" id="max-seq-len" value="4096" />
30
+ </div>
31
+ <div class="option">
32
+ <label for="max-gen-len">max generation length</label>
33
+ <input type="number" id="max-gen-len" value="4096" />
34
+ </div>
35
+ <h4>
36
+ <input type="checkbox" id="enable-text" name="enable-text" checked />
37
+ <label for="enable-text">Text Decoder Options</label>
38
+ </h4>
39
+ <div class="option">
40
+ <label for="text-rep-penalty">repetition penalty</label>
41
+ <input type="number" id="text-rep-penalty" value="1.2" step="0.01" />
42
+ </div>
43
+ <div class="option">
44
+ <label for="text-temp">temperature</label>
45
+ <input type="number" id="text-temp" value="0.7" step="0.01" />
46
+ </div>
47
+ <div class="option">
48
+ <label for="text-top-p">top-p</label>
49
+ <input type="number" id="text-top-p" value="0.9" step="0.01" />
50
+ </div>
51
+ <h4>
52
+ <input type="checkbox" id="enable-image" name="enable-image" checked />
53
+ <label for="enable-image">Image Decoder Options</label>
54
+ </h4>
55
+ <div class="option">
56
+ <label for="img-cfg-gstext">cfg text</label>
57
+ <input type="number" id="img-cfg-gstext" value="3.0" step="0.01" />
58
+ </div>
59
+ <div class="option">
60
+ <label for="img-cfg-gsimage">cfg image</label>
61
+ <input type="number" id="img-cfg-gsimage" value="1.2" step="0.01" />
62
+ </div>
63
+ <div class="option">
64
+ <label for="img-temp">temperature</label>
65
+ <input type="number" id="img-temp" value="0.7" step="0.01" />
66
+ </div>
67
+ <div class="option">
68
+ <label for="img-top-p">top-p</label>
69
+ <input type="number" id="img-top-p" value="0.9" step="0.01" />
70
+ </div>
71
+ </div>
72
+ <div class="content with-padding">
73
+ <div class="input-wrapper">
74
+ Inputs:
75
+ <div id="inputs" class="with-padding"></div>
76
+ </div>
77
+ <h4>
78
+ <button id="generate" class="button" onclick="generate()">
79
+ Generate
80
+ </button>
81
+ <button
82
+ id="cancel"
83
+ class="button"
84
+ onclick="cancel()"
85
+ style="display: none"
86
+ >
87
+ Cancel
88
+ </button>
89
+ </h4>
90
+ Results:
91
+ <pre id="results" class="with-padding"></pre>
92
+ <div id="timing" class="with-padding"></div>
93
+ <div id="queue" class="with-padding"></div>
94
+ </div>
95
+ </div>
96
+
97
+ <style>
98
+ .container {
99
+ display: inline-flex;
100
+ }
101
+
102
+ .sidebar {
103
+ flex: 0 0 200px;
104
+ border-right: 2px solid #ddd;
105
+ }
106
+
107
+ #connection-status {
108
+ width: 20px;
109
+ height: 20px;
110
+ border-radius: 10px;
111
+ background-color: grey;
112
+ display: inline-block;
113
+ }
114
+
115
+ .input-controls-container {
116
+ display: inline-grid;
117
+ }
118
+
119
+ .option {
120
+ display: flex;
121
+ margin-bottom: 5px;
122
+ }
123
+
124
+ .option label {
125
+ white-space: nowrap;
126
+ margin-right: 10px;
127
+ }
128
+
129
+ .option input {
130
+ flex-grow: 1;
131
+ text-align: right;
132
+ }
133
+
134
+ .content {
135
+ width: 100%;
136
+ }
137
+
138
+ .with-padding {
139
+ padding: 10px;
140
+ }
141
+
142
+ .input-wrapper {
143
+ border: dotted;
144
+ }
145
+
146
+ .input-container {
147
+ display: flex;
148
+ align-items: center;
149
+ }
150
+
151
+ .input-controls {
152
+ display: inline-flex;
153
+ padding: 2px;
154
+ }
155
+
156
+ #results {
157
+ background: lightgray;
158
+ }
159
+
160
+ button {
161
+ text-align: left;
162
+ }
163
+
164
+ img {
165
+ width: 200px;
166
+ height: 200px;
167
+ }
168
+ </style>
169
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.6.0/socket.io.min.js"></script>
170
+ <script>
171
+ var active_key;
172
+ var socket;
173
+
174
+ function createButton(text, onClick) {
175
+ var button = document.createElement("button");
176
+ button.textContent = text;
177
+ button.onclick = onClick;
178
+ return button;
179
+ }
180
+
181
+ function removeInput(evt) {
182
+ var inputWrapper = evt.target.parentNode.parentNode;
183
+ inputWrapper.parentNode.removeChild(inputWrapper);
184
+ }
185
+
186
+ function moveInputUp(evt) {
187
+ var inputWrapper = evt.target.parentNode.parentNode;
188
+ var prev = inputWrapper.previousElementSibling;
189
+ if (prev) {
190
+ inputWrapper.parentNode.insertBefore(inputWrapper, prev);
191
+ }
192
+ }
193
+
194
+ function moveInputDown(evt) {
195
+ var inputWrapper = evt.target.parentNode.parentNode;
196
+ var next = inputWrapper.nextElementSibling;
197
+ if (next) {
198
+ inputWrapper.parentNode.insertBefore(next, inputWrapper);
199
+ }
200
+ }
201
+
202
+ function readFileAsync(file) {
203
+ return new Promise((resolve, reject) => {
204
+ let reader = new FileReader();
205
+ reader.onload = () => resolve(reader.result);
206
+ reader.onerror = reject;
207
+ reader.readAsDataURL(file);
208
+ });
209
+ }
210
+
211
+ async function loadImageSource(dataTransfer) {
212
+ if (dataTransfer.files.length > 0) {
213
+ return await readFileAsync(dataTransfer.files[0]);
214
+ }
215
+
216
+ let htmlContent = dataTransfer.getData("text/html");
217
+ if (htmlContent) {
218
+ const div = document.createElement("div");
219
+ div.innerHTML = htmlContent;
220
+ return div.querySelector("img").src;
221
+ }
222
+
223
+ return (
224
+ dataTransfer.getData("text/uri-list") ||
225
+ dataTransfer.getData("text/plain")
226
+ );
227
+ }
228
+
229
+ async function showPreview(evt) {
230
+ var wrapper = evt.target.parentElement;
231
+ wrapper.querySelector("img").src = await loadImageSource(evt.target);
232
+ wrapper.querySelector("img").style.display = "block";
233
+ wrapper.querySelector("p").style.display = "none";
234
+ }
235
+
236
+ async function handleDrop(evt) {
237
+ evt.preventDefault();
238
+ var wrapper = evt.target.parentElement;
239
+ var file = evt.dataTransfer.files[0];
240
+ var fileInput = wrapper.querySelector('input[type="file"]');
241
+ fileInput.files = evt.dataTransfer.files;
242
+ wrapper.querySelector("img").src = await loadImageSource(evt.dataTransfer);
243
+ wrapper.querySelector("img").style.display = "block";
244
+ wrapper.querySelector("p").style.display = "none";
245
+ }
246
+
247
+ function addInput(input_kind) {
248
+ var inputs_div = document.getElementById("inputs");
249
+ var wrapper = document.createElement("div");
250
+ wrapper.kind = input_kind;
251
+ wrapper.className = "input-container";
252
+
253
+ var new_inputs = [];
254
+ if (input_kind === "text") {
255
+ new_inputs.push(document.createElement("textarea"));
256
+ } else if (input_kind === "image") {
257
+ wrapper.setAttribute("draggable", true);
258
+ wrapper.ondragover = (evt) => evt.preventDefault();
259
+ wrapper.ondrop = handleDrop;
260
+
261
+ var hiddenImageFromFile = document.createElement("input");
262
+ hiddenImageFromFile.type = "file";
263
+ hiddenImageFromFile.accept = "image/*";
264
+ hiddenImageFromFile.addEventListener("change", showPreview);
265
+ hiddenImageFromFile.style.display = "none";
266
+ wrapper.onclick = function () {
267
+ hiddenImageFromFile.click();
268
+ };
269
+ new_inputs.push(hiddenImageFromFile);
270
+
271
+ var description = document.createElement("p");
272
+ description.textContent =
273
+ "Drag and drop your image here, or click to select.";
274
+ new_inputs.push(description);
275
+
276
+ var preview = document.createElement("img");
277
+ preview.style.display = "none";
278
+ new_inputs.push(preview);
279
+ } else {
280
+ var span = document.createElement("span");
281
+ span.textContent = input_kind;
282
+ new_inputs.push(span);
283
+ }
284
+
285
+ const input_controls = document.createElement("div");
286
+ input_controls.className = "input-controls";
287
+ input_controls.appendChild(createButton("-", removeInput));
288
+ input_controls.appendChild(createButton("↓", moveInputDown));
289
+ input_controls.appendChild(createButton("↑", moveInputUp));
290
+
291
+ wrapper.appendChild(input_controls);
292
+ for (var new_input of new_inputs) {
293
+ wrapper.appendChild(new_input);
294
+ }
295
+ wrapper.appendChild(document.createElement("br"));
296
+
297
+ inputs_div.appendChild(wrapper);
298
+ }
299
+
300
+ async function generate() {
301
+ document.getElementById("generate").style.display = "none";
302
+ document.getElementById("cancel").style.display = "block";
303
+ document.getElementById("results").innerHTML = "";
304
+ document.getElementById("timing").innerHTML = "";
305
+ document.getElementById("queue").innerHTML = "";
306
+
307
+ active_key = `key_${Math.random()
308
+ .toString(36)
309
+ .substring(2, 11)}_${Date.now()}`;
310
+
311
+ const user_options = {};
312
+ for (const option of document.getElementsByClassName("option")) {
313
+ const input = option.querySelector("input");
314
+ user_options[input.id] = Number(input.value);
315
+ }
316
+
317
+ user_options["enable-text"] =
318
+ document.getElementById("enable-text").checked;
319
+ user_options["enable-image"] =
320
+ document.getElementById("enable-image").checked;
321
+
322
+ const user_inputs = [];
323
+ const inputs_div = document.getElementById("inputs");
324
+
325
+ const input_elems = Array.from(inputs_div.children).map((wrapper) =>
326
+ wrapper.querySelector("textarea, input, span")
327
+ );
328
+
329
+ const image_promises = Array.from(inputs_div.children)
330
+ .filter((wrapper) => wrapper.kind === "image")
331
+ .map((wrapper) => {
332
+ const file_input = wrapper.querySelector('input[type="file"]');
333
+ return file_input.files[0]
334
+ ? readFileAsync(file_input.files[0])
335
+ : Promise.resolve(null);
336
+ });
337
+
338
+ const images = await Promise.all(image_promises);
339
+
340
+ for (const wrapper of inputs_div.children) {
341
+ if (wrapper.kind === "text") {
342
+ user_inputs.push({
343
+ type: "text",
344
+ value: wrapper.querySelector("textarea").value,
345
+ });
346
+ } else if (wrapper.kind === "image") {
347
+ user_inputs.push({ type: "image", value: images.shift() });
348
+ } else {
349
+ user_inputs.push({ type: "sentinel", value: wrapper.kind });
350
+ }
351
+ }
352
+
353
+ socket.emit("generate", active_key, user_options, user_inputs);
354
+ }
355
+
356
+ function cancel() {
357
+ document.getElementById("generate").style.display = "block";
358
+ document.getElementById("cancel").style.display = "none";
359
+ document.getElementById("queue").innerHTML = "";
360
+ socket.emit("cancel", active_key);
361
+ active_key = null;
362
+ }
363
+
364
+ function connectSocket() {
365
+ socket = io();
366
+
367
+ socket.on("connect", function() {
368
+ document.getElementById("connection-status").style.backgroundColor = 'green';
369
+ });
370
+
371
+ socket.on("disconnect", function(reason) {
372
+ cancel();
373
+ document.getElementById("connection-status").style.backgroundColor = 'red';
374
+ });
375
+
376
+ socket.on("progress", function (data) {
377
+ if (data.key != active_key) {
378
+ return;
379
+ }
380
+
381
+ document.getElementById("queue").innerHTML = "";
382
+ if (data.type == "queue") {
383
+ document.getElementById(
384
+ "queue"
385
+ ).innerHTML = `queue position ${data.value}`;
386
+ }
387
+
388
+ if (data.type == "text") {
389
+ document.getElementById("results").innerHTML += data.value;
390
+ } else if (data.type == "image_start") {
391
+ document.getElementById("results").appendChild(new Image());
392
+ } else if (data.type == "image") {
393
+ document.getElementById("results").lastElementChild.src = data.value;
394
+ } else if (data.type == "image_end") {
395
+ } else if (data.type == "done") {
396
+ document.getElementById(
397
+ "timing"
398
+ ).innerHTML = `Generation time: ${data.value.toFixed(2)} sec`;
399
+ document.getElementById("generate").style.display = "block";
400
+ document.getElementById("cancel").style.display = "none";
401
+ active_key = null;
402
+ }
403
+ });
404
+ }
405
+
406
+ window.onload = (evt) => {
407
+ connectSocket();
408
+ };
409
+ </script>
chameleon/miniviewer/miniviewer.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import base64
7
+ import os
8
+ import threading
9
+ import time
10
+ from dataclasses import dataclass
11
+ from enum import Enum
12
+ from pathlib import Path
13
+
14
+ import click
15
+ import torch
16
+ from flask import Flask, request
17
+ from flask_socketio import SocketIO
18
+
19
+ from chameleon.inference.chameleon import ChameleonInferenceModel, Options, TokenManager
20
+
21
+
22
+ @dataclass
23
+ class Request:
24
+ room: str
25
+ key: str
26
+ options: dict[str, int | float | bool]
27
+ prompt_ui: list[dict]
28
+
29
+
30
+ def convert_options(ui_options: dict) -> Options:
31
+ txt = None
32
+ if ui_options["enable-text"]:
33
+ txt = Options.Text(
34
+ repetition_penalty=ui_options["text-rep-penalty"],
35
+ temp=ui_options["text-temp"],
36
+ top_p=ui_options["text-top-p"],
37
+ )
38
+ img = None
39
+ if ui_options["enable-image"]:
40
+ img = Options.Image(
41
+ cfg=Options.Image.CFG(
42
+ guidance_scale_image=ui_options["img-cfg-gsimage"],
43
+ guidance_scale_text=ui_options["img-cfg-gstext"],
44
+ ),
45
+ temp=ui_options["img-temp"],
46
+ top_p=ui_options["img-top-p"],
47
+ )
48
+ return Options(
49
+ max_seq_len=ui_options["max-seq-len"],
50
+ max_gen_len=ui_options["max-gen-len"],
51
+ seed=ui_options["seed"],
52
+ txt=txt,
53
+ img=img,
54
+ )
55
+
56
+
57
+ class UIDecoder:
58
+ class State(Enum):
59
+ TXT = 1
60
+ IMG = 2
61
+ IMG_END = 3
62
+
63
+ def __init__(self, token_manager: TokenManager):
64
+ self.token_manager = token_manager
65
+ self.state = UIDecoder.State.TXT
66
+ self.image_builder = []
67
+ self.image_yield_every_n = 32
68
+ self.image_has_updated = False
69
+
70
+ def _image_progress(self) -> dict:
71
+ self.image_has_updated = False
72
+ png = self.token_manager.png_from_bpe_tokens(torch.cat(self.image_builder))
73
+ return {
74
+ "type": "image",
75
+ "value": "data:image/png;base64," + base64.b64encode(png).decode(),
76
+ }
77
+
78
+ def next(self, gpu_token: torch.LongTensor) -> dict | None:
79
+ if self.state == UIDecoder.State.TXT:
80
+ cpu_tok = gpu_token.item()
81
+
82
+ if cpu_tok == self.token_manager.vocab.begin_image:
83
+ self.state = UIDecoder.State.IMG
84
+ return {"type": "image_start"}
85
+
86
+ return {
87
+ "type": "text",
88
+ "value": self.token_manager.tokenizer.decode([cpu_tok]),
89
+ }
90
+
91
+ elif self.state == UIDecoder.State.IMG:
92
+ self.image_builder.append(gpu_token)
93
+ self.image_has_updated = True
94
+ if len(self.image_builder) == 1024:
95
+ self.state = UIDecoder.State.IMG_END
96
+ if len(self.image_builder) % self.image_yield_every_n == 0:
97
+ return self._image_progress()
98
+
99
+ elif self.state == UIDecoder.State.IMG_END:
100
+ # assert gpu_token == end_image
101
+ self.state = UIDecoder.State.TXT
102
+ progress = self._image_progress() if self.image_has_updated else None
103
+ self.image_builder = []
104
+ return progress
105
+
106
+
107
+ @dataclass
108
+ class State:
109
+ room_keys: dict[str, set[str]]
110
+ pending_requests: list[Request]
111
+ cond: threading.Condition
112
+
113
+ def __enter__(self, *args, **kwargs):
114
+ self.cond.__enter__(*args, **kwargs)
115
+ return self
116
+
117
+ def __exit__(self, *args, **kwargs):
118
+ self.cond.__exit__(*args, **kwargs)
119
+ return self
120
+
121
+
122
+ GlobalState = State(room_keys={}, pending_requests=[], cond=threading.Condition())
123
+
124
+ app = Flask(__name__)
125
+ socketio = SocketIO(app, max_http_buffer_size=16 * 1024 * 1024)
126
+
127
+
128
+ @app.route("/")
129
+ def index():
130
+ with open(Path(__file__).parent / "miniviewer.html") as f:
131
+ return f.read()
132
+
133
+
134
+ @socketio.on("disconnect")
135
+ def handle_disconnect():
136
+ with GlobalState as state:
137
+ try:
138
+ del state.room_keys[request.sid]
139
+ except KeyError:
140
+ pass
141
+
142
+
143
+ @socketio.on("cancel")
144
+ def handle_cancel(key):
145
+ with GlobalState as state:
146
+ try:
147
+ state.room_keys[request.sid].remove(key)
148
+ except KeyError:
149
+ pass
150
+
151
+
152
+ @socketio.on("generate")
153
+ def handle_generate(key, options, prompt_ui):
154
+ with GlobalState as state:
155
+ if request.sid not in state.room_keys:
156
+ state.room_keys[request.sid] = set()
157
+ state.room_keys[request.sid].add(key)
158
+ state.pending_requests.append(Request(request.sid, key, options, prompt_ui))
159
+ state.cond.notify_all()
160
+
161
+
162
+ def generation_thread(model: ChameleonInferenceModel):
163
+ while True:
164
+ with GlobalState as state:
165
+ state.cond.wait_for(lambda: state.pending_requests)
166
+ req = state.pending_requests.pop(0)
167
+
168
+ start = time.time()
169
+ ui_decoder = UIDecoder(model.token_manager)
170
+ options = convert_options(req.options)
171
+
172
+ if not options.txt:
173
+ progress = ui_decoder.next(
174
+ torch.tensor([model.token_manager.vocab.begin_image])
175
+ )
176
+ socketio.emit(
177
+ "progress",
178
+ {"key": req.key, **progress},
179
+ room=req.room,
180
+ )
181
+
182
+ for token in model.stream(
183
+ prompt_ui=req.prompt_ui,
184
+ options=options,
185
+ ):
186
+ with GlobalState as state:
187
+ if req.key not in state.room_keys.get(req.room, {}):
188
+ break
189
+
190
+ if progress := ui_decoder.next(token.id):
191
+ socketio.emit(
192
+ "progress",
193
+ {"key": req.key, **progress},
194
+ room=req.room,
195
+ )
196
+
197
+ timing = time.time() - start
198
+ socketio.emit(
199
+ "progress",
200
+ {"key": req.key, "type": "done", "value": timing},
201
+ room=req.room,
202
+ )
203
+
204
+
205
+ def queue_position_thread():
206
+ local_pending_requests = []
207
+ while True:
208
+ with GlobalState as state:
209
+ state.cond.wait_for(
210
+ lambda: local_pending_requests != state.pending_requests
211
+ )
212
+ local_pending_requests = state.pending_requests[:]
213
+
214
+ for i, req in enumerate(local_pending_requests):
215
+ progress = {
216
+ "type": "queue",
217
+ "key": req.key,
218
+ "value": i + 1,
219
+ }
220
+ socketio.emit("progress", progress, room=req.room)
221
+
222
+
223
+ @click.command()
224
+ @click.option("--data-path", type=click.Path(), default="./data")
225
+ @click.option(
226
+ "--model-size", type=click.Choice(["7b", "30b"], case_sensitive=False), default="7b"
227
+ )
228
+ def main(data_path, model_size):
229
+ data_path = Path(data_path)
230
+
231
+ model_path = str(data_path / "models" / model_size)
232
+ tokenizer_path = str(data_path / "tokenizer/text_tokenizer.json")
233
+ vqgan_cfg_path = str(data_path / "tokenizer/vqgan.yaml")
234
+ vqgan_ckpt_path = str(data_path / "tokenizer/vqgan.ckpt")
235
+
236
+ if not os.path.exists(model_path):
237
+ raise ValueError(
238
+ "Model not found. Did you run python -m chameleon.download_data {PRESIGNED_URL}"
239
+ )
240
+
241
+ cm3v2_inference_model = ChameleonInferenceModel(
242
+ model_path, tokenizer_path, vqgan_cfg_path, vqgan_ckpt_path
243
+ )
244
+ threading.Thread(
245
+ target=generation_thread,
246
+ args=(cm3v2_inference_model,),
247
+ daemon=True,
248
+ ).start()
249
+ threading.Thread(target=queue_position_thread, daemon=True).start()
250
+ socketio.run(app, debug=False)
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
chameleon/viewer/backend/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/viewer/backend/data_types.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Literal
8
+
9
+ from pydantic import BaseModel, Extra, Field
10
+
11
+ from chameleon.viewer.backend.models.abstract_model import (
12
+ DEFAULT_MULTIMODAL_CFG_IMAGE,
13
+ DEFAULT_MULTIMODAL_CFG_TEXT,
14
+ )
15
+
16
+
17
+ class WSMessageType(str, Enum):
18
+ GENERATE_IMAGE = "GENERATE_IMAGE"
19
+ GENERATE_TEXT = "GENERATE_TEXT"
20
+ GENERATE_MULTIMODAL = "GENERATE_MULTIMODAL"
21
+ PARTIAL_OUTPUT = "PARTIAL_OUTPUT"
22
+ FULL_OUTPUT = "FULL_OUTPUT"
23
+ COMPLETE = "COMPLETE"
24
+ ERROR = "ERROR"
25
+ QUEUE_STATUS = "QUEUE_STATUS"
26
+
27
+
28
+ class ContentType(str, Enum):
29
+ TEXT = "TEXT"
30
+ IMAGE = "IMAGE"
31
+
32
+
33
+ class Content(BaseModel):
34
+ content_type: ContentType
35
+ content: str
36
+
37
+ class Config:
38
+ extra = Extra.forbid
39
+
40
+
41
+ class NoOptionsForPartial(BaseModel):
42
+ message_type: Literal[WSMessageType.PARTIAL_OUTPUT] = WSMessageType.PARTIAL_OUTPUT
43
+
44
+
45
+ class NoOptionsForFull(BaseModel):
46
+ message_type: Literal[WSMessageType.FULL_OUTPUT] = WSMessageType.FULL_OUTPUT
47
+
48
+
49
+ class NoOptionsForComplete(BaseModel):
50
+ message_type: Literal[WSMessageType.COMPLETE] = WSMessageType.COMPLETE
51
+
52
+
53
+ class NoOptionsForError(BaseModel):
54
+ message_type: Literal[WSMessageType.ERROR] = WSMessageType.ERROR
55
+
56
+
57
+ class NoOptionsForQueueStatus(BaseModel):
58
+ message_type: Literal[WSMessageType.QUEUE_STATUS] = WSMessageType.QUEUE_STATUS
59
+
60
+
61
+ class MultimodalGeneratorOptions(BaseModel):
62
+ message_type: Literal[
63
+ WSMessageType.GENERATE_MULTIMODAL
64
+ ] = WSMessageType.GENERATE_MULTIMODAL
65
+ temp: float = 0.7
66
+ top_p: float = 0.9
67
+ cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE
68
+ cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT
69
+ yield_every_n: int = 32
70
+ max_gen_tokens: int = 4096
71
+ repetition_penalty: float = 1.2
72
+ suffix_tokens: list[str] | None = None
73
+ seed: int | None = None
74
+
75
+ class Config:
76
+ extra = Extra.forbid
77
+
78
+
79
+ class WSMultimodalMessage(BaseModel):
80
+ message_type: WSMessageType
81
+ content: list[Content]
82
+ options: (
83
+ MultimodalGeneratorOptions
84
+ | NoOptionsForPartial
85
+ | NoOptionsForFull
86
+ | NoOptionsForError
87
+ | NoOptionsForComplete
88
+ | NoOptionsForQueueStatus
89
+ ) = Field(..., discriminator="message_type")
90
+ debug_info: dict[str, str] = {}
chameleon/viewer/backend/model_viewer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import hydra
7
+ import torch
8
+ from omegaconf import DictConfig
9
+
10
+ from chameleon.inference import loader
11
+ from chameleon.viewer.backend.models.chameleon_distributed import (
12
+ ChameleonDistributedGenerator,
13
+ )
14
+ from chameleon.viewer.backend.models.chameleon_local import ChameleonLocalGenerator
15
+ from chameleon.viewer.backend.models.service import serve
16
+ from chameleon.viewer.backend.utils import configure_rich_logging, get_logger
17
+
18
+ logger = get_logger(__name__)
19
+
20
+ VERSION = "2.0"
21
+ SEED = 42
22
+
23
+
24
+ def create_chameleon_generator(cfg: DictConfig):
25
+ world_size = loader.detect_shard_count(cfg.model_path)
26
+ if world_size > 1:
27
+ torch.multiprocessing.set_start_method("spawn")
28
+ generator = ChameleonDistributedGenerator(
29
+ model_path=cfg.model_path,
30
+ tokenizer_path=cfg.tokenizer_path,
31
+ vqgan_config_path=cfg.vqgan_config_path,
32
+ vqgan_ckpt_path=cfg.vqgan_ckpt_path,
33
+ additional_eos_tokens=cfg.additional_eos_tokens,
34
+ world_size=world_size,
35
+ master_address=cfg.distributed.master_address,
36
+ master_port=cfg.distributed.master_port,
37
+ redis_port=cfg.redis_port,
38
+ )
39
+ else:
40
+ generator = ChameleonLocalGenerator(
41
+ model_path=cfg.model_path,
42
+ tokenizer_path=cfg.tokenizer_path,
43
+ vqgan_config_path=cfg.vqgan_config_path,
44
+ vqgan_ckpt_path=cfg.vqgan_ckpt_path,
45
+ additional_eos_tokens=cfg.additional_eos_tokens,
46
+ )
47
+ return generator
48
+
49
+
50
+ @hydra.main("../../../config", config_name="model_viewer", version_base="1.3.2")
51
+ def main(cfg: DictConfig) -> None:
52
+ configure_rich_logging()
53
+ torch.set_default_tensor_type("torch.cuda.FloatTensor")
54
+ logger.info("Starting viewer server with hydra cfg: %s", cfg)
55
+
56
+ serve(
57
+ create_chameleon_generator(cfg),
58
+ cfg.host,
59
+ cfg.port,
60
+ debug=cfg.debug,
61
+ redis_port=cfg.redis_port,
62
+ )
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
chameleon/viewer/backend/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
chameleon/viewer/backend/models/abstract_model.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import abc
7
+ from dataclasses import dataclass
8
+ from typing import Generator
9
+
10
+ import PIL.Image
11
+
12
+ # images, joined retrieval queries, retrieval images
13
+ MixedTokenType = str | PIL.Image.Image
14
+ MixedSequenceType = list[MixedTokenType]
15
+
16
+
17
+ @dataclass
18
+ class StreamingImage:
19
+ image: PIL.Image.Image
20
+ final: bool
21
+
22
+
23
+ DEFAULT_MULTIMODAL_CFG_IMAGE = 1.2
24
+ DEFAULT_MULTIMODAL_CFG_TEXT = 3.0
25
+ DEFAULT_IMAGE_CFG_IMAGE = 3.0
26
+ DEFAULT_IMAGE_CFG_TEXT = 3.0
27
+
28
+
29
+ class AbstractMultimodalGenerator(abc.ABC):
30
+ @abc.abstractmethod
31
+ def generate_text_streaming(
32
+ self,
33
+ prompts: list[MixedSequenceType],
34
+ temp: float = 1.0,
35
+ top_p: float = 0.8,
36
+ seed: int | None = None,
37
+ ) -> Generator[list[str], None, None]:
38
+ pass
39
+
40
+ @abc.abstractmethod
41
+ def generate_image_streaming(
42
+ self,
43
+ prompt: MixedSequenceType,
44
+ temp: float = 1.0,
45
+ top_p: float = 0.8,
46
+ cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
47
+ cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
48
+ yield_every_n: int = 32,
49
+ seed: int | None = None,
50
+ ) -> Generator[PIL.Image.Image, None, None]:
51
+ pass
52
+
53
+ @abc.abstractmethod
54
+ def generate_multimodal_streaming(
55
+ self,
56
+ prompt: MixedSequenceType,
57
+ temp: float = 1.0,
58
+ top_p: float = 0.8,
59
+ cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
60
+ cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
61
+ yield_every_n: int = 32,
62
+ max_gen_tokens: int = 4096,
63
+ repetition_penalty: float = 1.2,
64
+ suffix_tokens: list[str] | None = None,
65
+ seed: int | None = None,
66
+ ) -> Generator[MixedSequenceType, None, None]:
67
+ pass
chameleon/viewer/backend/models/chameleon_distributed.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import asyncio
7
+ import json
8
+ import multiprocessing
9
+ import os
10
+ import random
11
+ import sys
12
+ import threading
13
+ import time
14
+ import traceback
15
+ from functools import partial
16
+ from typing import Any, Generator, TypeVar
17
+
18
+ import redis
19
+ import redis.asyncio as async_redis
20
+ import torch
21
+ from tokenizers import Tokenizer
22
+
23
+ from chameleon.inference.image_tokenizer import ImageTokenizer
24
+ from chameleon.inference.loader import load_model
25
+ from chameleon.inference.vocab import VocabInfo
26
+ from chameleon.viewer.backend.data_types import WSMessageType
27
+ from chameleon.viewer.backend.models.abstract_model import (
28
+ DEFAULT_IMAGE_CFG_IMAGE,
29
+ DEFAULT_IMAGE_CFG_TEXT,
30
+ DEFAULT_MULTIMODAL_CFG_IMAGE,
31
+ DEFAULT_MULTIMODAL_CFG_TEXT,
32
+ AbstractMultimodalGenerator,
33
+ MixedSequenceType,
34
+ StreamingImage,
35
+ )
36
+ from chameleon.viewer.backend.models.chameleon_local import (
37
+ ChameleonForwardMixin,
38
+ ChameleonTokenizationMixin,
39
+ )
40
+ from chameleon.viewer.backend.utils import get_logger
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ START = "START"
45
+
46
+ T = TypeVar("T")
47
+
48
+
49
+ def find_any(queue_by_id: dict[str, list]) -> str | None:
50
+ for candidate_queue_id, candidate_queue in queue_by_id.items():
51
+ if len(candidate_queue) > 0:
52
+ return candidate_queue_id
53
+ return None
54
+
55
+
56
+ class RedisQueue:
57
+ def __init__(self, redis_client: redis.Redis, name: str, interval: float = 0.1):
58
+ self.redis_client = redis_client
59
+ self.name = name
60
+ self.interval = interval
61
+ self.lock = redis.lock.Lock(redis_client, f"lock_for_{name}")
62
+
63
+ def reset(self):
64
+ self.redis_client.set(self.name, json.dumps({}))
65
+ try:
66
+ self.lock.release()
67
+ except redis.lock.LockError:
68
+ pass
69
+
70
+ def size(self) -> int:
71
+ maybe_queue_by_id = self.redis_client.get(self.name)
72
+ if maybe_queue_by_id is None:
73
+ return 0
74
+ else:
75
+ return len(json.loads(maybe_queue_by_id))
76
+
77
+ def clear(self, queue_id: str):
78
+ with self.lock:
79
+ maybe_queue_by_id = self.redis_client.get(self.name)
80
+ if maybe_queue_by_id is None:
81
+ queue_by_id: dict[str, list] = {}
82
+ else:
83
+ queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id)
84
+ queue_by_id[queue_id] = []
85
+ self.redis_client.set(self.name, json.dumps(queue_by_id))
86
+
87
+ def put(self, queue_id: str, value: T):
88
+ logger.debug(
89
+ "Thread %s: Starting PUT(%s) for %s",
90
+ threading.get_ident(),
91
+ self.name,
92
+ queue_id,
93
+ )
94
+ with self.lock:
95
+ maybe_queue_by_id = self.redis_client.get(self.name)
96
+ if maybe_queue_by_id is None:
97
+ queue_by_id: dict[str, list[T]] = {}
98
+ else:
99
+ queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
100
+
101
+ if queue_id not in queue_by_id:
102
+ queue_by_id[queue_id] = []
103
+ queue_by_id[queue_id] = [value] + queue_by_id[queue_id]
104
+ self.redis_client.set(self.name, json.dumps(queue_by_id))
105
+
106
+ logger.debug(
107
+ "Thread %s: Finished PUT(%s) for %s",
108
+ threading.get_ident(),
109
+ self.name,
110
+ queue_id,
111
+ )
112
+
113
+ def get(self, queue_id: str | None) -> tuple[str, T]:
114
+ """
115
+ Get the next value in the queue.
116
+
117
+ if queue_id is None, will get a value from any queue
118
+
119
+ if queue_id is not none, will wait to get a value from a specific queue
120
+ """
121
+ logger.debug(
122
+ "Thread %s: Starting GET(%s) for %s",
123
+ threading.get_ident(),
124
+ self.name,
125
+ queue_id,
126
+ )
127
+ while True:
128
+ with self.lock:
129
+ # Initialization hasn't happened, so wait for it to happen
130
+ maybe_queue_by_id = self.redis_client.get(self.name)
131
+ if maybe_queue_by_id is None:
132
+ continue
133
+ queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
134
+ if queue_id is None:
135
+ queue_id = find_any(queue_by_id)
136
+
137
+ # Ensure a queue_id was found or that it already existed
138
+ if queue_id is not None and queue_id in queue_by_id:
139
+ queue = queue_by_id[queue_id]
140
+ if len(queue) == 0:
141
+ continue
142
+ value = queue.pop(-1)
143
+ # queue is mutated and queue_by_id references it, so this works
144
+ self.redis_client.set(self.name, json.dumps(queue_by_id))
145
+ logger.debug(
146
+ "Thread %s: Finished GET(%s) for %s",
147
+ threading.get_ident(),
148
+ self.name,
149
+ queue_id,
150
+ )
151
+ return queue_id, value
152
+ time.sleep(self.interval)
153
+
154
+
155
+ class AsyncRedisQueue:
156
+ def __init__(
157
+ self, redis_client: async_redis.Redis, name: str, interval: float = 0.1
158
+ ) -> None:
159
+ self.redis_client = redis_client
160
+ self.name = name
161
+ self.interval = interval
162
+ self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}")
163
+
164
+ async def reset(self):
165
+ await self.redis_client.set(self.name, json.dumps({}))
166
+ try:
167
+ await self.lock.release()
168
+ except async_redis.lock.LockError:
169
+ pass
170
+
171
+ async def size(self) -> int:
172
+ maybe_queue_by_id = await self.redis_client.get(self.name)
173
+ if maybe_queue_by_id is None:
174
+ return 0
175
+ else:
176
+ return len(json.loads(maybe_queue_by_id))
177
+
178
+ async def clear(self, queue_id: str):
179
+ logger.debug(
180
+ "ASYNC Thread %s: Starting CLEAR(%s) for %s",
181
+ threading.get_ident(),
182
+ self.name,
183
+ queue_id,
184
+ )
185
+ async with self.lock:
186
+ maybe_queue_by_id = await self.redis_client.get(self.name)
187
+ if maybe_queue_by_id is None:
188
+ queue_by_id: dict[str, list] = {}
189
+ else:
190
+ queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id)
191
+ queue_by_id[queue_id] = []
192
+ await self.redis_client.set(self.name, json.dumps(queue_by_id))
193
+
194
+ logger.debug(
195
+ "ASYNC Thread %s: Finished CLEAR(%s) for %s",
196
+ threading.get_ident(),
197
+ self.name,
198
+ queue_id,
199
+ )
200
+
201
+ async def put(self, queue_id: str, value: T):
202
+ logger.debug(
203
+ "ASYNC Thread %s: Starting PUT(%s) for %s",
204
+ threading.get_ident(),
205
+ self.name,
206
+ queue_id,
207
+ )
208
+
209
+ async with self.lock:
210
+ maybe_queue_by_id = await self.redis_client.get(self.name)
211
+ if maybe_queue_by_id is None:
212
+ queue_by_id: dict[str, list[T]] = {}
213
+ else:
214
+ queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
215
+
216
+ if queue_id not in queue_by_id:
217
+ queue_by_id[queue_id] = []
218
+ queue_by_id[queue_id] = [value] + queue_by_id[queue_id]
219
+ await self.redis_client.set(self.name, json.dumps(queue_by_id))
220
+
221
+ logger.debug(
222
+ "ASYNC Thread %s: Finished PUT(%s) for %s",
223
+ threading.get_ident(),
224
+ self.name,
225
+ queue_id,
226
+ )
227
+
228
+ async def get(self, queue_id: str | None):
229
+ """
230
+ Get the next value in the queue.
231
+
232
+ if queue_id is None, will get a value from any queue
233
+
234
+ if queue_id is not none, will wait to get a value from a specific queue
235
+ """
236
+ logger.debug(
237
+ "ASYNC Thread %s: Starting GET(%s) for %s",
238
+ threading.get_ident(),
239
+ self.name,
240
+ queue_id,
241
+ )
242
+ while True:
243
+ async with self.lock:
244
+ maybe_queue_by_id = await self.redis_client.get(self.name)
245
+ if maybe_queue_by_id is None:
246
+ continue
247
+ queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
248
+ if queue_id is None:
249
+ queue_id = find_any(queue_by_id)
250
+
251
+ # Ensure a queue_id was found or that it already existed
252
+ if queue_id is not None and queue_id in queue_by_id:
253
+ queue: list = queue_by_id[queue_id]
254
+ if len(queue) == 0:
255
+ continue
256
+ value = queue.pop(-1)
257
+ # queue is mutated and queue_by_id references it, so this works
258
+ await self.redis_client.set(self.name, json.dumps(queue_by_id))
259
+ logger.debug(
260
+ "ASYNC Thread %s: Finished GET(%s) for %s",
261
+ threading.get_ident(),
262
+ self.name,
263
+ queue_id,
264
+ )
265
+ return queue_id, value
266
+ await asyncio.sleep(self.interval)
267
+
268
+
269
+ class AsyncRedisCounter:
270
+ def __init__(self, redis_client: async_redis.Redis, name: str) -> None:
271
+ self.redis_client = redis_client
272
+ self.name = name
273
+ self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}")
274
+
275
+ async def reset(self) -> int:
276
+ try:
277
+ await self.lock.release()
278
+ except async_redis.lock.LockError:
279
+ pass
280
+ await self.redis_client.set(self.name, 0)
281
+
282
+ async def add(self, n: int) -> int:
283
+ async with self.lock:
284
+ current_val = await self.redis_client.get(self.name)
285
+ if current_val is None:
286
+ current_val = 0
287
+ else:
288
+ current_val = int(current_val)
289
+ new_val = current_val + n
290
+ await self.redis_client.set(self.name, new_val)
291
+ return new_val
292
+
293
+ async def sub(self, n: int) -> int:
294
+ async with self.lock:
295
+ current_val = await self.redis_client.get(self.name)
296
+ if current_val is None:
297
+ raise ValueError("Invalid sub counter when counter does not exist")
298
+ current_val = int(current_val)
299
+ if current_val <= 0:
300
+ raise ValueError("Invalid sub counter to counter that is already zero")
301
+ new_val = current_val - n
302
+ await self.redis_client.set(self.name, new_val)
303
+ return new_val
304
+
305
+ async def count(self) -> int:
306
+ value = await self.redis_client.get(self.name)
307
+ if value is None:
308
+ return 0
309
+ else:
310
+ return int(value)
311
+
312
+
313
+ def distributed_workers(
314
+ model_args: dict,
315
+ master_address: str,
316
+ master_port: str,
317
+ world_size: int,
318
+ rank: int,
319
+ redis_port: int,
320
+ worker_queues: dict[int, multiprocessing.Queue],
321
+ ) -> None:
322
+ redis_client = redis.Redis("redis", redis_port)
323
+ request_queue = RedisQueue(redis_client, "request")
324
+ response_queue = RedisQueue(redis_client, "response")
325
+
326
+ os.environ["MASTER_ADDR"] = master_address
327
+ os.environ["MASTER_PORT"] = str(master_port)
328
+
329
+ torch.set_default_tensor_type("torch.cuda.FloatTensor")
330
+
331
+ torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
332
+ assert rank == torch.distributed.get_rank()
333
+
334
+ torch.cuda.set_device(rank)
335
+
336
+ is_coord = rank == 0
337
+
338
+ worker = ChameleonWorker(
339
+ rank=rank,
340
+ model_path=model_args["model_path"],
341
+ tokenizer_path=model_args["tokenizer_path"],
342
+ additional_eos_tokens=model_args["additional_eos_tokens"],
343
+ )
344
+ worker_id = id(worker)
345
+ logger.info("Rank %s, master_port=%s worker=%s", rank, master_port, worker_id)
346
+
347
+ step = 0
348
+ while True:
349
+ step += 1
350
+ redis_client.set(f"status_rank_{rank}", "Pre-coordinator sync")
351
+ if is_coord:
352
+ distributed_objs = [request_queue.get(None)]
353
+ logger.info("Objects from queue: %s", distributed_objs)
354
+ for worker_rank in range(1, world_size):
355
+ worker_message = {"message": START, "src": rank, "dst": worker_rank}
356
+ logger.info("Rank %s Sending: %s", rank, worker_message)
357
+ worker_queues[worker_rank].put(worker_message)
358
+ else:
359
+ distributed_objs = [None]
360
+ logger.info("Rank %s worker %s waiting for rank 0", rank, worker_id)
361
+ message_from_rank_0 = worker_queues[rank].get()
362
+ logger.info(
363
+ "Received message from rank 0 in rank %s: %s", rank, message_from_rank_0
364
+ )
365
+ if message_from_rank_0["message"] != START:
366
+ raise ValueError(
367
+ f"Unexpected message from rank 0: {message_from_rank_0['message']}"
368
+ )
369
+ redis_client.set(f"status_rank_{rank}", "Post-coordinator sync")
370
+
371
+ try:
372
+ logger.info(
373
+ "Broadcast Starting: Rank %s, worker %s, step %s",
374
+ rank,
375
+ worker_id,
376
+ step,
377
+ )
378
+ redis_client.set(f"status_rank_{rank}", "Pre-torch sync")
379
+ torch.distributed.broadcast_object_list(distributed_objs, src=0)
380
+ redis_client.set(f"status_rank_{rank}", "Post-torch sync")
381
+ logger.info(
382
+ "Broadcast Complete: Rank %s, worker %s, step %s",
383
+ rank,
384
+ worker_id,
385
+ step,
386
+ )
387
+ except RuntimeError as e:
388
+ logger.error(
389
+ "Rank %s, worker %s, step %s, Error detected in torch broadcast: %s",
390
+ rank,
391
+ worker_id,
392
+ step,
393
+ str(e),
394
+ )
395
+ raise
396
+
397
+ logger.info("rank %s, objs %s", rank, distributed_objs)
398
+ queue_id, data = distributed_objs[0]
399
+ mode = data.pop("mode")
400
+ request_id = data.pop("request_id")
401
+ assert queue_id == request_id
402
+ tokenized_prompt = data.pop("tokenized_prompt")
403
+ try:
404
+ match mode:
405
+ case WSMessageType.GENERATE_TEXT:
406
+ generator_fn = partial(
407
+ worker._generate_text_streaming, tokenized_prompt, **data
408
+ )
409
+ case WSMessageType.GENERATE_IMAGE:
410
+ generator_fn = partial(
411
+ worker._generate_image_streaming, tokenized_prompt, **data
412
+ )
413
+ case WSMessageType.GENERATE_MULTIMODAL:
414
+ generator_fn = partial(
415
+ worker._generate_multimodal_streaming, tokenized_prompt, **data
416
+ )
417
+ case _:
418
+ logger.error(
419
+ "Encountered unknown mode, crashing the program: %s", mode
420
+ )
421
+ response_queue.put(
422
+ queue_id, {"error": True, "final": True, "message": mode}
423
+ )
424
+ raise ValueError("Unknown mode")
425
+ logger.info("Rank: %s, Processing request: %s", rank, request_id)
426
+ i = 0
427
+ redis_client.set(f"status_rank_{rank}", "Pre-generate")
428
+ for output in generator_fn():
429
+ i += 1
430
+ if is_coord:
431
+ response = {"final": False, "output": output, "error": False}
432
+ logger.info(
433
+ "Rank: %s, Adding to response queue: %.100s",
434
+ rank,
435
+ response,
436
+ )
437
+ redis_client.set(f"status_rank_{rank}", f"Generate Pre Put {i}")
438
+ response_queue.put(queue_id, response)
439
+ redis_client.set(f"status_rank_{rank}", f"Generate Post Put {i}")
440
+ else:
441
+ redis_client.set(f"status_rank_{rank}", f"Generate {i}")
442
+ redis_client.set(f"step_on_rank_{rank}", i)
443
+ redis_client.set(f"status_rank_{rank}", "Post-generate")
444
+ if is_coord:
445
+ logger.info("Rank: %s, Adding final result to output queue", rank)
446
+ response_queue.put(queue_id, {"final": True, "error": False})
447
+ except torch.cuda.OutOfMemoryError as e:
448
+ logger.error("Encountered OOM, crashing the program: %s", e)
449
+ response_queue.put(
450
+ queue_id, {"error": True, "final": True, "message": str(e)}
451
+ )
452
+ crash_program()
453
+ except RuntimeError as e:
454
+ message = str(e)
455
+ if "CUDA" in message:
456
+ logger.error("Encountered CUDA error, crashing the program: %s", e)
457
+ response_queue.put(
458
+ queue_id, {"error": True, "final": True, "message": str(e)}
459
+ )
460
+ crash_program()
461
+ else:
462
+ logger.error(
463
+ "Encountered unexpected runtime error, crashing the program: %s %s",
464
+ e,
465
+ traceback.format_exc(),
466
+ )
467
+ response_queue.put(
468
+ queue_id, {"error": True, "final": True, "message": str(e)}
469
+ )
470
+ crash_program()
471
+ except Exception as e:
472
+ logger.error(
473
+ "Encountered unexpected exception: %s %s",
474
+ str(e),
475
+ traceback.format_exc(),
476
+ )
477
+ response_queue.put(
478
+ queue_id, {"error": True, "final": True, "message": str(e)}
479
+ )
480
+ crash_program()
481
+
482
+
483
+ class ChameleonWorker(ChameleonForwardMixin):
484
+ def __init__(
485
+ self,
486
+ *,
487
+ rank: int,
488
+ model_path: str,
489
+ tokenizer_path: str,
490
+ additional_eos_tokens: list[str] | None,
491
+ ) -> None:
492
+ self.rank = rank
493
+ self.model_path = model_path
494
+ self.additional_eos_tokens = additional_eos_tokens
495
+ torch.set_default_device(f"cuda:{rank}")
496
+ self.model = load_model(model_path, rank)
497
+ self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
498
+ self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
499
+ logger.info(
500
+ "Rank: %s, Model loaded in worker_obj: %s",
501
+ rank,
502
+ id(self),
503
+ )
504
+
505
+
506
+ def crash_program() -> None:
507
+ logger.error(
508
+ "Crashing the program as instructed, likely due to distributed worker failures"
509
+ )
510
+ sys.exit(1)
511
+
512
+
513
+ class ChameleonDistributedGenerator(AbstractMultimodalGenerator, ChameleonTokenizationMixin):
514
+ def __init__(
515
+ self,
516
+ *,
517
+ world_size: int,
518
+ model_path: str,
519
+ master_port: int,
520
+ tokenizer_path: str,
521
+ vqgan_config_path: str,
522
+ vqgan_ckpt_path: str | None = None,
523
+ master_address: str = "0.0.0.0",
524
+ additional_eos_tokens: list[str] | None = None,
525
+ redis_port: int | None = None,
526
+ ) -> None:
527
+ self.master_port = master_port
528
+ self.master_address = master_address
529
+ self.additional_eos_tokens = additional_eos_tokens
530
+ logger.info("Loading tokenizer...")
531
+ tokenizer_path = tokenizer_path
532
+ self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
533
+ self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
534
+
535
+ logger.info("Loading VQGAN...")
536
+ self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path)
537
+ self.redis_port = redis_port
538
+ self.redis_pool = async_redis.ConnectionPool.from_url(
539
+ f"redis://redis:{redis_port}"
540
+ )
541
+ self.redis_client = async_redis.Redis.from_pool(self.redis_pool)
542
+ self.request_queue = AsyncRedisQueue(self.redis_client, "request")
543
+ self.response_queue = AsyncRedisQueue(self.redis_client, "response")
544
+ self.worker_queues: dict[int, multiprocessing.Queue] = {
545
+ rank: multiprocessing.Queue() for rank in range(world_size)
546
+ }
547
+ self.procs: list[multiprocessing.Process] = []
548
+ model_args = {
549
+ "model_path": model_path,
550
+ "master_address": master_address,
551
+ "master_port": master_port,
552
+ "tokenizer_path": tokenizer_path,
553
+ "additional_eos_tokens": additional_eos_tokens,
554
+ }
555
+ logger.info("Launching paralle model with world_size=%s", world_size)
556
+ for i in range(world_size):
557
+ proc = multiprocessing.Process(
558
+ target=distributed_workers,
559
+ args=(
560
+ model_args,
561
+ master_address,
562
+ master_port,
563
+ world_size,
564
+ i,
565
+ self.redis_port,
566
+ self.worker_queues,
567
+ ),
568
+ daemon=True,
569
+ )
570
+ self.procs.append(proc)
571
+ proc.start()
572
+
573
+ def check_error(self, output: dict) -> None:
574
+ if output["error"]:
575
+ import sys
576
+ print(f"check_error({output})", file=sys.stderr)
577
+ self.kill_procs()
578
+ logger.error(
579
+ "COORDINATOR: Encountered error in managed processes, exiting: %s",
580
+ output,
581
+ )
582
+ crash_program()
583
+
584
+ def __del__(self) -> None:
585
+ self.kill_procs(error=False)
586
+
587
+ def kill_procs(self, error: bool = True) -> None:
588
+ if error:
589
+ log_fn = logger.error
590
+ else:
591
+ log_fn = logger.info
592
+ log_fn("Error encountered, killing worker procs: %s", self.procs)
593
+ for p in self.procs:
594
+ try:
595
+ log_fn("Killing: %s", p)
596
+ p.kill()
597
+ except:
598
+ log_fn("Encountered issue killing process and ignoring: %s", p)
599
+
600
+ # ALLOW_ANY(get_next_output.return)
601
+ async def get_next_output(self, request_id: str) -> Any:
602
+ logger.info("Waiting for response for request_id=%s", request_id)
603
+ queue_id, output = await self.response_queue.get(request_id)
604
+ assert queue_id == request_id
605
+ return output
606
+
607
+ async def generate_text_streaming(
608
+ self,
609
+ prompt: MixedSequenceType,
610
+ max_gen_tokens: int = 256,
611
+ temp: float = 1.0,
612
+ top_p: float = 0.8,
613
+ repetition_penalty: float = 1.2,
614
+ seed: int | None = None,
615
+ debug: dict | None = None,
616
+ ) -> Generator[str, None, None]:
617
+ tokenized_prompt = self.tokens_from_inputs(prompt)
618
+ request_id = f"request_{random.randint(100_000, 200_000)}"
619
+ if seed is None:
620
+ seed = random.randint(1, 2048)
621
+ if debug is not None:
622
+ debug["seed"] = seed
623
+ if len(tokenized_prompt) > (4096 - 3):
624
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
625
+ return
626
+ assert not isinstance(tokenized_prompt, torch.Tensor)
627
+ request = {
628
+ "mode": WSMessageType.GENERATE_TEXT.value,
629
+ "request_id": request_id,
630
+ "tokenized_prompt": tokenized_prompt,
631
+ "max_gen_tokens": max_gen_tokens,
632
+ "temp": temp,
633
+ "top_p": top_p,
634
+ "repetition_penalty": repetition_penalty,
635
+ "seed": seed,
636
+ }
637
+ logger.info(
638
+ "Sending request_id=%s: %s",
639
+ request_id,
640
+ request,
641
+ )
642
+ await asyncio.gather(
643
+ self.request_queue.clear(request_id),
644
+ self.response_queue.clear(request_id),
645
+ )
646
+ logger.info("Cleared request/response queue for %s", request_id)
647
+ await self.request_queue.put(request_id, request)
648
+ logger.info("Sent request to coordinator %s", request_id)
649
+ try:
650
+ while True:
651
+ output = await self.get_next_output(request_id)
652
+ logger.info("Received response for %s", request_id)
653
+ self.check_error(output)
654
+ if output["final"]:
655
+ break
656
+
657
+ n_outs = len(output["output"])
658
+ if n_outs != 1:
659
+ logger.error(
660
+ "Encountered unexpected number of %s arguments in: %s",
661
+ n_outs,
662
+ output["output"],
663
+ )
664
+ tokens = output["output"]
665
+ assert not isinstance(tokens, torch.Tensor)
666
+ logger.info("output info: type=%s, value=%.20s", type(tokens), tokens)
667
+ yield self.tokenizer.decode(tokens)
668
+ finally:
669
+ logger.info("Cleaning up queues in request_id=%s", request_id)
670
+ await asyncio.gather(
671
+ self.request_queue.clear(request_id),
672
+ self.response_queue.clear(request_id),
673
+ )
674
+ logger.info("Completed cleaning for request_id=%s", request_id)
675
+
676
+ async def generate_image_streaming(
677
+ self,
678
+ prompt: MixedSequenceType,
679
+ temp: float = 1.0,
680
+ top_p: float = 0.8,
681
+ cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
682
+ cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
683
+ yield_every_n: int = 32,
684
+ debug: dict | None = None,
685
+ seed: int | None = None,
686
+ ) -> Generator[StreamingImage, None, None]:
687
+ tokenized_prompt = self.tokens_from_inputs(prompt)
688
+ tokenized_prompt.append(self.vocab.begin_image)
689
+ assert not isinstance(tokenized_prompt, torch.Tensor)
690
+ request_id = f"request_{random.randint(100_000, 200_000)}"
691
+ if seed is None:
692
+ seed = random.randint(1, 2048)
693
+ if debug is not None:
694
+ debug["seed"] = seed
695
+ if len(tokenized_prompt) > (4096 - 3 - 1024):
696
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
697
+ return
698
+ request = {
699
+ "mode": WSMessageType.GENERATE_IMAGE.value,
700
+ "request_id": request_id,
701
+ "tokenized_prompt": tokenized_prompt,
702
+ "cfg_image_weight": cfg_image_weight,
703
+ "cfg_text_weight": cfg_text_weight,
704
+ "yield_every_n": yield_every_n,
705
+ "temp": temp,
706
+ "top_p": top_p,
707
+ "seed": seed,
708
+ }
709
+ logger.info(
710
+ "Sending request_id=%s: %s",
711
+ request_id,
712
+ request,
713
+ )
714
+ await asyncio.gather(
715
+ self.request_queue.clear(request_id),
716
+ self.response_queue.clear(request_id),
717
+ )
718
+ logger.info("Cleared request/response queue for %s", request_id)
719
+ await self.request_queue.put(request_id, request)
720
+ logger.info("Sent request to coordinator %s", request_id)
721
+ try:
722
+ while True:
723
+ output = await self.get_next_output(request_id)
724
+ logger.info("Received response for %s", request_id)
725
+ self.check_error(output)
726
+ if output["final"]:
727
+ break
728
+ n_outs = len(output["output"])
729
+ if n_outs != 2:
730
+ logger.error(
731
+ "Encountered unexpected number of %s arguments in: %s",
732
+ n_outs,
733
+ output["output"],
734
+ )
735
+ tokens, final = output["output"]
736
+ assert not isinstance(tokens, torch.Tensor)
737
+ yield StreamingImage(
738
+ image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final
739
+ )
740
+ finally:
741
+ logger.info("Cleaning up queues in request_id=%s", request_id)
742
+ await asyncio.gather(
743
+ self.request_queue.clear(request_id),
744
+ self.response_queue.clear(request_id),
745
+ )
746
+ logger.info("Completed cleaning for request_id=%s", request_id)
747
+
748
+ async def generate_multimodal_streaming(
749
+ self,
750
+ prompt: MixedSequenceType,
751
+ temp: float = 1.0,
752
+ top_p: float = 0.8,
753
+ cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
754
+ cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
755
+ yield_every_n: int = 32,
756
+ max_gen_tokens: int = 4096,
757
+ repetition_penalty: float = 1.2,
758
+ suffix_tokens: list[str] | None = None,
759
+ seed: int | None = None,
760
+ debug: dict | None = None,
761
+ ) -> Generator[MixedSequenceType, None, None]:
762
+ tokenized_prompt = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens)
763
+ assert not isinstance(tokenized_prompt, torch.Tensor)
764
+ request_id = f"request_{random.randint(100_000, 200_000)}"
765
+ if seed is None:
766
+ seed = random.randint(1, 2048)
767
+ if debug is not None:
768
+ debug["seed"] = seed
769
+ if len(tokenized_prompt) > (4096 - 3):
770
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens."
771
+ return
772
+
773
+ request = {
774
+ "mode": WSMessageType.GENERATE_MULTIMODAL.value,
775
+ "request_id": request_id,
776
+ "tokenized_prompt": tokenized_prompt,
777
+ "cfg_image_weight": cfg_image_weight,
778
+ "cfg_text_weight": cfg_text_weight,
779
+ "repetition_penalty": repetition_penalty,
780
+ "yield_every_n": yield_every_n,
781
+ "max_gen_tokens": max_gen_tokens,
782
+ "temp": temp,
783
+ "top_p": top_p,
784
+ "seed": seed,
785
+ }
786
+ logger.info(
787
+ "Sending request_id=%s: %s",
788
+ request_id,
789
+ request,
790
+ )
791
+ await asyncio.gather(
792
+ self.request_queue.clear(request_id),
793
+ self.response_queue.clear(request_id),
794
+ )
795
+ logger.info("Cleared request/response queue for %s", request_id)
796
+ await self.request_queue.put(request_id, request)
797
+ logger.info("Sent request to coordinator %s", request_id)
798
+ try:
799
+ while True:
800
+ output = await self.get_next_output(request_id)
801
+ logger.info("Received response for %s", request_id)
802
+ self.check_error(output)
803
+ if output["final"]:
804
+ break
805
+ n_outs = len(output["output"])
806
+ if n_outs != 3:
807
+ logger.error(
808
+ "Encountered unexpected number of %s arguments in: %s",
809
+ n_outs,
810
+ output["output"],
811
+ )
812
+ token_type, tokens, image_is_final = output["output"]
813
+ assert not isinstance(tokens, torch.Tensor)
814
+ match token_type:
815
+ case "TEXT":
816
+ yield self.tokenizer.decode(tokens)
817
+ case "IMAGE":
818
+ yield StreamingImage(
819
+ image=self.pillow_from_bpe_tokens(torch.tensor(tokens)),
820
+ final=image_is_final,
821
+ )
822
+ case _:
823
+ raise ValueError("Unknown token type")
824
+ finally:
825
+ logger.info("Cleaning up queues in request_id=%s", request_id)
826
+ await self.request_queue.clear(request_id)
827
+ await self.response_queue.clear(request_id)
chameleon/viewer/backend/models/chameleon_local.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import io
7
+ import json
8
+ from typing import Generator
9
+
10
+ import PIL.Image
11
+ import torch
12
+ import transformers
13
+ from tokenizers import Tokenizer
14
+ from transformers import (
15
+ MaxLengthCriteria,
16
+ RepetitionPenaltyLogitsProcessor,
17
+ TemperatureLogitsWarper,
18
+ TopPLogitsWarper,
19
+ )
20
+
21
+ from chameleon.inference.alignment import AlignPromptRight
22
+ from chameleon.inference.generation import ChameleonGenerator
23
+ from chameleon.inference.image_tokenizer import ImageTokenizer
24
+ from chameleon.inference.loader import load_model
25
+ from chameleon.inference.logits_processor import (
26
+ AllowOnlyTokensAfterIndexLogitsProcessor,
27
+ AllowOnlyTokensLogitsProcessor,
28
+ InBatchInstructCFGLogitsProcessor,
29
+ )
30
+ from chameleon.inference.model_adapter import ChameleonModelAdapter
31
+ from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex
32
+ from chameleon.inference.token_selector import (
33
+ MultinomialTokenSelector,
34
+ ReplicatedInputTokenSelector,
35
+ )
36
+ from chameleon.inference.vocab import VocabInfo, VocabTranslation
37
+ from chameleon.viewer.backend.models.abstract_model import (
38
+ DEFAULT_IMAGE_CFG_IMAGE,
39
+ DEFAULT_IMAGE_CFG_TEXT,
40
+ DEFAULT_MULTIMODAL_CFG_IMAGE,
41
+ DEFAULT_MULTIMODAL_CFG_TEXT,
42
+ AbstractMultimodalGenerator,
43
+ MixedSequenceType,
44
+ StreamingImage,
45
+ )
46
+ from chameleon.viewer.backend.utils import get_logger
47
+
48
+ logger = get_logger(__name__)
49
+
50
+
51
+ def set_seed(seed: int) -> None:
52
+ transformers.enable_full_determinism(seed, warn_only=True)
53
+
54
+
55
+ def get_rank() -> int:
56
+ if torch.distributed.is_initialized():
57
+ return torch.distributed.get_rank()
58
+ else:
59
+ return 0
60
+
61
+
62
+ class ChameleonTokenizationMixin:
63
+ def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
64
+ img = self.pillow_from_bpe_tokens(bpe_tokens)
65
+
66
+ img_io = io.BytesIO()
67
+ img.save(img_io, format="PNG")
68
+ return img_io.getvalue()
69
+
70
+ def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image:
71
+ image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens)
72
+ if image_tensor.shape[0] < 1024:
73
+ padding = (
74
+ torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0]
75
+ )
76
+ image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
77
+
78
+ return self.image_tokenizer.pil_from_img_toks(image_tensor)
79
+
80
+ def tokens_from_inputs(
81
+ self,
82
+ inputs: MixedSequenceType,
83
+ suffix_tokens: list[str] | None = None,
84
+ ) -> list[int]:
85
+ tokens = [self.vocab.bos_id]
86
+ for input_ in inputs:
87
+ if isinstance(input_, str):
88
+ tokens.extend(self.tokenizer.encode(input_.strip()).ids)
89
+ elif isinstance(input_, PIL.Image.Image):
90
+ tokens.append(self.vocab.begin_image)
91
+ imgtoks = self.image_tokenizer.img_tokens_from_pil(input_)
92
+ tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks))
93
+ tokens.append(self.vocab.end_image)
94
+ else:
95
+ raise ValueError(f"Unknown input type: {type(input_)}")
96
+
97
+ if suffix_tokens is not None:
98
+ for t in suffix_tokens:
99
+ tokens.extend(self.tokenizer.encode(t).ids)
100
+ sanitized_tokens = []
101
+ for t in tokens:
102
+ if isinstance(t, torch.Tensor):
103
+ sanitized_tokens.append(t.item())
104
+ else:
105
+ sanitized_tokens.append(t)
106
+ return sanitized_tokens
107
+
108
+
109
+ class GeneratorWrapper:
110
+ def __init__(self, gen):
111
+ self.gen = gen
112
+
113
+ def __iter__(self):
114
+ return self
115
+
116
+ def __next__(self):
117
+ return next(self.gen)
118
+
119
+
120
+ class Decoder:
121
+ def __init__(
122
+ self,
123
+ chameleon_generator: "ChameleonLocalGenerator",
124
+ input_ids: list[int],
125
+ ):
126
+ ...
127
+
128
+ def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]:
129
+ ...
130
+
131
+
132
+ class TextDecoder(Decoder):
133
+ def __init__(
134
+ self,
135
+ chameleon_generator: "ChameleonLocalGenerator",
136
+ input_ids: list[int],
137
+ *,
138
+ temp: float,
139
+ top_p: float,
140
+ max_seq_len: int,
141
+ # TODO: Propagage setting upwards
142
+ repetition_penalty: float,
143
+ **kwargs,
144
+ ):
145
+ self.chameleon_generator = chameleon_generator
146
+ assert chameleon_generator.vocab.eos_id is not None
147
+
148
+ stopping_criteria = [
149
+ StopOnEOS(chameleon_generator.vocab.eos_id),
150
+ MaxLengthCriteria(max_seq_len),
151
+ ]
152
+ if chameleon_generator.additional_eos_tokens is not None:
153
+ for token in chameleon_generator.additional_eos_tokens:
154
+ stopping_criteria.append(
155
+ StopOnEOSAfterBatchIndex(
156
+ chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)]
157
+ )
158
+ )
159
+
160
+ logits_processors = [
161
+ AllowOnlyTokensLogitsProcessor(
162
+ chameleon_generator.vocab.text_tokens
163
+ + [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image]
164
+ ),
165
+ # Don't allow any more images near the end since there isn't enough room
166
+ AllowOnlyTokensAfterIndexLogitsProcessor(
167
+ chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id],
168
+ # TODO: Calculate exact
169
+ 1024 * 3 - 3,
170
+ ),
171
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
172
+ TemperatureLogitsWarper(temp),
173
+ TopPLogitsWarper(top_p),
174
+ ]
175
+
176
+ self.gen = ChameleonGenerator(
177
+ model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len),
178
+ input_ids=[input_ids],
179
+ stopping_criteria=stopping_criteria,
180
+ logits_processors=logits_processors,
181
+ )
182
+ for _ in range(len(input_ids)):
183
+ next(self.gen)
184
+
185
+ def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
186
+ gpu_tok = next(self.gen).id.item()
187
+ cpu_tok = gpu_tok
188
+ if cpu_tok == self.chameleon_generator.vocab.begin_image:
189
+ # return "TEXT", [cpu_tok], [], False, ImageDecoder
190
+ raise StopIteration()
191
+
192
+ return (
193
+ "TEXT",
194
+ [cpu_tok],
195
+ [cpu_tok],
196
+ False,
197
+ None,
198
+ )
199
+
200
+
201
+ class ImageDecoder(Decoder):
202
+ def __init__(
203
+ self,
204
+ chameleon_generator: "ChameleonLocalGenerator",
205
+ input_ids: list[int],
206
+ *,
207
+ cfg_image_weight: float,
208
+ cfg_text_weight: float,
209
+ temp: float,
210
+ top_p: float,
211
+ yield_every_n: int,
212
+ **kwargs,
213
+ ):
214
+ self.yield_every_n = yield_every_n
215
+ self.chameleon_generator = chameleon_generator
216
+ logits_processors = [
217
+ InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight),
218
+ AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens),
219
+ TemperatureLogitsWarper(temp),
220
+ TopPLogitsWarper(top_p),
221
+ ]
222
+
223
+ image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | {
224
+ chameleon_generator.vocab.bos_id,
225
+ chameleon_generator.vocab.begin_image,
226
+ chameleon_generator.vocab.end_image,
227
+ }
228
+
229
+ full_conditioned = input_ids
230
+ image_conditioned = [
231
+ in_id for in_id in input_ids if in_id in image_conditioned_allowed
232
+ ]
233
+ unconditioned = [
234
+ chameleon_generator.vocab.bos_id,
235
+ chameleon_generator.vocab.begin_image,
236
+ ]
237
+
238
+ self.gen = ChameleonGenerator(
239
+ model=ChameleonModelAdapter(
240
+ chameleon_generator.model, max_seq_len=len(input_ids) + 1024
241
+ ),
242
+ input_ids=[full_conditioned, image_conditioned, unconditioned],
243
+ logits_processors=logits_processors,
244
+ alignment=AlignPromptRight(chameleon_generator.vocab.pad_id),
245
+ token_selector=ReplicatedInputTokenSelector(
246
+ MultinomialTokenSelector(), n=3
247
+ ),
248
+ )
249
+ for _ in range(len(input_ids)):
250
+ next(self.gen)
251
+ self.image_builder: list[torch.LongTensor] = []
252
+ self.gpu_tok_batch: list[torch.LongTensor] = []
253
+
254
+ def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
255
+ while True:
256
+ gpu_tok = next(self.gen)
257
+ gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0]
258
+
259
+ self.image_builder.append(gpu_tok)
260
+ self.gpu_tok_batch.append(gpu_tok)
261
+
262
+ if len(self.image_builder) == 1024:
263
+ return (
264
+ "IMAGE",
265
+ torch.tensor(self.gpu_tok_batch).tolist()
266
+ + [self.chameleon_generator.vocab.end_image],
267
+ torch.tensor(self.image_builder).tolist(),
268
+ True,
269
+ TextDecoder,
270
+ )
271
+ elif len(self.image_builder) % self.yield_every_n == 0:
272
+ cpu_toks = torch.tensor(self.gpu_tok_batch).tolist()
273
+ self.gpu_tok_batch = []
274
+
275
+ return (
276
+ "IMAGE",
277
+ cpu_toks,
278
+ torch.tensor(self.image_builder).tolist(),
279
+ False,
280
+ None,
281
+ )
282
+
283
+
284
+ class ChameleonForwardMixin:
285
+ @torch.inference_mode()
286
+ def _generate_text_streaming(
287
+ self,
288
+ input_ids: list[int],
289
+ max_gen_tokens: int = 256,
290
+ temp: float = 1.0,
291
+ top_p: float = 0.8,
292
+ repetition_penalty: float = 1.2,
293
+ seed: int | None = None,
294
+ ) -> Generator[str, None, None]:
295
+ if seed is not None:
296
+ set_seed(seed)
297
+ logger.info(
298
+ "Rank: %s, set seed: %s",
299
+ get_rank(),
300
+ seed,
301
+ )
302
+
303
+ logits_processors = [
304
+ # Only allow text tokens and end-of-sequence.
305
+ AllowOnlyTokensLogitsProcessor(
306
+ self.vocab.text_tokens + [self.vocab.eos_id]
307
+ ),
308
+ # Don't allow the first token to be end-of-sequence.
309
+ # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
310
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
311
+ TemperatureLogitsWarper(temp),
312
+ TopPLogitsWarper(top_p),
313
+ ]
314
+
315
+ stopping_criteria = [
316
+ StopOnEOS(self.vocab.eos_id),
317
+ MaxLengthCriteria(len(input_ids) + max_gen_tokens),
318
+ ]
319
+ if self.additional_eos_tokens is not None:
320
+ for token in self.additional_eos_tokens:
321
+ stopping_criteria.append(
322
+ StopOnEOSAfterBatchIndex(
323
+ self.tokenizer.token_to_id(token), [len(input_ids)]
324
+ )
325
+ )
326
+ for tok in ChameleonGenerator(
327
+ model=ChameleonModelAdapter(
328
+ self.model,
329
+ max_seq_len=len(input_ids) + max_gen_tokens,
330
+ ),
331
+ input_ids=[input_ids],
332
+ stopping_criteria=stopping_criteria,
333
+ logits_processors=logits_processors,
334
+ ):
335
+ yield tok.tolist()
336
+
337
+ @torch.inference_mode()
338
+ def _generate_batched_text_streaming(
339
+ self,
340
+ batch: list[list[int]],
341
+ max_gen_tokens: int = 256,
342
+ temp: float = 1.0,
343
+ top_p: float = 0.8,
344
+ repetition_penalty: float = 1.2,
345
+ seed: int | None = None,
346
+ ) -> Generator[list[str], None, None]:
347
+ if seed is not None:
348
+ set_seed(seed)
349
+ logits_processors = [
350
+ # Only allow text tokens and end-of-sequence.
351
+ AllowOnlyTokensLogitsProcessor(
352
+ self.vocab.text_tokens + [self.vocab.eos_id]
353
+ ),
354
+ # Don't allow the first token to be end-of-sequence.
355
+ # DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
356
+ RepetitionPenaltyLogitsProcessor(repetition_penalty),
357
+ TemperatureLogitsWarper(temp),
358
+ TopPLogitsWarper(top_p),
359
+ ]
360
+
361
+ max_batch_size = max(len(p) for p in batch)
362
+ stopping_criteria = [
363
+ StopOnEOS(self.vocab.eos_id),
364
+ MaxLengthCriteria(max_batch_size + max_gen_tokens),
365
+ ]
366
+ if self.additional_eos_tokens is not None:
367
+ for token in self.additional_eos_tokens:
368
+ stopping_criteria.append(
369
+ StopOnEOSAfterBatchIndex(
370
+ self.tokenizer.token_to_id(token), [len(x) for x in batch]
371
+ )
372
+ )
373
+ for tok in ChameleonGenerator(
374
+ model=ChameleonModelAdapter(
375
+ self.model,
376
+ max_seq_len=max_batch_size + max_gen_tokens,
377
+ ),
378
+ input_ids=batch,
379
+ stopping_criteria=stopping_criteria,
380
+ logits_processors=logits_processors,
381
+ ):
382
+ yield tok.unsqueeze(1).tolist()
383
+
384
+ @torch.inference_mode()
385
+ def _generate_image_streaming(
386
+ self,
387
+ tokenized_prompt: list[int],
388
+ temp: float = 1.0,
389
+ top_p: float = 0.8,
390
+ cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
391
+ cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
392
+ yield_every_n: int = 32,
393
+ seed: int | None = None,
394
+ ) -> Generator[tuple[list[int], bool], None, None]:
395
+ if seed is not None:
396
+ set_seed(seed)
397
+ logger.info(
398
+ "Rank: %s, set seed: %s",
399
+ get_rank(),
400
+ seed,
401
+ )
402
+
403
+ decoder = ImageDecoder(
404
+ self,
405
+ tokenized_prompt,
406
+ cfg_image_weight=cfg_image_weight,
407
+ cfg_text_weight=cfg_text_weight,
408
+ temp=temp,
409
+ top_p=top_p,
410
+ yield_every_n=yield_every_n,
411
+ )
412
+
413
+ for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder):
414
+ if next_decoder is not None:
415
+ break
416
+
417
+ yield torch.tensor(frontend_tokens).tolist(), is_final
418
+
419
+ @torch.inference_mode()
420
+ def _generate_multimodal_streaming(
421
+ self,
422
+ input_ids: list[int],
423
+ temp: float = 1.0,
424
+ top_p: float = 0.8,
425
+ cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
426
+ cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
427
+ yield_every_n: int = 32,
428
+ max_gen_tokens: int = 4096,
429
+ repetition_penalty: float = 1.2,
430
+ seed: int | None = None,
431
+ ) -> Generator[tuple[str, list[int], bool], None, None]:
432
+ if seed is not None:
433
+ set_seed(seed)
434
+ logger.info(
435
+ "Rank: %s, set seed: %s",
436
+ get_rank(),
437
+ seed,
438
+ )
439
+ max_seq_len = min(len(input_ids) + max_gen_tokens, 4096)
440
+ gen_wrapper = GeneratorWrapper(
441
+ TextDecoder(
442
+ self,
443
+ input_ids,
444
+ temp=temp,
445
+ top_p=top_p,
446
+ max_seq_len=max_seq_len,
447
+ repetition_penalty=repetition_penalty,
448
+ )
449
+ )
450
+
451
+ for (
452
+ message_type,
453
+ cpu_toks,
454
+ frontend_tokens,
455
+ is_final,
456
+ next_decoder,
457
+ ) in gen_wrapper:
458
+ input_ids.extend(cpu_toks)
459
+ if len(frontend_tokens) > 0:
460
+ yield message_type, frontend_tokens, is_final
461
+ if next_decoder is not None:
462
+ gen_wrapper.gen = next_decoder(
463
+ self,
464
+ input_ids,
465
+ temp=temp,
466
+ top_p=top_p,
467
+ max_seq_len=max_seq_len,
468
+ cfg_image_weight=cfg_image_weight,
469
+ cfg_text_weight=cfg_text_weight,
470
+ yield_every_n=yield_every_n,
471
+ repetition_penalty=repetition_penalty,
472
+ )
473
+
474
+
475
+ class ChameleonLocalGenerator(
476
+ AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin
477
+ ):
478
+ def __init__(
479
+ self,
480
+ model_path: str,
481
+ tokenizer_path: str,
482
+ vqgan_config_path: str,
483
+ vqgan_ckpt_path: str | None = None,
484
+ additional_eos_tokens: list[str] | None = None,
485
+ ) -> None:
486
+ super().__init__()
487
+ logger.info("Loading model...")
488
+ self.model = load_model(model_path)
489
+ self.additional_eos_tokens = additional_eos_tokens
490
+
491
+ logger.info("Loading tokenizer...")
492
+ tokenizer_path = tokenizer_path
493
+ self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
494
+ self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
495
+
496
+ logger.info("Loading VQGAN...")
497
+ self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path)
498
+
499
+ @torch.inference_mode()
500
+ def generate_batched_text(
501
+ self,
502
+ prompts: list[MixedSequenceType],
503
+ max_gen_tokens: int = 256,
504
+ temp: float = 1.0,
505
+ top_p: float = 0.8,
506
+ repetition_penalty: float = 1.2,
507
+ seed: int | None = None,
508
+ ) -> list[str]:
509
+ outputs = [""] * len(prompts)
510
+ for vals in self.generate_batched_text_streaming(
511
+ prompts,
512
+ max_gen_tokens=max_gen_tokens,
513
+ temp=temp,
514
+ top_p=top_p,
515
+ repetition_penalty=repetition_penalty,
516
+ seed=seed,
517
+ ):
518
+ for idx, val in enumerate(vals):
519
+ outputs[idx] += val
520
+ return outputs
521
+
522
+ @torch.inference_mode()
523
+ def generate_batched_text_streaming(
524
+ self,
525
+ prompts: list[MixedSequenceType],
526
+ max_gen_tokens: int = 256,
527
+ temp: float = 1.0,
528
+ top_p: float = 0.8,
529
+ repetition_penalty: float = 1.2,
530
+ seed: int | None = None,
531
+ ) -> Generator[list[str], None, None]:
532
+ batch = []
533
+ for prompt in prompts:
534
+ batch.append(self.tokens_from_inputs(prompt))
535
+
536
+ for tok in self._generate_batched_text_streaming(
537
+ batch,
538
+ max_gen_tokens=max_gen_tokens,
539
+ temp=temp,
540
+ top_p=top_p,
541
+ repetition_penalty=repetition_penalty,
542
+ seed=seed,
543
+ ):
544
+ yield self.tokenizer.decode_batch(tok)
545
+
546
+ @torch.inference_mode()
547
+ async def generate_text_streaming(
548
+ self,
549
+ prompt: MixedSequenceType,
550
+ max_gen_tokens: int = 256,
551
+ temp: float = 1.0,
552
+ top_p: float = 0.8,
553
+ repetition_penalty: float = 1.2,
554
+ seed: int | None = None,
555
+ debug: dict | None = None,
556
+ ) -> Generator[str, None, None]:
557
+ tokenized_prompt = self.tokens_from_inputs(prompt)
558
+ if len(tokenized_prompt) > (4096 - 3):
559
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
560
+ return
561
+ for out in self.generate_batched_text_streaming(
562
+ [prompt],
563
+ max_gen_tokens=max_gen_tokens,
564
+ temp=temp,
565
+ top_p=top_p,
566
+ repetition_penalty=repetition_penalty,
567
+ seed=seed,
568
+ ):
569
+ yield out[0]
570
+
571
+ @torch.inference_mode()
572
+ async def generate_image_streaming(
573
+ self,
574
+ prompt: MixedSequenceType,
575
+ temp: float = 1.0,
576
+ top_p: float = 0.8,
577
+ cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
578
+ cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
579
+ yield_every_n: int = 32,
580
+ seed: int | None = None,
581
+ debug: dict | None = None,
582
+ ) -> Generator[StreamingImage, None, None]:
583
+ assert isinstance(prompt, list)
584
+ tokenized_prompt = self.tokens_from_inputs(prompt)
585
+ tokenized_prompt.append(self.vocab.begin_image)
586
+ if len(tokenized_prompt) > (4096 - 3 - 1024):
587
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
588
+ return
589
+ for tokens, final in self._generate_image_streaming(
590
+ tokenized_prompt,
591
+ temp=temp,
592
+ top_p=top_p,
593
+ cfg_image_weight=cfg_image_weight,
594
+ cfg_text_weight=cfg_text_weight,
595
+ yield_every_n=yield_every_n,
596
+ seed=seed,
597
+ ):
598
+ yield StreamingImage(
599
+ image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final
600
+ )
601
+
602
+ @torch.inference_mode()
603
+ async def generate_multimodal_streaming(
604
+ self,
605
+ prompt: MixedSequenceType,
606
+ temp: float = 1.0,
607
+ top_p: float = 0.8,
608
+ cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
609
+ cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
610
+ yield_every_n: int = 32,
611
+ max_gen_tokens: int = 4096,
612
+ repetition_penalty: float = 1.2,
613
+ suffix_tokens: list[str] | None = None,
614
+ seed: int | None = None,
615
+ debug: dict | None = None,
616
+ ) -> Generator[MixedSequenceType, None, None]:
617
+ input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens)
618
+ if len(input_ids) > (4096 - 3):
619
+ yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens."
620
+ return
621
+
622
+ for token_type, tokens, is_final in self._generate_multimodal_streaming(
623
+ input_ids,
624
+ temp=temp,
625
+ top_p=top_p,
626
+ cfg_image_weight=cfg_image_weight,
627
+ cfg_text_weight=cfg_text_weight,
628
+ yield_every_n=yield_every_n,
629
+ max_gen_tokens=max_gen_tokens,
630
+ repetition_penalty=repetition_penalty,
631
+ seed=seed,
632
+ ):
633
+ match token_type:
634
+ case "TEXT":
635
+ yield self.tokenizer.decode(tokens)
636
+ case "IMAGE":
637
+ yield StreamingImage(
638
+ image=self.pillow_from_bpe_tokens(torch.tensor(tokens)),
639
+ final=is_final,
640
+ )
641
+ case _:
642
+ raise ValueError("Unknown token type")
chameleon/viewer/backend/models/service.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import base64
7
+ import io
8
+ import socket
9
+ import subprocess
10
+ import time
11
+ from functools import partial
12
+
13
+ import fastapi
14
+ import PIL
15
+ import pydantic
16
+ import redis.asyncio as async_redis
17
+ import uvicorn
18
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException
19
+ from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
20
+
21
+ from chameleon.viewer.backend.data_types import (
22
+ Content,
23
+ ContentType,
24
+ NoOptionsForComplete,
25
+ NoOptionsForFull,
26
+ NoOptionsForPartial,
27
+ NoOptionsForQueueStatus,
28
+ WSMessageType,
29
+ WSMultimodalMessage,
30
+ )
31
+ from chameleon.viewer.backend.models.abstract_model import (
32
+ AbstractMultimodalGenerator,
33
+ StreamingImage,
34
+ )
35
+ from chameleon.viewer.backend.models.chameleon_distributed import AsyncRedisCounter
36
+ from chameleon.viewer.backend.utils import get_logger
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ def nvidia_smi() -> str:
42
+ return subprocess.check_output(["nvidia-smi"], text=True)
43
+
44
+
45
+ async def await_generate_message(websocket: WebSocket) -> WSMultimodalMessage:
46
+ while True:
47
+ rec_message = await websocket.receive_json()
48
+ try:
49
+ maybe_message = WSMultimodalMessage.parse_obj(rec_message)
50
+ except pydantic.ValidationError:
51
+ maybe_message = None
52
+ logger.info("Got invalid message", maybe_message)
53
+ if maybe_message is not None:
54
+ return maybe_message
55
+
56
+
57
+ async def async_acquire_lock(
58
+ *,
59
+ websocket: WebSocket,
60
+ counter: AsyncRedisCounter,
61
+ lock: async_redis.lock.Lock,
62
+ interval=0.1,
63
+ status_interval=1,
64
+ hostname: str | None = None,
65
+ ):
66
+ start = time.time()
67
+ await counter.add(1)
68
+ while True:
69
+ acquired = await lock.acquire(blocking_timeout=interval)
70
+ if acquired:
71
+ break
72
+ elapsed = time.time() - start
73
+ if elapsed > status_interval:
74
+ n_requests = await counter.count()
75
+ message = WSMultimodalMessage(
76
+ message_type=WSMessageType.QUEUE_STATUS,
77
+ content=[
78
+ Content(
79
+ content_type=ContentType.TEXT,
80
+ content=f"n_requests={n_requests}",
81
+ )
82
+ ],
83
+ options=NoOptionsForQueueStatus(),
84
+ debug_info={"hostname": hostname},
85
+ ).dict()
86
+ await websocket.send_json(message)
87
+ start = time.time()
88
+ await counter.sub(1)
89
+
90
+
91
+ COORDINATOR = "coordinator"
92
+
93
+
94
+ def web_app(
95
+ generator: AbstractMultimodalGenerator,
96
+ debug: bool = True,
97
+ redis_port: int | None = None,
98
+ ) -> FastAPI:
99
+ app = FastAPI(debug=debug)
100
+ if redis_port is None:
101
+ redis_client = None
102
+ redis_lock = None
103
+ queue_counter = None
104
+ else:
105
+ redis_client = async_redis.Redis.from_url(f"redis://redis:{redis_port}")
106
+ redis_lock = async_redis.lock.Lock(redis_client, COORDINATOR)
107
+ queue_counter = AsyncRedisCounter(redis_client, "count_pending")
108
+ hostname = socket.gethostname()
109
+
110
+ @app.get("/api/2.0/status")
111
+ def alive() -> dict:
112
+ return {
113
+ "status": "alive",
114
+ "hostname": hostname,
115
+ "nvidia-smi": nvidia_smi(),
116
+ }
117
+
118
+ @app.websocket("/ws/chameleon/v2/{client_id}")
119
+ async def websocket_chameleon_v2(*, websocket: WebSocket, client_id: str):
120
+ logger.info("Requested client_id: %s", client_id)
121
+ await websocket.accept()
122
+ logger.info("Client opened %s with generator id %s", client_id, id(generator))
123
+
124
+ try:
125
+ while True:
126
+ generate_message = await await_generate_message(websocket)
127
+ logger.info("Got generate message: %s", str(generate_message)[:300])
128
+ parsed_prompt = []
129
+ for c in generate_message.content:
130
+ match c.content_type:
131
+ case ContentType.TEXT:
132
+ parsed_prompt.append(c.content)
133
+ case ContentType.IMAGE:
134
+ image_parts = c.content.split(",", 1)
135
+ if len(image_parts) < 2:
136
+ logger.error(
137
+ "Encountered invalid image: %s", image_parts
138
+ )
139
+ raise WebSocketException(
140
+ code=fastapi.status.WS_1008_POLICY_VIOLATION,
141
+ reason=f"Invalid image: {image_parts}",
142
+ )
143
+ image_data = image_parts[1]
144
+ base64_image = base64.b64decode(image_data)
145
+ image_file = io.BytesIO(base64_image)
146
+ parsed_prompt.append(PIL.Image.open(image_file))
147
+ case _:
148
+ raise ValueError("Unknown content type")
149
+ logger.info("Prompt: %s", parsed_prompt)
150
+ partial_outputs = []
151
+ final_contents: list[Content] = []
152
+
153
+ match generate_message.message_type:
154
+ case WSMessageType.GENERATE_TEXT:
155
+ output_generator = generator.generate_text_streaming
156
+ case WSMessageType.GENERATE_IMAGE:
157
+ output_generator = generator.generate_image_streaming
158
+ case WSMessageType.GENERATE_MULTIMODAL:
159
+ output_generator = generator.generate_multimodal_streaming
160
+ case _:
161
+ raise WebSocketException(
162
+ code=fastapi.status.WS_1008_POLICY_VIOLATION,
163
+ reason="Unknown message type",
164
+ )
165
+
166
+ logger.info(
167
+ "Acquiring lock for client %s generation with options: %s",
168
+ client_id,
169
+ generate_message.options,
170
+ )
171
+ option_args = generate_message.options.dict()
172
+ debug_info = {"hostname": hostname}
173
+ del option_args["message_type"]
174
+ output_generator = partial(
175
+ output_generator,
176
+ **option_args,
177
+ debug=debug_info,
178
+ )
179
+ if redis_lock is not None:
180
+ await async_acquire_lock(
181
+ websocket=websocket,
182
+ lock=redis_lock,
183
+ hostname=hostname,
184
+ counter=queue_counter,
185
+ )
186
+ await redis_client.set("has_lock", client_id)
187
+
188
+ logger.info(
189
+ "Starting locked generation for client %s with options: %s",
190
+ client_id,
191
+ generate_message.options,
192
+ )
193
+ try:
194
+ async for output_token in output_generator(parsed_prompt):
195
+ if isinstance(output_token, str):
196
+ content_type = ContentType.TEXT
197
+ content = output_token
198
+ message_type = WSMessageType.PARTIAL_OUTPUT
199
+ options = NoOptionsForPartial()
200
+ partial_outputs.extend(output_token)
201
+ elif isinstance(output_token, StreamingImage):
202
+ content_type = ContentType.IMAGE
203
+ image = output_token.image
204
+ img_io = io.BytesIO()
205
+ image.save(img_io, format="png")
206
+ content = (
207
+ "data:image/png;base64,"
208
+ + base64.b64encode(img_io.getvalue()).decode()
209
+ )
210
+ if output_token.final:
211
+ message_type = WSMessageType.FULL_OUTPUT
212
+ options = NoOptionsForFull()
213
+ else:
214
+ message_type = WSMessageType.PARTIAL_OUTPUT
215
+ options = NoOptionsForPartial()
216
+
217
+ if output_token.final:
218
+ partial_outputs.append(output_token.image)
219
+ else:
220
+ raise ValueError(f"Invalid output_token: {output_token}")
221
+
222
+ message_content = Content(
223
+ content_type=content_type, content=content
224
+ )
225
+ match content_type:
226
+ case ContentType.TEXT:
227
+ final_contents.append(message_content)
228
+ case ContentType.IMAGE:
229
+ if message_type == WSMessageType.FULL_OUTPUT:
230
+ final_contents.append(message_content)
231
+ case _:
232
+ pass
233
+
234
+ message = WSMultimodalMessage(
235
+ message_type=message_type,
236
+ content=[message_content],
237
+ options=options,
238
+ debug_info=debug_info,
239
+ ).dict()
240
+ await websocket.send_json(message)
241
+ finally:
242
+ if redis_lock is not None:
243
+ logger.info(
244
+ "Attempting release of lock for client %s generation with options: %s",
245
+ client_id,
246
+ generate_message.options,
247
+ )
248
+ owned = await redis_lock.owned()
249
+ if owned:
250
+ await redis_client.set("has_lock", "")
251
+ try:
252
+ await redis_lock.release()
253
+ except async_redis.lock.LockError:
254
+ pass
255
+
256
+ logger.info(
257
+ "Released lock for client %s generation with options: %s",
258
+ client_id,
259
+ generate_message.options,
260
+ )
261
+ await websocket.send_json(
262
+ WSMultimodalMessage(
263
+ message_type=WSMessageType.COMPLETE,
264
+ content=final_contents,
265
+ options=NoOptionsForComplete(),
266
+ debug_info=debug_info,
267
+ ).dict()
268
+ )
269
+ except WebSocketDisconnect:
270
+ logger.info("Client disconnected %s", client_id)
271
+ except ConnectionClosedError:
272
+ logger.info("Client forced a close %s", client_id)
273
+ except ConnectionClosedOK:
274
+ logger.info("Connection closed ok %s", client_id)
275
+ finally:
276
+ if redis_lock is not None:
277
+ logger.info("Checking for client holding lock: %s", client_id)
278
+ owned = await redis_lock.owned()
279
+ if owned:
280
+ try:
281
+ logger.info("Attempted to release owned lock: %s", client_id)
282
+ await redis_lock.release()
283
+ except async_redis.lock.LockError:
284
+ pass
285
+ await redis_client.set("has_lock", "")
286
+
287
+ return app
288
+
289
+
290
+ def serve(
291
+ model: AbstractMultimodalGenerator,
292
+ host: str,
293
+ port: int,
294
+ debug: bool = True,
295
+ redis_port: int | None = None,
296
+ ) -> None:
297
+ app = web_app(model, debug=debug, redis_port=redis_port)
298
+ # TODO: convert this to a subprocess call so enable more
299
+ # uvicorn features like multiple workers
300
+ uvicorn.run(app, host=host, port=port)
chameleon/viewer/backend/requirements.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # If black/isort/pytest change, then update `.circleci/config.yml`
2
+ black==23.7.0
3
+ isort==5.12.0
4
+ pytest==7.4.0
5
+ rich==13.5.*
6
+ ipython
7
+
8
+ # Do not change, python 3.11 needs this
9
+ hydra-core==1.3.2
10
+ typer==0.9.0
11
+ httpx==0.24.1
12
+ pylint==2.17.5
13
+ submitit==1.4.2
14
+ pudb==2022.1.3
15
+
16
+ # These do/should match dependency versions
17
+ # This is so that the viewer can run without any other deps outside of this file
18
+ Pillow==10.0.*
19
+ fastapi==0.101.1
20
+ pydantic==1.10.*
21
+ requests==2.31.*
22
+ uvicorn==0.23.2
23
+ python-multipart==0.0.6
24
+ ruff==0.1.2
25
+ websockets==12.0
26
+ redis[hiredis]==5.0.1
27
+ psutil==5.9.7
28
+
29
+ # For inference
30
+ albumentations==1.3.1
31
+ einops==0.7.0
32
+ pytorch_lightning==2.1.2
33
+ transformers==4.36.2
34
+ xformers==0.0.23
35
+ torchvision==0.16.*
chameleon/viewer/backend/utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Chameleon License found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import types
8
+
9
+ from rich.logging import RichHandler
10
+
11
+
12
+ def configure_rich_logging():
13
+ FORMAT = "%(message)s"
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ handlers=[RichHandler(rich_tracebacks=True)],
17
+ format=FORMAT,
18
+ force=True,
19
+ )
20
+
21
+
22
+ configure_rich_logging()
23
+
24
+
25
+ def get_logger(module: types.ModuleType) -> logging.Logger:
26
+ """This forces logging.basicConfig to be called first."""
27
+ logger = logging.getLogger(module)
28
+ return logger
chameleon/viewer/frontend/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install
2
+
3
+ ```
4
+ npm install
5
+ ```
6
+
7
+
8
+ # Run local
9
+ ```
10
+ npm run dev
11
+ ```
chameleon/viewer/frontend/index.html ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- Copyright (c) Meta Platforms, Inc. and affiliates. -->
2
+
3
+ <!-- This source code is licensed under the Chameleon License found in the -->
4
+ <!-- LICENSE file in the root directory of this source tree. -->
5
+ <!doctype html>
6
+ <html lang="en">
7
+ <head>
8
+ <meta charset="UTF-8" />
9
+ <link rel="icon" type="image/svg+xml" href="/vite.svg" />
10
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
11
+ <title>Chameleon Viewer</title>
12
+ </head>
13
+ <body>
14
+ <div id="root"></div>
15
+ <script type="module" src="/src/main.tsx"></script>
16
+ </body>
17
+ </html>
chameleon/viewer/frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
chameleon/viewer/frontend/package.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "chameleon-frontend",
3
+ "private": true,
4
+ "version": "0.0.0",
5
+ "type": "module",
6
+ "scripts": {
7
+ "dev": "vite --host 0.0.0.0 --port 7654",
8
+ "staging": "vite --mode staging --host 0.0.0.0",
9
+ "datadev": "vite --mode datadev --host 0.0.0.0",
10
+ "check-build": "tsc && vite build",
11
+ "build": "vite build",
12
+ "lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
13
+ "preview": "vite preview",
14
+ "check-format": "prettier --check src",
15
+ "format": "prettier --write src",
16
+ "test": "vitest"
17
+ },
18
+ "dependencies": {
19
+ "@carbon/icons-react": "^11.25.0",
20
+ "@lexical/react": "^0.12.2",
21
+ "axios": "^1.4.0",
22
+ "lexical": "^0.12.2",
23
+ "prettier": "^3.0.3",
24
+ "react": "^18.2.0",
25
+ "react-cookie": "^6.1.1",
26
+ "react-daisyui": "^4.1.0",
27
+ "react-dnd": "^16.0.1",
28
+ "react-dnd-html5-backend": "^16.0.1",
29
+ "react-dom": "^18.2.0",
30
+ "react-dropzone": "^14.2.3",
31
+ "react-hotkeys-hook": "^4.4.1",
32
+ "react-markdown": "^9.0.1",
33
+ "react-router-dom": "^6.15.0",
34
+ "react-use-websocket": "^4.5.0",
35
+ "react18-json-view": "^0.2.4",
36
+ "remark-gfm": "^4.0.0",
37
+ "unique-username-generator": "^1.2.0",
38
+ "ws": "^8.14.2",
39
+ "zod": "^3.22.2",
40
+ "zustand": "^4.4.1"
41
+ },
42
+ "devDependencies": {
43
+ "@tailwindcss/typography": "^0.5.9",
44
+ "@types/react": "^18.2.15",
45
+ "@types/react-dom": "^18.2.7",
46
+ "@types/ws": "^8.5.9",
47
+ "@typescript-eslint/eslint-plugin": "^6.0.0",
48
+ "@typescript-eslint/parser": "^6.0.0",
49
+ "@vitejs/plugin-react": "^4.0.3",
50
+ "autoprefixer": "^10.4.15",
51
+ "daisyui": "^3.9.2",
52
+ "eslint": "^8.45.0",
53
+ "eslint-plugin-react-hooks": "^4.6.0",
54
+ "eslint-plugin-react-refresh": "^0.4.3",
55
+ "postcss": "^8.4.28",
56
+ "prettier": "^3.0.3",
57
+ "tailwindcss": "^3.3.3",
58
+ "typescript": "^5.0.2",
59
+ "vite": "^4.4.5",
60
+ "vitest": "^0.34.6"
61
+ }
62
+ }
chameleon/viewer/frontend/postcss.config.cjs ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the Chameleon License found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ module.exports = {
9
+ plugins: {
10
+ tailwindcss: {},
11
+ autoprefixer: {},
12
+ },
13
+ };
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_DisplayVF_W_Wght.woff2 ADDED
Binary file (36.5 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff ADDED
Binary file (28.9 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff2 ADDED
Binary file (23.6 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff ADDED
Binary file (28.6 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff2 ADDED
Binary file (23.4 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff ADDED
Binary file (28.7 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff2 ADDED
Binary file (23.6 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_TextVF_W_Wght.woff2 ADDED
Binary file (36.3 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff ADDED
Binary file (28.8 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff2 ADDED
Binary file (23.5 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff ADDED
Binary file (28.7 kB). View file
 
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff2 ADDED
Binary file (23.5 kB). View file