Spaces:
Running
on
Zero
Running
on
Zero
xuefengli
commited on
Commit
•
7362797
1
Parent(s):
e0974a9
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +3 -3
- app.py +77 -0
- chameleon/__init__.py +4 -0
- chameleon/download_data.py +88 -0
- chameleon/inference/__init__.py +4 -0
- chameleon/inference/alignment.py +79 -0
- chameleon/inference/chameleon.py +689 -0
- chameleon/inference/cudagraph.py +85 -0
- chameleon/inference/generation.py +162 -0
- chameleon/inference/image_tokenizer.py +125 -0
- chameleon/inference/loader.py +71 -0
- chameleon/inference/logits_processor.py +336 -0
- chameleon/inference/model_adapter.py +118 -0
- chameleon/inference/stopping_criteria.py +55 -0
- chameleon/inference/token_selector.py +47 -0
- chameleon/inference/transformer.py +421 -0
- chameleon/inference/utils.py +34 -0
- chameleon/inference/vocab.py +123 -0
- chameleon/inference/vqgan.py +675 -0
- chameleon/miniviewer/__init__.py +4 -0
- chameleon/miniviewer/__main__.py +9 -0
- chameleon/miniviewer/miniviewer.html +409 -0
- chameleon/miniviewer/miniviewer.py +254 -0
- chameleon/viewer/backend/__init__.py +4 -0
- chameleon/viewer/backend/data_types.py +90 -0
- chameleon/viewer/backend/model_viewer.py +66 -0
- chameleon/viewer/backend/models/__init__.py +4 -0
- chameleon/viewer/backend/models/abstract_model.py +67 -0
- chameleon/viewer/backend/models/chameleon_distributed.py +827 -0
- chameleon/viewer/backend/models/chameleon_local.py +642 -0
- chameleon/viewer/backend/models/service.py +300 -0
- chameleon/viewer/backend/requirements.txt +35 -0
- chameleon/viewer/backend/utils.py +28 -0
- chameleon/viewer/frontend/README.md +11 -0
- chameleon/viewer/frontend/index.html +17 -0
- chameleon/viewer/frontend/package-lock.json +0 -0
- chameleon/viewer/frontend/package.json +62 -0
- chameleon/viewer/frontend/postcss.config.cjs +13 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_DisplayVF_W_Wght.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_TextVF_W_Wght.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff2 +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff +0 -0
- chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff2 +0 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: Anole
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
title: Anole
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: green
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import subprocess
|
3 |
+
import shutil
|
4 |
+
import gradio as gr
|
5 |
+
from PIL import Image
|
6 |
+
from huggingface_hub import snapshot_download
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Specify the repository ID
|
11 |
+
repo_id = "GAIR/Anole-7b-v0.1"
|
12 |
+
|
13 |
+
if not os.path.exists("./Anole-7b-v0.1"):
|
14 |
+
os.system("git lfs install")
|
15 |
+
os.system("git clone https://huggingface.co/GAIR/Anole-7b-v0.1")
|
16 |
+
|
17 |
+
subprocess.run(["/bin/bash", "install.sh"], capture_output=True, text=True)
|
18 |
+
result = subprocess.run(["/bin/bash", "install.sh"], capture_output=True, text=True)
|
19 |
+
|
20 |
+
@spaces.GPU(duration=90)
|
21 |
+
def text_to_image(instruction):
|
22 |
+
result = subprocess.run(["python", "text2image.py", "-i", instruction, "-b", "1"], capture_output=True, text=True)
|
23 |
+
if result.returncode == 0:
|
24 |
+
return gr.update(value="Image Generated. Check the display below.", visible=True), "outputs/text2image/1.png"
|
25 |
+
else:
|
26 |
+
return "Error: " + result.stderr, None
|
27 |
+
|
28 |
+
@spaces.GPU(duration=150)
|
29 |
+
def text_to_interleaved(instruction):
|
30 |
+
result = subprocess.run(["python", "interleaved_generation.py", "-i", instruction], capture_output=True, text=True)
|
31 |
+
if result.returncode == 0:
|
32 |
+
outputs = [None for i in range(7)]
|
33 |
+
box_index = 0
|
34 |
+
|
35 |
+
# Read the segments.jsonl file
|
36 |
+
with open('./segments.jsonl', 'r') as file:
|
37 |
+
for line in file:
|
38 |
+
line_dict = json.loads(line.strip())
|
39 |
+
if line_dict['type'] == 'text':
|
40 |
+
if box_index % 2 != 0:
|
41 |
+
box_index += 1
|
42 |
+
outputs[box_index] = line_dict['content']
|
43 |
+
elif line_dict['type'] == 'image':
|
44 |
+
if box_index % 2 == 0:
|
45 |
+
box_index += 1
|
46 |
+
outputs[box_index] = Image.open(line_dict['content'])
|
47 |
+
box_index += 1
|
48 |
+
|
49 |
+
return outputs[0], outputs[1], outputs[2], outputs[3], outputs[4], outputs[5], outputs[6]
|
50 |
+
else:
|
51 |
+
return ("Error: " + result.stderr, ) * 7
|
52 |
+
|
53 |
+
# Use Blocks to organize the interfaces side by side
|
54 |
+
with gr.Blocks() as demo:
|
55 |
+
# Create a row to place columns side by side
|
56 |
+
with gr.Row():
|
57 |
+
# First column for Text-to-Image Interface
|
58 |
+
with gr.Column():
|
59 |
+
gr.Interface(
|
60 |
+
fn=text_to_image, # Function to generate cat images
|
61 |
+
inputs=gr.Textbox(label="Enter Instruction for Image Generation"), # Input textbox for user instructions
|
62 |
+
outputs=[gr.Text(label="Status"), gr.Image(label="Generated Image")], # Outputs: status message and generated image
|
63 |
+
title="Anole: Text-to-Image", # Title of the interface
|
64 |
+
description="Generate images based on text instructions. Check https://github.com/GAIR-NLP/anole for more information. Model can be downloaded at: https://huggingface.co/GAIR/Anole-7b-v0.1."
|
65 |
+
)
|
66 |
+
# Second column for Text-to-Interleaved Image-Text Interface
|
67 |
+
with gr.Column():
|
68 |
+
gr.Interface(
|
69 |
+
fn=text_to_interleaved,
|
70 |
+
inputs=gr.Textbox(label="Enter Instruction for Interleaved Content"),
|
71 |
+
outputs=[gr.Text(label="Text Output 1"), gr.Image(label="Image Output 1"), gr.Text(label="Text Output 2"), gr.Image(label="Image Output 2"), gr.Text(label="Text Output 3"), gr.Image(label="Image Output 3"), gr.Text(label="Text Output 4")],
|
72 |
+
title="Anole: Text-to-Interleaved", # Title of the interface
|
73 |
+
description="Generate interleaved text and images based on text instructions. Check https://github.com/GAIR-NLP/anole for more information. Model can be downloaded at: https://huggingface.co/GAIR/Anole-7b-v0.1."
|
74 |
+
)
|
75 |
+
|
76 |
+
# Launch the entire Blocks interface
|
77 |
+
demo.launch()
|
chameleon/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/download_data.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# This software may be used and distributed according to the terms of the Chameleon License Agreement.
|
3 |
+
|
4 |
+
import hashlib
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
|
10 |
+
def download_file(url: str, output_path: Path):
|
11 |
+
print(f"Downloading {output_path}")
|
12 |
+
subprocess.check_call(["wget", "--continue", url, "-O", str(output_path)])
|
13 |
+
|
14 |
+
|
15 |
+
def validate_checksum(folder: Path):
|
16 |
+
chks_parts = (folder / "checklist.chk").read_text().split()
|
17 |
+
for expected_checksum, file in zip(chks_parts[::2], chks_parts[1::2]):
|
18 |
+
file_path = folder / file
|
19 |
+
checksum = hashlib.md5(file_path.read_bytes()).hexdigest()
|
20 |
+
if checksum != expected_checksum:
|
21 |
+
print(f"Checksum mismatch for {file_path}")
|
22 |
+
sys.exit(1)
|
23 |
+
|
24 |
+
|
25 |
+
def download_tokenizer(presigned_url: str, target_folder: Path):
|
26 |
+
tokenizer_folder = target_folder / "tokenizer"
|
27 |
+
tokenizer_folder.mkdir(parents=True, exist_ok=True)
|
28 |
+
|
29 |
+
for filename in [
|
30 |
+
"text_tokenizer.json",
|
31 |
+
"vqgan.ckpt",
|
32 |
+
"vqgan.yaml",
|
33 |
+
"checklist.chk",
|
34 |
+
]:
|
35 |
+
download_file(
|
36 |
+
presigned_url.replace("*", f"tokenizer/{filename}"),
|
37 |
+
tokenizer_folder / filename,
|
38 |
+
)
|
39 |
+
|
40 |
+
validate_checksum(tokenizer_folder)
|
41 |
+
|
42 |
+
|
43 |
+
def download_model(presigned_url: str, target_folder: Path, model: str):
|
44 |
+
model_folder = target_folder / "models" / model
|
45 |
+
model_folder.mkdir(parents=True, exist_ok=True)
|
46 |
+
|
47 |
+
download_filenames = ["params.json", "consolidate_params.json", "checklist.chk"]
|
48 |
+
|
49 |
+
if model == "7b":
|
50 |
+
download_filenames += ["consolidated.pth"]
|
51 |
+
elif model == "30b":
|
52 |
+
download_filenames += [f"consolidated.{i:02}.pth" for i in range(4)]
|
53 |
+
else:
|
54 |
+
print(f"Unknown model: {model}")
|
55 |
+
sys.exit(1)
|
56 |
+
|
57 |
+
for filename in download_filenames:
|
58 |
+
download_file(
|
59 |
+
presigned_url.replace("*", f"{model}/{filename}"),
|
60 |
+
model_folder / filename,
|
61 |
+
)
|
62 |
+
|
63 |
+
validate_checksum(model_folder)
|
64 |
+
|
65 |
+
|
66 |
+
def main():
|
67 |
+
presigned_url = (
|
68 |
+
sys.argv[1] if len(sys.argv) > 1 else input("Enter the URL from email: ")
|
69 |
+
)
|
70 |
+
|
71 |
+
target_folder = Path("./data")
|
72 |
+
target_folder.mkdir(parents=True, exist_ok=True)
|
73 |
+
|
74 |
+
download_tokenizer(presigned_url, target_folder)
|
75 |
+
|
76 |
+
model_size = input(
|
77 |
+
"Enter the list of models to download without spaces (7B,30B), or press Enter for all: "
|
78 |
+
)
|
79 |
+
if not model_size:
|
80 |
+
model_size = "7B,30B"
|
81 |
+
|
82 |
+
for model in model_size.split(","):
|
83 |
+
model = model.strip().lower()
|
84 |
+
download_model(presigned_url, target_folder, model)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
main()
|
chameleon/inference/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/inference/alignment.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class PromptAlignment(ABC):
|
12 |
+
@abstractmethod
|
13 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
14 |
+
...
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
|
18 |
+
...
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def postprocess_inputs(
|
22 |
+
self, inputs: torch.Tensor, original_inputs: torch.Tensor
|
23 |
+
) -> torch.Tensor:
|
24 |
+
...
|
25 |
+
|
26 |
+
|
27 |
+
class AlignPromptRight(PromptAlignment):
|
28 |
+
def __init__(self, pad_id: int):
|
29 |
+
self.pad_id = pad_id
|
30 |
+
|
31 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
32 |
+
return max(len(sublist) for sublist in input_ids)
|
33 |
+
|
34 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor:
|
35 |
+
max_length = max(len(sublist) for sublist in input_ids)
|
36 |
+
return torch.tensor(
|
37 |
+
[
|
38 |
+
([self.pad_id] * (max_length - len(sublist))) + sublist
|
39 |
+
for sublist in input_ids
|
40 |
+
],
|
41 |
+
requires_grad=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
def postprocess_inputs(
|
45 |
+
self,
|
46 |
+
inputs: torch.Tensor,
|
47 |
+
original_inputs: torch.Tensor,
|
48 |
+
) -> torch.Tensor:
|
49 |
+
return inputs
|
50 |
+
|
51 |
+
|
52 |
+
class AlignPromptLeft(PromptAlignment):
|
53 |
+
def __init__(self, pad_id: int = -1):
|
54 |
+
self.pad_id = pad_id
|
55 |
+
|
56 |
+
def start_index(self, input_ids: list[list[int]]) -> int:
|
57 |
+
return min(len(sublist) for sublist in input_ids)
|
58 |
+
|
59 |
+
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor:
|
60 |
+
max_length = max(len(sublist) for sublist in input_ids)
|
61 |
+
return torch.tensor(
|
62 |
+
[
|
63 |
+
sublist + ([self.pad_id] * (max_length - len(sublist)))
|
64 |
+
for sublist in input_ids
|
65 |
+
],
|
66 |
+
requires_grad=False,
|
67 |
+
)
|
68 |
+
|
69 |
+
def postprocess_inputs(
|
70 |
+
self,
|
71 |
+
inputs: torch.Tensor,
|
72 |
+
original_inputs: torch.Tensor,
|
73 |
+
) -> torch.Tensor:
|
74 |
+
max_init_len = original_inputs.shape[1]
|
75 |
+
if inputs.shape[1] <= max_init_len:
|
76 |
+
original_inputs_limited = original_inputs[:, : inputs.shape[1]]
|
77 |
+
mask = original_inputs_limited != self.pad_id
|
78 |
+
inputs[mask] = original_inputs_limited[mask]
|
79 |
+
return inputs
|
chameleon/inference/chameleon.py
ADDED
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import base64
|
7 |
+
import io
|
8 |
+
import json
|
9 |
+
import math
|
10 |
+
import queue
|
11 |
+
import threading
|
12 |
+
from dataclasses import dataclass, field
|
13 |
+
from tqdm import tqdm
|
14 |
+
from enum import Enum
|
15 |
+
from multiprocessing import managers, queues, synchronize
|
16 |
+
from typing import Literal, Union
|
17 |
+
|
18 |
+
import PIL
|
19 |
+
import torch
|
20 |
+
import torch.distributed as dist
|
21 |
+
import torch.multiprocessing as mp
|
22 |
+
from PIL.Image import Image
|
23 |
+
from tokenizers import Tokenizer
|
24 |
+
from transformers import (
|
25 |
+
LogitsProcessor,
|
26 |
+
RepetitionPenaltyLogitsProcessor,
|
27 |
+
TemperatureLogitsWarper,
|
28 |
+
TopPLogitsWarper,
|
29 |
+
enable_full_determinism,
|
30 |
+
)
|
31 |
+
|
32 |
+
from chameleon.inference import loader
|
33 |
+
from chameleon.inference.alignment import AlignPromptRight
|
34 |
+
from chameleon.inference.generation import ChameleonGenerator
|
35 |
+
from chameleon.inference.image_tokenizer import ImageTokenizer
|
36 |
+
from chameleon.inference.logits_processor import (
|
37 |
+
AllowOnlyTokensLogitsProcessor,
|
38 |
+
DisallowTokensAtOrAfterIndexLogitsProcessor,
|
39 |
+
InBatchInstructCFGLogitsProcessor,
|
40 |
+
)
|
41 |
+
from chameleon.inference.model_adapter import ChameleonModelAdapter
|
42 |
+
from chameleon.inference.stopping_criteria import (
|
43 |
+
MaxLengthCriteria,
|
44 |
+
StopOnEOSAfterBatchIndex,
|
45 |
+
)
|
46 |
+
from chameleon.inference.token_selector import (
|
47 |
+
ArgmaxTokenSelector,
|
48 |
+
MultinomialTokenSelector,
|
49 |
+
ReplicatedInputTokenSelector,
|
50 |
+
)
|
51 |
+
from chameleon.inference.transformer import Transformer
|
52 |
+
from chameleon.inference.utils import DynamicGenerator, advance, random_unused_port
|
53 |
+
from chameleon.inference.vocab import VocabInfo, VocabTranslation
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class Options:
|
58 |
+
@dataclass
|
59 |
+
class Text:
|
60 |
+
repetition_penalty: float = 1.2
|
61 |
+
temp: float = 1.0
|
62 |
+
top_p: float = 0.9
|
63 |
+
greedy: bool = False
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class Image:
|
67 |
+
@dataclass
|
68 |
+
class CFG:
|
69 |
+
guidance_scale_text: float = 3.0
|
70 |
+
guidance_scale_image: float = 1.2
|
71 |
+
|
72 |
+
cfg: CFG = field(default_factory=CFG)
|
73 |
+
temp: float = 0.7
|
74 |
+
top_p: float = 0.9
|
75 |
+
greedy: bool = False
|
76 |
+
|
77 |
+
max_seq_len: int = 4096
|
78 |
+
max_gen_len: int = 4096
|
79 |
+
seed: int | None = None
|
80 |
+
txt: Text | bool = True
|
81 |
+
img: Image | bool = True
|
82 |
+
extra_eos_tokens: list[int | str] = field(default_factory=lambda: [])
|
83 |
+
|
84 |
+
def __post_init__(self):
|
85 |
+
if self.txt is True:
|
86 |
+
self.txt = Options.Text()
|
87 |
+
if self.img is True:
|
88 |
+
self.img = Options.Image()
|
89 |
+
|
90 |
+
|
91 |
+
class TokenManager:
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
tokenizer_path: str,
|
95 |
+
vqgan_cfg_path: str,
|
96 |
+
vqgan_ckpt_path: str,
|
97 |
+
device: str | None = None,
|
98 |
+
):
|
99 |
+
self.tokenizer = Tokenizer.from_file(tokenizer_path)
|
100 |
+
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
|
101 |
+
self.translation = VocabTranslation(self.vocab, device=device)
|
102 |
+
self.image_tokenizer = ImageTokenizer(
|
103 |
+
cfg_path=vqgan_cfg_path, ckpt_path=vqgan_ckpt_path, device=device
|
104 |
+
)
|
105 |
+
|
106 |
+
def pil_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image:
|
107 |
+
image_tensor = self.translation.convert_bpe2img(bpe_tokens)
|
108 |
+
if image_tensor.shape[0] < 1024:
|
109 |
+
padding = (
|
110 |
+
torch.ones(
|
111 |
+
[1024 - image_tensor.shape[0]],
|
112 |
+
dtype=int,
|
113 |
+
device=image_tensor.device,
|
114 |
+
)
|
115 |
+
* image_tensor[0]
|
116 |
+
)
|
117 |
+
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
|
118 |
+
|
119 |
+
return self.image_tokenizer.pil_from_img_toks(image_tensor)
|
120 |
+
|
121 |
+
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
|
122 |
+
pil = self.pil_from_bpe_tokens(bpe_tokens)
|
123 |
+
img_io = io.BytesIO()
|
124 |
+
pil.save(img_io, format="PNG")
|
125 |
+
return img_io.getvalue()
|
126 |
+
|
127 |
+
def tokenize_text(self, text: str) -> list[int]:
|
128 |
+
return self.tokenizer.encode(text).ids
|
129 |
+
|
130 |
+
def tokenize_image(self, img: Image) -> list[int]:
|
131 |
+
return (
|
132 |
+
[self.vocab.begin_image]
|
133 |
+
+ self.translation.convert_img2bp2(
|
134 |
+
self.image_tokenizer.img_tokens_from_pil(img) # [0 : 8191], vqgan codebook ids
|
135 |
+
).tolist()
|
136 |
+
+ [self.vocab.end_image]
|
137 |
+
)
|
138 |
+
|
139 |
+
def tokenize_b64img(self, b64img: str) -> list[int]:
|
140 |
+
image_data = base64.b64decode(b64img)
|
141 |
+
image_file = io.BytesIO(image_data)
|
142 |
+
return self.tokenize_image(PIL.Image.open(image_file))
|
143 |
+
|
144 |
+
def tokens_from_ui(self, inputs: list[dict]) -> list[int]:
|
145 |
+
tokens = [self.vocab.bos_id]
|
146 |
+
for input_ in inputs:
|
147 |
+
if input_["type"] == "text":
|
148 |
+
tokens += self.tokenize_text(input_["value"])
|
149 |
+
elif input_["type"] == "image":
|
150 |
+
if isinstance(input_["value"], str):
|
151 |
+
if input_["value"].startswith("data:"):
|
152 |
+
# Value Format: 'data:image/[^;]+;base64,[A-Za-z0-9+/]+={0,2}'
|
153 |
+
tokens += self.tokenize_b64img(input_["value"].split(",", 1)[1])
|
154 |
+
elif input_["value"].startswith("file:"):
|
155 |
+
tokens += self.tokenize_image(
|
156 |
+
PIL.Image.open(input_["value"].split(":", 1)[1])
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
raise ValueError("Unknown image format.")
|
160 |
+
elif isinstance(input_["value"], Image):
|
161 |
+
tokens += self.tokenize_image(input_["value"])
|
162 |
+
else:
|
163 |
+
raise ValueError("Unknown image type.")
|
164 |
+
elif input_["type"] == "sentinel":
|
165 |
+
tokens += [
|
166 |
+
{
|
167 |
+
"<START-OF-IMAGE>": self.vocab.begin_image,
|
168 |
+
"<END-OF-TURN>": self.vocab.eot_id,
|
169 |
+
}[input_["value"]]
|
170 |
+
]
|
171 |
+
elif input_["type"] == "ids":
|
172 |
+
tokens += input_["value"]
|
173 |
+
else:
|
174 |
+
raise ValueError("Unknown input type.")
|
175 |
+
return tokens
|
176 |
+
|
177 |
+
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
|
178 |
+
if isinstance(ids, torch.Tensor):
|
179 |
+
ids = ids.tolist()
|
180 |
+
|
181 |
+
for row, values in enumerate(ids):
|
182 |
+
try:
|
183 |
+
ids[row] = values[: values.index(self.vocab.eos_id)]
|
184 |
+
except ValueError:
|
185 |
+
pass
|
186 |
+
|
187 |
+
return self.tokenizer.decode_batch(ids)
|
188 |
+
|
189 |
+
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
|
190 |
+
return [self.pil_from_bpe_tokens(sample) for sample in ids]
|
191 |
+
|
192 |
+
|
193 |
+
@dataclass
|
194 |
+
class DecodePiece:
|
195 |
+
token: ChameleonGenerator.Token
|
196 |
+
next_decoder: type["Decoder"] | None
|
197 |
+
|
198 |
+
|
199 |
+
class Decoder:
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
model: Transformer,
|
203 |
+
vocab: VocabInfo,
|
204 |
+
options: Options,
|
205 |
+
input_ids: list[int],
|
206 |
+
): ...
|
207 |
+
|
208 |
+
def __next__(self) -> DecodePiece: ...
|
209 |
+
|
210 |
+
|
211 |
+
class TextDecoder(Decoder):
|
212 |
+
def __init__(
|
213 |
+
self,
|
214 |
+
model: Transformer,
|
215 |
+
vocab: VocabInfo,
|
216 |
+
options: Options,
|
217 |
+
input_ids: list[list[int]],
|
218 |
+
):
|
219 |
+
self.vocab = vocab
|
220 |
+
self.options = options
|
221 |
+
assert vocab.eos_id is not None
|
222 |
+
|
223 |
+
prompt_lens = [len(inp) for inp in input_ids]
|
224 |
+
max_prompt_len = max(prompt_lens)
|
225 |
+
max_seq_len = min(options.max_seq_len, max_prompt_len + options.max_gen_len)
|
226 |
+
|
227 |
+
self.eos_ids = [vocab.eos_id]
|
228 |
+
for extra_eos_token in options.extra_eos_tokens:
|
229 |
+
if isinstance(extra_eos_token, str):
|
230 |
+
extra_eos_token = vocab.name2val[extra_eos_token]
|
231 |
+
assert isinstance(extra_eos_token, int)
|
232 |
+
self.eos_ids.append(extra_eos_token)
|
233 |
+
|
234 |
+
stopping_criteria = [
|
235 |
+
MaxLengthCriteria(max_seq_len),
|
236 |
+
] + [StopOnEOSAfterBatchIndex(eos_id, [max_prompt_len] * len(prompt_lens)) for eos_id in self.eos_ids]
|
237 |
+
|
238 |
+
self.gen = ChameleonGenerator(
|
239 |
+
model=ChameleonModelAdapter(model, max_seq_len=max_seq_len),
|
240 |
+
input_ids=input_ids,
|
241 |
+
stopping_criteria=stopping_criteria,
|
242 |
+
logits_processors=self._logits_processors(),
|
243 |
+
alignment=AlignPromptRight(vocab.pad_id),
|
244 |
+
token_selector=(
|
245 |
+
ArgmaxTokenSelector()
|
246 |
+
if options.txt.greedy
|
247 |
+
else MultinomialTokenSelector()
|
248 |
+
),
|
249 |
+
)
|
250 |
+
advance(self.gen, max_prompt_len)
|
251 |
+
|
252 |
+
def _allowed_tokens(self) -> list[int]:
|
253 |
+
allowed_tokens = [self.vocab.eos_id]
|
254 |
+
if self.options.txt:
|
255 |
+
allowed_tokens += self.vocab.text_tokens
|
256 |
+
if self.options.img:
|
257 |
+
allowed_tokens += [self.vocab.begin_image]
|
258 |
+
return allowed_tokens
|
259 |
+
|
260 |
+
def _logits_processors(self) -> list[LogitsProcessor]:
|
261 |
+
logits_processors = [
|
262 |
+
AllowOnlyTokensLogitsProcessor(self._allowed_tokens()),
|
263 |
+
]
|
264 |
+
if isinstance(self.options.img, Options.Image):
|
265 |
+
logits_processors += [
|
266 |
+
DisallowTokensAtOrAfterIndexLogitsProcessor(
|
267 |
+
[self.vocab.begin_image],
|
268 |
+
self.options.max_seq_len - 1026,
|
269 |
+
),
|
270 |
+
]
|
271 |
+
if isinstance(self.options.txt, Options.Text):
|
272 |
+
logits_processors += [
|
273 |
+
RepetitionPenaltyLogitsProcessor(self.options.txt.repetition_penalty),
|
274 |
+
TemperatureLogitsWarper(self.options.txt.temp),
|
275 |
+
TopPLogitsWarper(self.options.txt.top_p),
|
276 |
+
]
|
277 |
+
return logits_processors
|
278 |
+
|
279 |
+
def __next__(self) -> DecodePiece:
|
280 |
+
tok = next(self.gen)
|
281 |
+
next_decoder = None
|
282 |
+
if (
|
283 |
+
self.vocab.begin_image not in self.eos_ids
|
284 |
+
and (tok.id == self.vocab.begin_image).all()
|
285 |
+
):
|
286 |
+
next_decoder = ImageDecoder
|
287 |
+
return DecodePiece(tok, next_decoder)
|
288 |
+
|
289 |
+
|
290 |
+
class ImageDecoder(Decoder):
|
291 |
+
def __init__(
|
292 |
+
self,
|
293 |
+
model: Transformer,
|
294 |
+
vocab: VocabInfo,
|
295 |
+
options: Options,
|
296 |
+
input_ids: list[list[int]],
|
297 |
+
):
|
298 |
+
assert isinstance(options.img, Options.Image)
|
299 |
+
self.vocab = vocab
|
300 |
+
self.options = options
|
301 |
+
self.batch_size = len(input_ids)
|
302 |
+
logits_processors = [
|
303 |
+
InBatchInstructCFGLogitsProcessor(
|
304 |
+
options.img.cfg.guidance_scale_text,
|
305 |
+
options.img.cfg.guidance_scale_image,
|
306 |
+
),
|
307 |
+
AllowOnlyTokensLogitsProcessor(vocab.image_tokens),
|
308 |
+
TemperatureLogitsWarper(options.img.temp),
|
309 |
+
TopPLogitsWarper(options.img.top_p),
|
310 |
+
]
|
311 |
+
|
312 |
+
for inp in input_ids:
|
313 |
+
if inp[-1] != self.vocab.begin_image:
|
314 |
+
inp.append(self.vocab.begin_image)
|
315 |
+
|
316 |
+
max_prompt_len = max(len(inp) for inp in input_ids)
|
317 |
+
self.gen = ChameleonGenerator(
|
318 |
+
model=ChameleonModelAdapter(model, max_seq_len=max_prompt_len + 1024),
|
319 |
+
input_ids=self._split_inputs_for_cfg(input_ids),
|
320 |
+
logits_processors=logits_processors,
|
321 |
+
alignment=AlignPromptRight(vocab.pad_id),
|
322 |
+
token_selector=ReplicatedInputTokenSelector(
|
323 |
+
(
|
324 |
+
ArgmaxTokenSelector()
|
325 |
+
if options.img.greedy
|
326 |
+
else MultinomialTokenSelector()
|
327 |
+
),
|
328 |
+
n=3,
|
329 |
+
),
|
330 |
+
)
|
331 |
+
advance(self.gen, max_prompt_len)
|
332 |
+
self.gen_count = 0
|
333 |
+
|
334 |
+
def _split_inputs_for_cfg(self, input_ids: list[list[int]]) -> list[list[int]]:
|
335 |
+
image_conditioned_allowed = set(self.vocab.image_tokens) | {
|
336 |
+
self.vocab.bos_id,
|
337 |
+
self.vocab.begin_image,
|
338 |
+
self.vocab.end_image,
|
339 |
+
}
|
340 |
+
|
341 |
+
full_conditioned = input_ids
|
342 |
+
|
343 |
+
image_conditioned = [
|
344 |
+
[id for id in sample if id in image_conditioned_allowed]
|
345 |
+
for sample in input_ids
|
346 |
+
]
|
347 |
+
|
348 |
+
unconditioned = [
|
349 |
+
[
|
350 |
+
self.vocab.bos_id,
|
351 |
+
self.vocab.begin_image,
|
352 |
+
]
|
353 |
+
] * self.batch_size
|
354 |
+
|
355 |
+
return full_conditioned + image_conditioned + unconditioned
|
356 |
+
|
357 |
+
def __next__(self) -> DecodePiece:
|
358 |
+
if self.gen_count == 1024:
|
359 |
+
id = torch.tensor([self.vocab.end_image] * self.batch_size)
|
360 |
+
logits = torch.full(
|
361 |
+
(self.batch_size, len(self.vocab.all_tokens)), -math.inf
|
362 |
+
)
|
363 |
+
logits[:, self.vocab.end_image] = 0
|
364 |
+
return DecodePiece(
|
365 |
+
ChameleonGenerator.Token(id=id, logits=logits),
|
366 |
+
TextDecoder,
|
367 |
+
)
|
368 |
+
|
369 |
+
tok = next(self.gen)
|
370 |
+
tok.id = tok.id.chunk(3)[0]
|
371 |
+
self.gen_count += 1
|
372 |
+
return DecodePiece(tok, None)
|
373 |
+
|
374 |
+
|
375 |
+
class Generator(Decoder):
|
376 |
+
def __init__(
|
377 |
+
self,
|
378 |
+
model: Transformer,
|
379 |
+
vocab: VocabInfo,
|
380 |
+
options: Options,
|
381 |
+
input_ids: list[list[int]],
|
382 |
+
):
|
383 |
+
if options.seed is not None:
|
384 |
+
enable_full_determinism(options.seed, warn_only=True)
|
385 |
+
|
386 |
+
self.model = model
|
387 |
+
self.vocab = vocab
|
388 |
+
self.input_ids = input_ids[:]
|
389 |
+
self.generated_token_ids: list[torch.LongTensor] = []
|
390 |
+
self.options = options
|
391 |
+
if not self.options.txt:
|
392 |
+
self.dyngen = DynamicGenerator(
|
393 |
+
ImageDecoder(model, vocab, options, input_ids)
|
394 |
+
)
|
395 |
+
else:
|
396 |
+
self.dyngen = DynamicGenerator(
|
397 |
+
TextDecoder(model, vocab, options, input_ids)
|
398 |
+
)
|
399 |
+
|
400 |
+
def __iter__(self):
|
401 |
+
return self
|
402 |
+
|
403 |
+
def __next__(self) -> ChameleonGenerator.Token:
|
404 |
+
piece = next(self.dyngen)
|
405 |
+
self.generated_token_ids.append(piece.token.id)
|
406 |
+
if piece.next_decoder is not None:
|
407 |
+
if not self.options.txt:
|
408 |
+
raise StopIteration
|
409 |
+
|
410 |
+
self.input_ids = [
|
411 |
+
old_list + generated
|
412 |
+
for old_list, generated in zip(
|
413 |
+
self.input_ids, torch.stack(self.generated_token_ids).T.tolist()
|
414 |
+
)
|
415 |
+
]
|
416 |
+
self.generated_token_ids = []
|
417 |
+
self.dyngen.gen = piece.next_decoder(
|
418 |
+
self.model,
|
419 |
+
self.vocab,
|
420 |
+
self.options,
|
421 |
+
self.input_ids,
|
422 |
+
)
|
423 |
+
return piece.token
|
424 |
+
|
425 |
+
|
426 |
+
class DistributedMode(Enum):
|
427 |
+
AUTO = 0
|
428 |
+
THREAD = 1
|
429 |
+
PROCESS = 2
|
430 |
+
|
431 |
+
|
432 |
+
@dataclass
|
433 |
+
class _DistributedContext:
|
434 |
+
req_q: Union[queue.Queue, queues.Queue]
|
435 |
+
res_q: Union[queue.Queue, queues.Queue]
|
436 |
+
active_key: Union[dict[int, Literal[True]], managers.DictProxy]
|
437 |
+
active_key_lock: Union[threading.Lock, synchronize.Lock]
|
438 |
+
ready_barrier: Union[threading.Barrier, synchronize.Barrier]
|
439 |
+
worker_launcher: Union[type[threading.Thread], type[mp.Process]]
|
440 |
+
|
441 |
+
@staticmethod
|
442 |
+
def make_for_threading(world_size: int):
|
443 |
+
return _DistributedContext(
|
444 |
+
req_q=queue.Queue(),
|
445 |
+
res_q=queue.Queue(),
|
446 |
+
active_key={},
|
447 |
+
active_key_lock=threading.Lock(),
|
448 |
+
ready_barrier=threading.Barrier(world_size + 1),
|
449 |
+
worker_launcher=threading.Thread,
|
450 |
+
)
|
451 |
+
|
452 |
+
@staticmethod
|
453 |
+
def make_for_multiprocessing(world_size: int):
|
454 |
+
local_mp = mp.get_context("spawn")
|
455 |
+
return _DistributedContext(
|
456 |
+
req_q=local_mp.Queue(),
|
457 |
+
res_q=local_mp.Queue(),
|
458 |
+
active_key=local_mp.Manager().dict(),
|
459 |
+
active_key_lock=local_mp.Lock(),
|
460 |
+
ready_barrier=local_mp.Barrier(world_size + 1),
|
461 |
+
worker_launcher=local_mp.Process,
|
462 |
+
)
|
463 |
+
|
464 |
+
@staticmethod
|
465 |
+
def make(mode: DistributedMode, world_size: int):
|
466 |
+
if mode == DistributedMode.AUTO:
|
467 |
+
mode = DistributedMode.PROCESS
|
468 |
+
|
469 |
+
if mode == DistributedMode.THREAD:
|
470 |
+
return _DistributedContext.make_for_threading(world_size)
|
471 |
+
elif mode == DistributedMode.PROCESS:
|
472 |
+
return _DistributedContext.make_for_multiprocessing(world_size)
|
473 |
+
else:
|
474 |
+
raise ValueError("Unknown DistributedMode")
|
475 |
+
|
476 |
+
|
477 |
+
def _worker_impl(
|
478 |
+
init_method: str,
|
479 |
+
model: Transformer | str,
|
480 |
+
world_size: int,
|
481 |
+
rank: int,
|
482 |
+
vocab: VocabInfo,
|
483 |
+
dctx: _DistributedContext,
|
484 |
+
):
|
485 |
+
dist.init_process_group(
|
486 |
+
"nccl",
|
487 |
+
init_method=init_method,
|
488 |
+
world_size=world_size,
|
489 |
+
rank=rank,
|
490 |
+
)
|
491 |
+
|
492 |
+
torch.set_default_device(f"cuda:{rank}")
|
493 |
+
torch.cuda.set_device(rank)
|
494 |
+
if isinstance(model, str):
|
495 |
+
model = loader.load_model(model, rank=rank)
|
496 |
+
dctx.ready_barrier.wait()
|
497 |
+
|
498 |
+
is_coord = rank == 0
|
499 |
+
|
500 |
+
while True:
|
501 |
+
req = [Options(), [], 0, False]
|
502 |
+
if is_coord:
|
503 |
+
req = dctx.req_q.get()
|
504 |
+
|
505 |
+
dist.broadcast_object_list(req, src=0)
|
506 |
+
options, input_ids, key, shutdown = req
|
507 |
+
if shutdown:
|
508 |
+
break
|
509 |
+
|
510 |
+
for token in Generator(
|
511 |
+
model=model,
|
512 |
+
vocab=vocab,
|
513 |
+
options=options,
|
514 |
+
input_ids=input_ids,
|
515 |
+
):
|
516 |
+
if is_coord:
|
517 |
+
dctx.res_q.put((key, token))
|
518 |
+
|
519 |
+
to_continue = [True]
|
520 |
+
if is_coord:
|
521 |
+
with dctx.active_key_lock:
|
522 |
+
to_continue = [key in dctx.active_key]
|
523 |
+
dist.broadcast_object_list(to_continue, src=0)
|
524 |
+
if not to_continue[0]:
|
525 |
+
break
|
526 |
+
|
527 |
+
if is_coord:
|
528 |
+
dctx.res_q.put((key, None))
|
529 |
+
|
530 |
+
|
531 |
+
class ChameleonInferenceModel:
|
532 |
+
def __init__(
|
533 |
+
self,
|
534 |
+
model: Transformer | str,
|
535 |
+
tokenizer_path: str,
|
536 |
+
vqgan_cfg_path: str,
|
537 |
+
vqgan_ckpt_path: str,
|
538 |
+
*,
|
539 |
+
options: Options | None = None,
|
540 |
+
distributed_mode: DistributedMode = DistributedMode.AUTO,
|
541 |
+
):
|
542 |
+
self.options = options or Options()
|
543 |
+
self.next_key = 0
|
544 |
+
|
545 |
+
self.token_manager = TokenManager(
|
546 |
+
tokenizer_path=tokenizer_path,
|
547 |
+
vqgan_cfg_path=vqgan_cfg_path,
|
548 |
+
vqgan_ckpt_path=vqgan_ckpt_path,
|
549 |
+
device="cuda",
|
550 |
+
)
|
551 |
+
self.vocab = self.token_manager.vocab
|
552 |
+
|
553 |
+
world_size = 1
|
554 |
+
if isinstance(model, str):
|
555 |
+
world_size = loader.detect_shard_count(model)
|
556 |
+
self.dctx = _DistributedContext.make(distributed_mode, world_size)
|
557 |
+
|
558 |
+
init_method = f"tcp://0.0.0.0:{random_unused_port()}"
|
559 |
+
self.workers = [
|
560 |
+
self.dctx.worker_launcher(
|
561 |
+
target=_worker_impl,
|
562 |
+
args=(init_method, model, world_size, i, self.vocab, self.dctx),
|
563 |
+
daemon=True,
|
564 |
+
)
|
565 |
+
for i in range(world_size)
|
566 |
+
]
|
567 |
+
for w in self.workers:
|
568 |
+
w.start()
|
569 |
+
self.dctx.ready_barrier.wait()
|
570 |
+
|
571 |
+
def __del__(self):
|
572 |
+
try:
|
573 |
+
with self.dctx.active_key_lock:
|
574 |
+
self.dctx.active_key.clear()
|
575 |
+
self.dctx.req_q.put([None, None, None, True])
|
576 |
+
for w in self.workers:
|
577 |
+
w.join()
|
578 |
+
except FileNotFoundError:
|
579 |
+
pass
|
580 |
+
|
581 |
+
def stream(
|
582 |
+
self,
|
583 |
+
*,
|
584 |
+
input_ids: list[int] | None = None,
|
585 |
+
prompt_text: str | None = None,
|
586 |
+
prompt_ui: list[dict] | None = None,
|
587 |
+
batch_input_ids: list[list[int]] | None = None,
|
588 |
+
batch_prompt_text: list[str] | None = None,
|
589 |
+
batch_prompt_ui: list[list[dict]] | None = None,
|
590 |
+
options: Options | None = None,
|
591 |
+
):
|
592 |
+
# NOTE: Not thread-safe! Only one instance of generate may be run at a time.
|
593 |
+
|
594 |
+
if (
|
595 |
+
sum(
|
596 |
+
x is not None
|
597 |
+
for x in [
|
598 |
+
input_ids,
|
599 |
+
prompt_text,
|
600 |
+
prompt_ui,
|
601 |
+
batch_input_ids,
|
602 |
+
batch_prompt_text,
|
603 |
+
batch_prompt_ui,
|
604 |
+
]
|
605 |
+
)
|
606 |
+
!= 1
|
607 |
+
):
|
608 |
+
raise ValueError(
|
609 |
+
"Must specify exactly one of: input_ids, prompt_text, prompt_ui, batch_input_ids, batch_prompt_text, batch_prompt_ui"
|
610 |
+
)
|
611 |
+
|
612 |
+
options = options or self.options
|
613 |
+
|
614 |
+
if prompt_text is not None:
|
615 |
+
batch_prompt_text = [prompt_text]
|
616 |
+
if prompt_ui is not None:
|
617 |
+
batch_prompt_ui = [prompt_ui]
|
618 |
+
if input_ids is not None:
|
619 |
+
batch_input_ids = [input_ids]
|
620 |
+
if batch_prompt_text is not None:
|
621 |
+
batch_prompt_ui = [
|
622 |
+
[{"type": "text", "value": prompt_text}]
|
623 |
+
for prompt_text in batch_prompt_text
|
624 |
+
]
|
625 |
+
if batch_prompt_ui is not None:
|
626 |
+
batch_input_ids = [
|
627 |
+
self.token_manager.tokens_from_ui(prompt_ui)
|
628 |
+
for prompt_ui in batch_prompt_ui
|
629 |
+
]
|
630 |
+
|
631 |
+
assert batch_input_ids
|
632 |
+
|
633 |
+
if not options.txt and not options.img:
|
634 |
+
raise ValueError("Must specify at least one modality.")
|
635 |
+
if options.txt and options.img and len(batch_input_ids) > 1:
|
636 |
+
raise ValueError(
|
637 |
+
"Batch generation only supported for one modality at a time."
|
638 |
+
)
|
639 |
+
|
640 |
+
req_key = self.next_key
|
641 |
+
self.next_key += 1
|
642 |
+
|
643 |
+
with self.dctx.active_key_lock:
|
644 |
+
self.dctx.active_key[req_key] = True
|
645 |
+
|
646 |
+
self.dctx.req_q.put([options, batch_input_ids, req_key, False])
|
647 |
+
|
648 |
+
try:
|
649 |
+
while key_token := self.dctx.res_q.get():
|
650 |
+
key, token = key_token
|
651 |
+
if key != req_key:
|
652 |
+
# Residual from prior calls to generation. Skip.
|
653 |
+
continue
|
654 |
+
if token is None:
|
655 |
+
break
|
656 |
+
yield token
|
657 |
+
finally:
|
658 |
+
with self.dctx.active_key_lock:
|
659 |
+
del self.dctx.active_key[req_key]
|
660 |
+
|
661 |
+
def step(self, *args, **kwargs) -> ChameleonGenerator.Token:
|
662 |
+
return next(self.stream(*args, **kwargs))
|
663 |
+
|
664 |
+
def generate(self, *args, **kwargs) -> torch.LongTensor:
|
665 |
+
tokens = [t.id for t in self.stream(*args, **kwargs)]
|
666 |
+
if not tokens:
|
667 |
+
return torch.LongTensor()
|
668 |
+
return torch.stack(tokens).T
|
669 |
+
|
670 |
+
def decode_text(self, ids: torch.LongTensor | list[list[int]]) -> list[str]:
|
671 |
+
return self.token_manager.decode_text(ids)
|
672 |
+
|
673 |
+
def decode_image(self, ids: torch.LongTensor) -> list[PIL.Image]:
|
674 |
+
return self.token_manager.decode_image(ids)
|
675 |
+
|
676 |
+
def sft_tokenization(self, json_path: str) -> list[dict]:
|
677 |
+
with open(json_path, 'r') as input_file:
|
678 |
+
jsonl_input = [json.loads(line) for line in input_file]
|
679 |
+
|
680 |
+
output_data = []
|
681 |
+
for entry in tqdm(jsonl_input, desc="Tokenize dataset"):
|
682 |
+
# print(i)
|
683 |
+
text_tokens = self.token_manager.tokenize_text(entry['text'])
|
684 |
+
image_tokens = self.token_manager.tokenize_image(PIL.Image.open(entry['image']))
|
685 |
+
entry['text_tokens'] = text_tokens
|
686 |
+
entry['image_tokens'] = image_tokens
|
687 |
+
output_data.append(entry)
|
688 |
+
|
689 |
+
return output_data
|
chameleon/inference/cudagraph.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import functools
|
7 |
+
from typing import Any, Callable, TypeVar
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
T = TypeVar("T")
|
12 |
+
FN = Callable[..., T] # type: ignore
|
13 |
+
|
14 |
+
|
15 |
+
class CUDAGraphWrapper:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
fn: FN[T],
|
19 |
+
warmup_iter: int = 1,
|
20 |
+
debug_dump_path: str | None = None,
|
21 |
+
):
|
22 |
+
self.fn = fn
|
23 |
+
self.warmup_iter = warmup_iter
|
24 |
+
self.debug_dump_path = debug_dump_path
|
25 |
+
self.graph: torch.cuda.CUDAGraph | None = None
|
26 |
+
self.result: T | None = None
|
27 |
+
|
28 |
+
def __call__(self, *args, **kwargs) -> Any: # type: ignore
|
29 |
+
if self.warmup_iter > 0:
|
30 |
+
self.warmup_iter -= 1
|
31 |
+
return self.fn(*args, **kwargs)
|
32 |
+
|
33 |
+
if self.graph is None:
|
34 |
+
self.graph = torch.cuda.CUDAGraph()
|
35 |
+
if self.debug_dump_path is not None:
|
36 |
+
self.graph.enable_debug_mode()
|
37 |
+
recording_kwargs = {}
|
38 |
+
if "capture_error_mode" in torch.cuda.graph.__init__.__annotations__:
|
39 |
+
# In PyTorch 2.1+ and nightlies from late Aug 2023,
|
40 |
+
# we can do this to maybe avoid watchdog-related crashes
|
41 |
+
recording_kwargs["capture_error_mode"] = "thread_local"
|
42 |
+
with torch.cuda.graph(self.graph, **recording_kwargs):
|
43 |
+
self.result = self.fn(*args, **kwargs)
|
44 |
+
torch.cuda.synchronize()
|
45 |
+
if self.debug_dump_path is not None:
|
46 |
+
self.graph.debug_dump(self.debug_dump_path)
|
47 |
+
|
48 |
+
assert self.graph is not None
|
49 |
+
self.graph.replay()
|
50 |
+
return self.result
|
51 |
+
|
52 |
+
|
53 |
+
def cudagraph_wrap(
|
54 |
+
*args,
|
55 |
+
warmup_iter: int = 1,
|
56 |
+
debug_dump_path: str | None = None,
|
57 |
+
) -> Callable[[FN[T]], FN[T]]:
|
58 |
+
def wrapper(fn: FN[T]) -> FN[T]:
|
59 |
+
graph_wrapper = CUDAGraphWrapper(
|
60 |
+
fn, warmup_iter=warmup_iter, debug_dump_path=debug_dump_path
|
61 |
+
)
|
62 |
+
|
63 |
+
@functools.wraps(fn)
|
64 |
+
def call_wrapper(*inner_args, **inner_kwargs):
|
65 |
+
return graph_wrapper(*inner_args, **inner_kwargs)
|
66 |
+
|
67 |
+
return call_wrapper
|
68 |
+
|
69 |
+
# @cudagraph_wrap
|
70 |
+
# def fn(...):
|
71 |
+
# ...
|
72 |
+
#
|
73 |
+
# - or -
|
74 |
+
#
|
75 |
+
# fast_fn = cudagraph_wrap(slow_fn, warmup_iter=2)
|
76 |
+
if len(args) == 1 and callable(args[0]):
|
77 |
+
return wrapper(args[0])
|
78 |
+
|
79 |
+
# @cudagraph_wrap(warmup_iter=3)
|
80 |
+
# def fn(...):
|
81 |
+
# ...
|
82 |
+
def decorator(fn: FN[T]) -> FN[T]:
|
83 |
+
return wrapper(fn)
|
84 |
+
|
85 |
+
return decorator
|
chameleon/inference/generation.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import (
|
10 |
+
LogitsProcessor,
|
11 |
+
LogitsProcessorList,
|
12 |
+
)
|
13 |
+
from transformers.generation.streamers import BaseStreamer
|
14 |
+
|
15 |
+
from chameleon.inference.alignment import AlignPromptLeft, PromptAlignment
|
16 |
+
from chameleon.inference.model_adapter import ModelAdapter
|
17 |
+
from chameleon.inference.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
18 |
+
from chameleon.inference.token_selector import MultinomialTokenSelector, TokenSelector
|
19 |
+
|
20 |
+
|
21 |
+
class ChameleonGenerator:
|
22 |
+
@dataclass
|
23 |
+
class Token:
|
24 |
+
id: torch.LongTensor
|
25 |
+
logits: torch.Tensor | None
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
model: ModelAdapter,
|
30 |
+
input_ids: list[list[int]],
|
31 |
+
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria] | None = None,
|
32 |
+
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
33 |
+
probability_processors: LogitsProcessorList
|
34 |
+
| list[LogitsProcessor]
|
35 |
+
| None = None,
|
36 |
+
token_selector: TokenSelector | None = None,
|
37 |
+
alignment: PromptAlignment = AlignPromptLeft(),
|
38 |
+
):
|
39 |
+
assert model.supports_alignment(alignment)
|
40 |
+
|
41 |
+
self.model = model
|
42 |
+
|
43 |
+
self.stopping_criteria = stopping_criteria
|
44 |
+
self.logits_processors = logits_processors
|
45 |
+
self.probability_processors = probability_processors
|
46 |
+
self.token_selector: TokenSelector = (
|
47 |
+
token_selector or MultinomialTokenSelector()
|
48 |
+
)
|
49 |
+
|
50 |
+
self.alignment = alignment
|
51 |
+
|
52 |
+
self.model.initialize(input_ids)
|
53 |
+
|
54 |
+
self._inputs = self.alignment.prepare_inputs(
|
55 |
+
input_ids
|
56 |
+
) # inputs.shape = [batch, seq-len]
|
57 |
+
|
58 |
+
self._idx = 0
|
59 |
+
self._start_idx = self.alignment.start_index(input_ids)
|
60 |
+
|
61 |
+
self._original_inputs = self._inputs.clone()
|
62 |
+
self._inputs = self._inputs[:, : self._start_idx]
|
63 |
+
|
64 |
+
def __iter__(self):
|
65 |
+
return self
|
66 |
+
|
67 |
+
@torch.inference_mode()
|
68 |
+
def __next__(self) -> Token:
|
69 |
+
# Are we done?
|
70 |
+
if self.stopping_criteria(self._inputs, None):
|
71 |
+
raise StopIteration
|
72 |
+
|
73 |
+
# Emit initial tokens.
|
74 |
+
# Model is not run for these.
|
75 |
+
# If you want the logits, you can do a separate forward pass outside generation.
|
76 |
+
if self._idx < self._start_idx:
|
77 |
+
idx, self._idx = self._idx, self._idx + 1
|
78 |
+
return ChameleonGenerator.Token(id=self._inputs[:, idx], logits=None)
|
79 |
+
|
80 |
+
# Run the model for the next token.
|
81 |
+
self._inputs = self._inputs.contiguous()
|
82 |
+
outputs = self.model(self._inputs) # outputs.shape = [batch, seq-len, vocab]
|
83 |
+
|
84 |
+
# Pull out and process the logits.
|
85 |
+
logits = outputs[:, -1, :] # logits.shape = [batch, vocab]
|
86 |
+
logits = self.logits_processors(self._inputs, logits)
|
87 |
+
probs = logits.softmax(dim=1) # probs.shape = [batch, vocab]
|
88 |
+
probs = self.probability_processors(self._inputs, probs)
|
89 |
+
|
90 |
+
# Select a token and add it to the inputs.
|
91 |
+
next_tokens = self.token_selector(
|
92 |
+
self._inputs, probs
|
93 |
+
) # next_tokens.shape = [batch]
|
94 |
+
self._inputs = torch.cat([self._inputs, next_tokens[:, None]], dim=1)
|
95 |
+
|
96 |
+
# Run alignment specific postprocessing.
|
97 |
+
self._inputs = self.alignment.postprocess_inputs(
|
98 |
+
self._inputs, self._original_inputs
|
99 |
+
)
|
100 |
+
|
101 |
+
# Return the next step result.
|
102 |
+
return ChameleonGenerator.Token(id=self._inputs[:, -1], logits=logits)
|
103 |
+
|
104 |
+
@property
|
105 |
+
def stopping_criteria(self) -> StoppingCriteriaList:
|
106 |
+
return self._stopping_criteria
|
107 |
+
|
108 |
+
@stopping_criteria.setter
|
109 |
+
def stopping_criteria(
|
110 |
+
self, value: StoppingCriteriaList | list[StoppingCriteria] | None
|
111 |
+
):
|
112 |
+
self._stopping_criteria = StoppingCriteriaList(value or [])
|
113 |
+
|
114 |
+
@property
|
115 |
+
def logits_processors(self) -> LogitsProcessorList:
|
116 |
+
return self._logits_processors
|
117 |
+
|
118 |
+
@logits_processors.setter
|
119 |
+
def logits_processors(
|
120 |
+
self, value: LogitsProcessorList | list[LogitsProcessor] | None
|
121 |
+
):
|
122 |
+
self._logits_processors = LogitsProcessorList(value or [])
|
123 |
+
|
124 |
+
@property
|
125 |
+
def probability_processors(self) -> LogitsProcessorList:
|
126 |
+
return self._probability_processors
|
127 |
+
|
128 |
+
@probability_processors.setter
|
129 |
+
def probability_processors(
|
130 |
+
self, value: LogitsProcessorList | list[LogitsProcessor] | None
|
131 |
+
):
|
132 |
+
self._probability_processors = LogitsProcessorList(value or [])
|
133 |
+
|
134 |
+
|
135 |
+
def run_generation(
|
136 |
+
model: torch.nn.Module,
|
137 |
+
input_ids: list[list[int]],
|
138 |
+
stopping_criteria: StoppingCriteriaList | list[StoppingCriteria],
|
139 |
+
logits_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
140 |
+
probability_processors: LogitsProcessorList | list[LogitsProcessor] | None = None,
|
141 |
+
token_selector: TokenSelector | None = None,
|
142 |
+
alignment: PromptAlignment = AlignPromptLeft(),
|
143 |
+
streamer: BaseStreamer | None = None,
|
144 |
+
) -> torch.LongTensor:
|
145 |
+
result = torch.empty((len(input_ids), 0), dtype=int)
|
146 |
+
for tok in ChameleonGenerator(
|
147 |
+
model=model,
|
148 |
+
input_ids=input_ids,
|
149 |
+
stopping_criteria=stopping_criteria,
|
150 |
+
logits_processors=logits_processors,
|
151 |
+
probability_processors=probability_processors,
|
152 |
+
token_selector=token_selector,
|
153 |
+
alignment=alignment,
|
154 |
+
):
|
155 |
+
if streamer is not None:
|
156 |
+
streamer.put(tok.id)
|
157 |
+
result = torch.cat([result, tok.id.view(-1, 1)], dim=1)
|
158 |
+
|
159 |
+
if streamer is not None:
|
160 |
+
streamer.end()
|
161 |
+
|
162 |
+
return result
|
chameleon/inference/image_tokenizer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import PIL
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from chameleon.inference.vqgan import VQModel
|
13 |
+
|
14 |
+
|
15 |
+
class ImageTokenizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
cfg_path: str,
|
19 |
+
ckpt_path: str,
|
20 |
+
device: str | torch.device | None = None,
|
21 |
+
):
|
22 |
+
with open(cfg_path) as f:
|
23 |
+
config = yaml.safe_load(f)
|
24 |
+
|
25 |
+
params = config["model"]["params"]
|
26 |
+
if "lossconfig" in params:
|
27 |
+
del params["lossconfig"]
|
28 |
+
params["ckpt_path"] = ckpt_path
|
29 |
+
|
30 |
+
self._vq_model = VQModel(**params)
|
31 |
+
self._vq_model.eval()
|
32 |
+
|
33 |
+
if device is None:
|
34 |
+
devices = {p.device for p in self._vq_model.parameters()}
|
35 |
+
assert len(devices) == 1
|
36 |
+
device = devices.pop()
|
37 |
+
else:
|
38 |
+
self._vq_model.to(device)
|
39 |
+
self._device = device
|
40 |
+
|
41 |
+
dtypes = {p.dtype for p in self._vq_model.parameters()}
|
42 |
+
assert len(dtypes) == 1
|
43 |
+
self._dtype = dtypes.pop()
|
44 |
+
|
45 |
+
def _whiten_transparency(self, img: PIL.Image) -> PIL.Image:
|
46 |
+
# Check if it's already in RGB format.
|
47 |
+
if img.mode == "RGB":
|
48 |
+
return img
|
49 |
+
|
50 |
+
vals_rgba = np.array(img.convert("RGBA"))
|
51 |
+
|
52 |
+
# If there is no transparency layer, simple convert and return.
|
53 |
+
if not (vals_rgba[:, :, 3] < 255).any():
|
54 |
+
return img.convert("RGB")
|
55 |
+
|
56 |
+
# There is a transparency layer, blend it with a white background.
|
57 |
+
|
58 |
+
# Calculate the alpha proportion for blending.
|
59 |
+
alpha = vals_rgba[:, :, 3] / 255.0
|
60 |
+
# Blend with white background.
|
61 |
+
vals_rgb = (1 - alpha[:, :, np.newaxis]) * 255 + alpha[
|
62 |
+
:, :, np.newaxis
|
63 |
+
] * vals_rgba[:, :, :3]
|
64 |
+
return PIL.Image.fromarray(vals_rgb.astype("uint8"), "RGB")
|
65 |
+
|
66 |
+
def _vqgan_input_from(self, img: PIL.Image, target_image_size=512) -> torch.Tensor:
|
67 |
+
# Resize with aspect ratio preservation.
|
68 |
+
s = min(img.size)
|
69 |
+
scale = target_image_size / s
|
70 |
+
new_size = (round(scale * img.size[0]), round(scale * img.size[1]))
|
71 |
+
img = img.resize(new_size, PIL.Image.LANCZOS)
|
72 |
+
|
73 |
+
# Center crop.
|
74 |
+
x0 = (img.width - target_image_size) // 2
|
75 |
+
y0 = (img.height - target_image_size) // 2
|
76 |
+
img = img.crop((x0, y0, x0 + target_image_size, y0 + target_image_size))
|
77 |
+
|
78 |
+
# Convert to tensor.
|
79 |
+
np_img = np.array(img) / 255.0 # Normalize to [0, 1]
|
80 |
+
np_img = np_img * 2 - 1 # Scale to [-1, 1]
|
81 |
+
tensor_img = (
|
82 |
+
torch.from_numpy(np_img).permute(2, 0, 1).float()
|
83 |
+
) # (Channels, Height, Width) format.
|
84 |
+
|
85 |
+
# Add batch dimension.
|
86 |
+
return tensor_img.unsqueeze(0)
|
87 |
+
|
88 |
+
def img_tokens_from_pil(self, image: PIL.Image) -> list[int]:
|
89 |
+
image = self._whiten_transparency(image)
|
90 |
+
vqgan_input = self._vqgan_input_from(image).to(self._device).to(self._dtype)
|
91 |
+
_, _, [_, _, img_toks] = self._vq_model.encode(vqgan_input)
|
92 |
+
return img_toks
|
93 |
+
|
94 |
+
def _pil_from_chw_tensor(self, chw_tensor: torch.Tensor) -> PIL.Image:
|
95 |
+
|
96 |
+
# Ensure detachment and move tensor to CPU.
|
97 |
+
detached_chw_tensor = chw_tensor.detach().cpu()
|
98 |
+
|
99 |
+
# Normalize tensor to [0, 1] range from [-1, 1] range.
|
100 |
+
normalized_chw_tensor = (
|
101 |
+
torch.clamp(detached_chw_tensor, -1.0, 1.0) + 1.0
|
102 |
+
) / 2.0
|
103 |
+
|
104 |
+
# Permute CHW tensor to HWC format and convert to NumPy array.
|
105 |
+
hwc_array = normalized_chw_tensor.permute(1, 2, 0).numpy()
|
106 |
+
|
107 |
+
# Convert to an 8-bit unsigned integer format.
|
108 |
+
image_array_uint8 = (hwc_array * 255).astype(np.uint8)
|
109 |
+
|
110 |
+
# Convert NumPy array to PIL Image.
|
111 |
+
pil_image = Image.fromarray(image_array_uint8)
|
112 |
+
|
113 |
+
# Convert image to RGB if it is not already.
|
114 |
+
if pil_image.mode != "RGB":
|
115 |
+
pil_image = pil_image.convert("RGB")
|
116 |
+
|
117 |
+
return pil_image
|
118 |
+
|
119 |
+
def pil_from_img_toks(self, img_tensor: torch.Tensor) -> PIL.Image:
|
120 |
+
emb_dim = self._vq_model.quantize.embedding.weight.shape[-1]
|
121 |
+
codebook_entry = self._vq_model.quantize.get_codebook_entry(
|
122 |
+
img_tensor, (1, 32, 32, emb_dim)
|
123 |
+
)
|
124 |
+
pixels = self._vq_model.decode(codebook_entry)
|
125 |
+
return self._pil_from_chw_tensor(pixels[0])
|
chameleon/inference/loader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import glob
|
7 |
+
import inspect
|
8 |
+
import json
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from chameleon.inference.transformer import ModelArgs, Transformer
|
14 |
+
|
15 |
+
|
16 |
+
def _convert(model_args: ModelArgs, consolidated_path: Path) -> Transformer:
|
17 |
+
old_default_dtype = torch.get_default_dtype()
|
18 |
+
torch.set_default_dtype(torch.bfloat16)
|
19 |
+
|
20 |
+
model = Transformer(model_args)
|
21 |
+
|
22 |
+
transfer_results = model.load_state_dict(
|
23 |
+
torch.load(str(consolidated_path), map_location='cuda'),
|
24 |
+
strict=False,
|
25 |
+
)
|
26 |
+
|
27 |
+
# TODO: More generally, assert missing or unexpected keys are buffers.
|
28 |
+
assert transfer_results.missing_keys == []
|
29 |
+
assert transfer_results.unexpected_keys == ["rope.freqs"]
|
30 |
+
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
torch.set_default_dtype(old_default_dtype)
|
34 |
+
return model
|
35 |
+
|
36 |
+
|
37 |
+
def _get_checkpoint_path(src_dir: Path, rank: int | None) -> Path:
|
38 |
+
base_path = src_dir / "consolidated.pth"
|
39 |
+
if not rank and base_path.exists():
|
40 |
+
return base_path
|
41 |
+
|
42 |
+
alt_path = src_dir / f"consolidated.{rank:02}.pth"
|
43 |
+
if alt_path.exists():
|
44 |
+
return alt_path
|
45 |
+
|
46 |
+
raise ValueError("Consolidated checkpoint not found.")
|
47 |
+
|
48 |
+
|
49 |
+
def load_model(path: str, rank: int | None = None) -> Transformer:
|
50 |
+
src_dir = Path(path)
|
51 |
+
|
52 |
+
with open(src_dir / "params.json", "r") as f:
|
53 |
+
params = json.loads(f.read())
|
54 |
+
with open(src_dir / "consolidate_params.json", "r") as f:
|
55 |
+
consolidate_params = json.loads(f.read())
|
56 |
+
params = {**params, **params["model"], **consolidate_params}
|
57 |
+
|
58 |
+
known_param = inspect.signature(ModelArgs.__init__).parameters
|
59 |
+
filtered_params = {k: v for k, v in params.items() if k in known_param}
|
60 |
+
|
61 |
+
return _convert(
|
62 |
+
ModelArgs(**filtered_params),
|
63 |
+
_get_checkpoint_path(src_dir, rank),
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def detect_shard_count(path: str) -> int:
|
68 |
+
src_dir = Path(path)
|
69 |
+
if (src_dir / "consolidated.pth").exists():
|
70 |
+
return 1
|
71 |
+
return len(glob.glob(str(src_dir / "consolidated.*.pth")))
|
chameleon/inference/logits_processor.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import LogitsProcessor
|
10 |
+
|
11 |
+
|
12 |
+
class TopPProbabilityProcessor(LogitsProcessor):
|
13 |
+
# Modified version of TopPLogitsWarper to act on probabilities.
|
14 |
+
# Changes:
|
15 |
+
# * filter_value changed from -inf to 0
|
16 |
+
# * removed softmax
|
17 |
+
# * renormalize L1
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
top_p: float,
|
22 |
+
min_tokens_to_keep: int = 1,
|
23 |
+
):
|
24 |
+
top_p = float(top_p)
|
25 |
+
if top_p < 0 or top_p > 1.0:
|
26 |
+
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
|
27 |
+
if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
|
28 |
+
raise ValueError(
|
29 |
+
f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}"
|
30 |
+
)
|
31 |
+
|
32 |
+
self.top_p = top_p
|
33 |
+
self.min_tokens_to_keep = min_tokens_to_keep
|
34 |
+
|
35 |
+
def __call__(
|
36 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
37 |
+
) -> torch.FloatTensor:
|
38 |
+
# input_ids.shape=[batch, seq-len]
|
39 |
+
# probs.shape=[batch, vocab]
|
40 |
+
sorted_probs, sorted_indices = torch.sort(probs, descending=False)
|
41 |
+
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
42 |
+
|
43 |
+
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
|
44 |
+
sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
45 |
+
# Keep at least min_tokens_to_keep
|
46 |
+
sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0
|
47 |
+
|
48 |
+
# scatter sorted tensors to original indexing
|
49 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
50 |
+
1, sorted_indices, sorted_indices_to_remove
|
51 |
+
)
|
52 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
53 |
+
probs = probs / probs.sum(dim=-1, keepdim=True)
|
54 |
+
return probs
|
55 |
+
|
56 |
+
|
57 |
+
class DisallowTokensInIndexRangeLogitsProcessor(LogitsProcessor):
|
58 |
+
def __init__(
|
59 |
+
self, token_ids: list[int], start_index: int, end_index: int | None = None
|
60 |
+
):
|
61 |
+
self.token_ids = torch.tensor(token_ids)
|
62 |
+
self.start_index = start_index
|
63 |
+
self.end_index = end_index if end_index is not None else math.inf
|
64 |
+
|
65 |
+
def __call__(
|
66 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
67 |
+
) -> torch.FloatTensor:
|
68 |
+
current_index = input_ids.shape[1]
|
69 |
+
if self.start_index <= current_index < self.end_index:
|
70 |
+
logits[:, self.token_ids] = -math.inf
|
71 |
+
return logits
|
72 |
+
|
73 |
+
|
74 |
+
class DisallowTokensLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
|
75 |
+
def __init__(self, token_ids: list[int]):
|
76 |
+
super().__init__(token_ids, 0)
|
77 |
+
|
78 |
+
|
79 |
+
class DisallowTokensAtIndexLogitsProcessor(DisallowTokensInIndexRangeLogitsProcessor):
|
80 |
+
def __init__(self, token_ids: list[int], index: int):
|
81 |
+
super().__init__(token_ids, index, index + 1)
|
82 |
+
|
83 |
+
|
84 |
+
class DisallowTokensAfterIndexLogitsProcessor(
|
85 |
+
DisallowTokensInIndexRangeLogitsProcessor
|
86 |
+
):
|
87 |
+
def __init__(self, token_ids: list[int], index: int):
|
88 |
+
super().__init__(token_ids, index + 1)
|
89 |
+
|
90 |
+
|
91 |
+
class DisallowTokensAtOrAfterIndexLogitsProcessor(
|
92 |
+
DisallowTokensInIndexRangeLogitsProcessor
|
93 |
+
):
|
94 |
+
def __init__(self, token_ids: list[int], index: int):
|
95 |
+
super().__init__(token_ids, index)
|
96 |
+
|
97 |
+
|
98 |
+
class DisallowTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
token_ids: list[int],
|
102 |
+
start_indices: list[int],
|
103 |
+
end_indices: list[int] | None = None,
|
104 |
+
):
|
105 |
+
self.token_ids = torch.tensor(token_ids)
|
106 |
+
self.start_indices = torch.tensor(start_indices)
|
107 |
+
self.end_indices = (
|
108 |
+
torch.tensor(end_indices)
|
109 |
+
if end_indices is not None
|
110 |
+
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
|
111 |
+
)
|
112 |
+
|
113 |
+
def __call__(
|
114 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
115 |
+
) -> torch.FloatTensor:
|
116 |
+
# input_ids.shape = [batch, seq_len]
|
117 |
+
# logits.shape = [batch, vocab]
|
118 |
+
current_index = input_ids.shape[1]
|
119 |
+
mask = (self.start_indices <= current_index) & (
|
120 |
+
current_index < self.end_indices
|
121 |
+
)
|
122 |
+
# The following will fail if the mask is all False.
|
123 |
+
# logits[mask, self.token_ids] = -math.inf
|
124 |
+
logits[torch.where(mask)[0].unsqueeze(1), self.token_ids] = -math.inf
|
125 |
+
return logits
|
126 |
+
|
127 |
+
|
128 |
+
class DisallowTokensAtBatchIndexLogitsProcessor(
|
129 |
+
DisallowTokensInBatchIndexRangeLogitsProcessor
|
130 |
+
):
|
131 |
+
def __init__(self, token_ids: list[int], batch_index: list[int]):
|
132 |
+
super().__init__(token_ids, batch_index, [i + 1 for i in batch_index])
|
133 |
+
|
134 |
+
|
135 |
+
class AllowOnlyTokensInIndexRangeLogitsProcessor(LogitsProcessor):
|
136 |
+
def __init__(
|
137 |
+
self, token_ids: list[int], start_index: int, end_index: int | None = None
|
138 |
+
):
|
139 |
+
self.token_ids = torch.tensor(token_ids)
|
140 |
+
self.start_index = start_index
|
141 |
+
self.end_index = end_index if end_index is not None else math.inf
|
142 |
+
|
143 |
+
def __call__(
|
144 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
145 |
+
) -> torch.FloatTensor:
|
146 |
+
current_index = input_ids.shape[1]
|
147 |
+
if self.start_index <= current_index < self.end_index:
|
148 |
+
replacement = torch.full_like(logits, -math.inf)
|
149 |
+
replacement[:, self.token_ids] = logits[:, self.token_ids]
|
150 |
+
logits[:] = replacement
|
151 |
+
return logits
|
152 |
+
|
153 |
+
|
154 |
+
class AllowOnlyTokensLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
|
155 |
+
def __init__(self, token_ids: list[int]):
|
156 |
+
super().__init__(token_ids, 0)
|
157 |
+
|
158 |
+
|
159 |
+
class AllowOnlyTokensAtIndexLogitsProcessor(AllowOnlyTokensInIndexRangeLogitsProcessor):
|
160 |
+
def __init__(self, token_ids: list[int], index: int):
|
161 |
+
super().__init__(token_ids, index, index + 1)
|
162 |
+
|
163 |
+
|
164 |
+
class AllowOnlyTokensAfterIndexLogitsProcessor(
|
165 |
+
AllowOnlyTokensInIndexRangeLogitsProcessor
|
166 |
+
):
|
167 |
+
def __init__(self, token_ids: list[int], index: int):
|
168 |
+
super().__init__(token_ids, index + 1)
|
169 |
+
|
170 |
+
|
171 |
+
class AllowOnlyTokensAtOrAfterIndexLogitsProcessor(
|
172 |
+
AllowOnlyTokensInIndexRangeLogitsProcessor
|
173 |
+
):
|
174 |
+
def __init__(self, token_ids: list[int], index: int):
|
175 |
+
super().__init__(token_ids, index)
|
176 |
+
|
177 |
+
|
178 |
+
class AllowOnlyTokensInBatchIndexRangeLogitsProcessor(LogitsProcessor):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
token_ids: list[int],
|
182 |
+
start_indices: list[int],
|
183 |
+
end_indices: list[int] | None = None,
|
184 |
+
):
|
185 |
+
self.token_ids = torch.tensor(token_ids)
|
186 |
+
self.start_indices = torch.tensor(start_indices)
|
187 |
+
self.end_indices = (
|
188 |
+
torch.tensor(end_indices)
|
189 |
+
if end_indices is not None
|
190 |
+
else torch.full_like(self.start_indices, math.inf, dtype=torch.float)
|
191 |
+
)
|
192 |
+
|
193 |
+
def __call__(
|
194 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
195 |
+
) -> torch.FloatTensor:
|
196 |
+
# input_ids.shape = [batch, seq_len]
|
197 |
+
# logits.shape = [batch, vocab]
|
198 |
+
current_index = input_ids.shape[1]
|
199 |
+
mask = (self.start_indices <= current_index) & (
|
200 |
+
current_index < self.end_indices
|
201 |
+
)
|
202 |
+
|
203 |
+
valid_batch_indices = torch.where(mask)[0].unsqueeze(1)
|
204 |
+
full_mask = torch.full_like(logits, -math.inf)
|
205 |
+
full_mask[valid_batch_indices, self.token_ids] = logits[
|
206 |
+
valid_batch_indices, self.token_ids
|
207 |
+
]
|
208 |
+
|
209 |
+
logits[:] = torch.where(full_mask != -math.inf, full_mask, logits)
|
210 |
+
return logits
|
211 |
+
|
212 |
+
|
213 |
+
class AllowOnlyTokensAtRelativeOffsetLogitsProcessor(LogitsProcessor):
|
214 |
+
def __init__(
|
215 |
+
self, trigger_token_id: int, subsequent_token_ids: list[int], offset: int
|
216 |
+
):
|
217 |
+
self.trigger_token_id = trigger_token_id
|
218 |
+
self.subsequent_token_ids = torch.tensor(subsequent_token_ids)
|
219 |
+
self.offset = offset
|
220 |
+
|
221 |
+
def __call__(
|
222 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
223 |
+
) -> torch.FloatTensor:
|
224 |
+
# input_ids.shape=[batch, seq_len]
|
225 |
+
# logits.shape=[batch, vocab]
|
226 |
+
if input_ids.shape[1] < self.offset:
|
227 |
+
return logits
|
228 |
+
|
229 |
+
trigger_positions = (
|
230 |
+
input_ids[:, -self.offset] == self.trigger_token_id
|
231 |
+
).unsqueeze(-1)
|
232 |
+
|
233 |
+
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
|
234 |
+
disallowed_tokens_mask[:, self.subsequent_token_ids] = False
|
235 |
+
|
236 |
+
return logits.masked_fill_(
|
237 |
+
disallowed_tokens_mask & trigger_positions,
|
238 |
+
-math.inf,
|
239 |
+
)
|
240 |
+
|
241 |
+
|
242 |
+
class AllowOnlyTokensInRelativeWindowLogitsProcessor(LogitsProcessor):
|
243 |
+
def __init__(self, trigger_token_id: int, allowed_token_ids: list[int], width: int):
|
244 |
+
self.trigger_token_id = trigger_token_id
|
245 |
+
self.allowed_token_ids = torch.tensor(allowed_token_ids).unsqueeze(
|
246 |
+
0
|
247 |
+
) # shape: [1, num_allowed_tokens]
|
248 |
+
self.width = width
|
249 |
+
|
250 |
+
def __call__(
|
251 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
252 |
+
) -> torch.FloatTensor:
|
253 |
+
# input_ids.shape=[batch, seq_len]
|
254 |
+
# logits.shape=[batch, vocab]
|
255 |
+
width = min(self.width, input_ids.shape[1])
|
256 |
+
trigger_positions = (
|
257 |
+
(input_ids[:, -width:] == self.trigger_token_id).any(dim=1).unsqueeze(-1)
|
258 |
+
)
|
259 |
+
|
260 |
+
disallowed_tokens_mask = torch.ones_like(logits, dtype=bool)
|
261 |
+
disallowed_tokens_mask[:, self.allowed_token_ids] = False
|
262 |
+
|
263 |
+
return logits.masked_fill_(
|
264 |
+
disallowed_tokens_mask & trigger_positions,
|
265 |
+
-math.inf,
|
266 |
+
)
|
267 |
+
|
268 |
+
|
269 |
+
class CFGLogitsProcessor(LogitsProcessor):
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
guidance_scale: float,
|
273 |
+
unconditional_ids: torch.LongTensor,
|
274 |
+
model,
|
275 |
+
):
|
276 |
+
self.guidance_scale = guidance_scale
|
277 |
+
self.unconditional_ids = unconditional_ids
|
278 |
+
self.model = model
|
279 |
+
|
280 |
+
def __call__(
|
281 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
282 |
+
) -> torch.FloatTensor:
|
283 |
+
conditioned_logits = logits
|
284 |
+
|
285 |
+
self.unconditional_ids = torch.cat(
|
286 |
+
[self.unconditional_ids, input_ids[:, -1:]], dim=1
|
287 |
+
)
|
288 |
+
unconditioned_outputs = self.model(self.unconditional_ids)
|
289 |
+
unconditioned_logits = unconditioned_outputs[:, -1, :]
|
290 |
+
return (
|
291 |
+
self.guidance_scale * (conditioned_logits - unconditioned_logits)
|
292 |
+
+ unconditioned_logits
|
293 |
+
)
|
294 |
+
|
295 |
+
|
296 |
+
class InBatchCFGLogitsProcessor(LogitsProcessor):
|
297 |
+
def __init__(self, guidance_scale: float):
|
298 |
+
self.guidance_scale = guidance_scale
|
299 |
+
|
300 |
+
def __call__(
|
301 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
302 |
+
) -> torch.FloatTensor:
|
303 |
+
# input_ids.shape=[2*batch, seq-len]
|
304 |
+
# logits.shape=[2*batch, vocab]
|
305 |
+
conditioned_logits, unconditioned_logits = torch.chunk(logits, chunks=2, dim=0)
|
306 |
+
mixed_logits = unconditioned_logits + self.guidance_scale * (
|
307 |
+
conditioned_logits - unconditioned_logits
|
308 |
+
)
|
309 |
+
return mixed_logits.repeat(2, 1)
|
310 |
+
|
311 |
+
|
312 |
+
class InBatchInstructCFGLogitsProcessor(LogitsProcessor):
|
313 |
+
# See https://arxiv.org/abs/2211.09800
|
314 |
+
|
315 |
+
def __init__(self, guidance_scale_text: float, guidance_scale_image: float):
|
316 |
+
self.guidance_scale_text = guidance_scale_text
|
317 |
+
self.guidance_scale_image = guidance_scale_image
|
318 |
+
|
319 |
+
def __call__(
|
320 |
+
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
|
321 |
+
) -> torch.FloatTensor:
|
322 |
+
# input_ids.shape=[3*batch, seq-len]
|
323 |
+
# logits.shape=[3*batch, vocab]
|
324 |
+
(
|
325 |
+
full_conditioned_logits,
|
326 |
+
image_conditioned_logits,
|
327 |
+
unconditioned_logits,
|
328 |
+
) = logits.chunk(3)
|
329 |
+
mixed_logits = (
|
330 |
+
unconditioned_logits
|
331 |
+
+ self.guidance_scale_image
|
332 |
+
* (image_conditioned_logits - unconditioned_logits)
|
333 |
+
+ self.guidance_scale_text
|
334 |
+
* (full_conditioned_logits - image_conditioned_logits)
|
335 |
+
)
|
336 |
+
return mixed_logits.repeat(3, 1)
|
chameleon/inference/model_adapter.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from chameleon.inference import transformer
|
12 |
+
from chameleon.inference.alignment import (
|
13 |
+
AlignPromptLeft,
|
14 |
+
AlignPromptRight,
|
15 |
+
PromptAlignment,
|
16 |
+
)
|
17 |
+
from chameleon.inference.cudagraph import cudagraph_wrap
|
18 |
+
|
19 |
+
|
20 |
+
class ModelAdapter(ABC):
|
21 |
+
@abstractmethod
|
22 |
+
def initialize(self, prompt_tokens: list[list[int]]):
|
23 |
+
...
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def supports_alignment(self, alignment: PromptAlignment) -> bool:
|
27 |
+
...
|
28 |
+
|
29 |
+
@abstractmethod
|
30 |
+
@torch.inference_mode()
|
31 |
+
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
|
32 |
+
...
|
33 |
+
|
34 |
+
|
35 |
+
class ChameleonModelAdapter(ModelAdapter):
|
36 |
+
"""Adapter for Chameleon-style model that handles state, such as cache."""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
model: transformer.Transformer,
|
41 |
+
max_seq_len: int,
|
42 |
+
dtype: torch.dtype | None = None,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
self._args = model.args
|
46 |
+
self._model = model
|
47 |
+
self._max_seq_len = max_seq_len
|
48 |
+
self._dtype = dtype or next(model.parameters()).data.dtype
|
49 |
+
|
50 |
+
def initialize(self, prompt_tokens: list[list[int]]):
|
51 |
+
self._prompt_lengths = [len(toks) for toks in prompt_tokens]
|
52 |
+
batch_size = len(prompt_tokens)
|
53 |
+
|
54 |
+
self._cache = transformer.make_cache(
|
55 |
+
args=self._args,
|
56 |
+
length=batch_size * self._max_seq_len,
|
57 |
+
dtype=self._dtype,
|
58 |
+
)
|
59 |
+
|
60 |
+
self._local_inputs = torch.zeros([batch_size], dtype=int, device="cuda")
|
61 |
+
|
62 |
+
self._forward = cudagraph_wrap(self._model.forward_with_attn_bias)
|
63 |
+
|
64 |
+
self._first_pass = True
|
65 |
+
|
66 |
+
def supports_alignment(self, alignment: PromptAlignment) -> bool:
|
67 |
+
return isinstance(alignment, AlignPromptLeft) or isinstance(
|
68 |
+
alignment, AlignPromptRight
|
69 |
+
)
|
70 |
+
|
71 |
+
def __call__(self, inputs: torch.LongTensor) -> torch.FloatTensor:
|
72 |
+
# inputs.shape=[batch, seq-len]
|
73 |
+
batch_size, seq_len = inputs.shape
|
74 |
+
|
75 |
+
if self._first_pass:
|
76 |
+
attn_seqlen = [min(pl, seq_len) for pl in self._prompt_lengths]
|
77 |
+
self._bias = transformer.AttnBias.from_seqlens(
|
78 |
+
q_seqlen=attn_seqlen,
|
79 |
+
kv_seqlen=attn_seqlen,
|
80 |
+
kv_padding=self._max_seq_len,
|
81 |
+
)
|
82 |
+
|
83 |
+
mask = torch.zeros_like(inputs, dtype=torch.bool)
|
84 |
+
for i, k in enumerate(self._prompt_lengths):
|
85 |
+
mask[i, -k:] = True
|
86 |
+
|
87 |
+
flat_outputs: torch.Tensor = self._forward( # type: ignore
|
88 |
+
token_values=inputs[mask],
|
89 |
+
attn_bias=self._bias,
|
90 |
+
cache=self._cache,
|
91 |
+
)
|
92 |
+
self._local_outputs = torch.full(
|
93 |
+
(inputs.shape[0], inputs.shape[1], flat_outputs.shape[-1]),
|
94 |
+
-math.inf,
|
95 |
+
)
|
96 |
+
self._local_outputs[mask] = flat_outputs
|
97 |
+
|
98 |
+
self._vocab_size = self._local_outputs.shape[-1]
|
99 |
+
|
100 |
+
self._bias.q_seqinfo.seqstart.copy_(
|
101 |
+
torch.arange(batch_size + 1, dtype=torch.int)
|
102 |
+
)
|
103 |
+
self._bias.q_seqinfo.max_seqlen = 1
|
104 |
+
self._bias.q_seqinfo.seqstart_py = self._bias.q_seqinfo.seqstart.tolist()
|
105 |
+
|
106 |
+
self._first_pass = False
|
107 |
+
|
108 |
+
else:
|
109 |
+
self._local_inputs.copy_(inputs[:, -1]) # type: ignore
|
110 |
+
|
111 |
+
self._local_outputs = self._forward( # type: ignore
|
112 |
+
token_values=self._local_inputs,
|
113 |
+
attn_bias=self._bias,
|
114 |
+
cache=self._cache,
|
115 |
+
)
|
116 |
+
|
117 |
+
self._bias.k_seqinfo.seqlen.add_(1)
|
118 |
+
return self._local_outputs.view(batch_size, -1, self._vocab_size)
|
chameleon/inference/stopping_criteria.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class StoppingCriteria:
|
10 |
+
def __call__(
|
11 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
12 |
+
) -> bool:
|
13 |
+
raise NotImplementedError("StoppingCriteria needs to be subclassed")
|
14 |
+
|
15 |
+
|
16 |
+
class StoppingCriteriaList(list):
|
17 |
+
def __call__(
|
18 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
19 |
+
) -> bool:
|
20 |
+
return any(criteria(input_ids, scores, **kwargs) for criteria in self)
|
21 |
+
|
22 |
+
|
23 |
+
class MaxLengthCriteria(StoppingCriteria):
|
24 |
+
def __init__(self, max_length: int):
|
25 |
+
self.max_length = max_length
|
26 |
+
|
27 |
+
def __call__(
|
28 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
29 |
+
) -> bool:
|
30 |
+
cur_len = input_ids.shape[-1]
|
31 |
+
return cur_len >= self.max_length
|
32 |
+
|
33 |
+
|
34 |
+
class StopOnEOS(StoppingCriteria):
|
35 |
+
def __init__(self, eos_id: int):
|
36 |
+
self._eos_id = eos_id
|
37 |
+
|
38 |
+
def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
|
39 |
+
# input_ids.shape=[batch, seq_len]
|
40 |
+
return (input_ids == self._eos_id).sum(dim=1).all()
|
41 |
+
|
42 |
+
|
43 |
+
class StopOnEOSAfterBatchIndex(StoppingCriteria):
|
44 |
+
def __init__(self, eos_id: int, batch_index: list[int]):
|
45 |
+
self._eos_id = eos_id
|
46 |
+
self.batch_index = torch.tensor(batch_index, dtype=torch.long).unsqueeze(1)
|
47 |
+
|
48 |
+
def __call__(self, input_ids: torch.LongTensor, _: torch.FloatTensor) -> bool:
|
49 |
+
# input_ids.shape=[batch, seq_len]
|
50 |
+
eos_mask = input_ids == self._eos_id
|
51 |
+
consider_eos_mask = (
|
52 |
+
torch.arange(input_ids.shape[1]).unsqueeze(0) >= self.batch_index
|
53 |
+
)
|
54 |
+
valid_eos = eos_mask & consider_eos_mask
|
55 |
+
return valid_eos.sum(dim=1).all()
|
chameleon/inference/token_selector.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class TokenSelector:
|
10 |
+
def __call__(
|
11 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
12 |
+
) -> torch.FloatTensor:
|
13 |
+
# input_ids.shape=[batch, seq_len]
|
14 |
+
# probs.shape=[batch, vocab]
|
15 |
+
...
|
16 |
+
|
17 |
+
|
18 |
+
class ArgmaxTokenSelector(TokenSelector):
|
19 |
+
def __call__(
|
20 |
+
self, _: torch.LongTensor, probs: torch.FloatTensor
|
21 |
+
) -> torch.LongTensor:
|
22 |
+
# probs.shape=[batch, vocab]
|
23 |
+
return probs.argmax(dim=1)
|
24 |
+
|
25 |
+
|
26 |
+
class MultinomialTokenSelector(TokenSelector):
|
27 |
+
def __call__(
|
28 |
+
self, _: torch.LongTensor, probs: torch.FloatTensor
|
29 |
+
) -> torch.LongTensor:
|
30 |
+
# probs.shape=[batch, vocab]
|
31 |
+
return probs.multinomial(num_samples=1).squeeze(1)
|
32 |
+
|
33 |
+
|
34 |
+
class ReplicatedInputTokenSelector(TokenSelector):
|
35 |
+
def __init__(self, token_selector: TokenSelector, n: int):
|
36 |
+
self.token_selector = token_selector
|
37 |
+
self.n = n
|
38 |
+
|
39 |
+
def __call__(
|
40 |
+
self, input_ids: torch.LongTensor, probs: torch.FloatTensor
|
41 |
+
) -> torch.LongTensor:
|
42 |
+
# input_ids.shape=[n*batch, seq_len]
|
43 |
+
# probs.shape=[n*batch, vocab]
|
44 |
+
primary_input_ids = torch.chunk(input_ids, chunks=self.n, dim=0)[0]
|
45 |
+
primary_probs = torch.chunk(probs, chunks=self.n, dim=0)[0]
|
46 |
+
tokens = self.token_selector(primary_input_ids, primary_probs)
|
47 |
+
return tokens.repeat(self.n)
|
chameleon/inference/transformer.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import distributed as dist
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import functional as F
|
12 |
+
from xformers.ops import RMSNorm, fmha, rope_padded
|
13 |
+
from xformers.ops.fmha.attn_bias import (
|
14 |
+
BlockDiagonalCausalWithOffsetPaddedKeysMask as AttnBias,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class ModelArgs:
|
20 |
+
model_parallel_size: int = 1
|
21 |
+
dim: int = 512
|
22 |
+
n_layers: int = 8
|
23 |
+
n_heads: int = 8
|
24 |
+
n_kv_heads: int | None = None
|
25 |
+
vocab_size: int = -1
|
26 |
+
ffn_dim_multiplier: float | None = None
|
27 |
+
multiple_of: int = 256
|
28 |
+
norm_eps: float = 1e-5
|
29 |
+
rope_theta: float = 10000.0
|
30 |
+
qk_normalization: bool = False
|
31 |
+
swin_norm: bool = False
|
32 |
+
|
33 |
+
|
34 |
+
LayerCache = tuple[torch.Tensor, torch.Tensor]
|
35 |
+
|
36 |
+
|
37 |
+
class Attention(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
model_parallel_size: int,
|
41 |
+
dim: int,
|
42 |
+
head_dim: int,
|
43 |
+
n_heads: int,
|
44 |
+
n_kv_heads: int,
|
45 |
+
rope_theta: float,
|
46 |
+
qk_normalization: bool = False,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
|
50 |
+
self.model_parallel_size = model_parallel_size
|
51 |
+
|
52 |
+
self.head_dim = head_dim
|
53 |
+
self.rope_theta = rope_theta
|
54 |
+
|
55 |
+
self.n_local_heads = n_heads // model_parallel_size
|
56 |
+
self.n_local_kv_heads = n_kv_heads // model_parallel_size
|
57 |
+
|
58 |
+
self.wqkv = nn.Linear(
|
59 |
+
dim,
|
60 |
+
(self.n_local_heads + 2 * self.n_local_kv_heads) * head_dim,
|
61 |
+
bias=False,
|
62 |
+
dtype=torch.bfloat16,
|
63 |
+
)
|
64 |
+
self.wo = nn.Linear(
|
65 |
+
self.n_local_heads * head_dim,
|
66 |
+
dim,
|
67 |
+
bias=False,
|
68 |
+
dtype=torch.bfloat16,
|
69 |
+
)
|
70 |
+
|
71 |
+
self.qk_normalization = qk_normalization
|
72 |
+
if qk_normalization:
|
73 |
+
self.q_normalization = torch.nn.LayerNorm(head_dim)
|
74 |
+
self.k_normalization = torch.nn.LayerNorm(head_dim)
|
75 |
+
|
76 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
77 |
+
|
78 |
+
# This adapter makes sure we can load vanilla
|
79 |
+
# Llama checkpoints where wq, wk, and wv are
|
80 |
+
# not fused in a single parameter
|
81 |
+
def load_hook(
|
82 |
+
self,
|
83 |
+
state_dict,
|
84 |
+
prefix,
|
85 |
+
local_metadata,
|
86 |
+
strict,
|
87 |
+
missing_keys,
|
88 |
+
unexpected_keys,
|
89 |
+
error_msgs,
|
90 |
+
):
|
91 |
+
if prefix + "wq.weight" in state_dict:
|
92 |
+
wq = state_dict.pop(prefix + "wq.weight")
|
93 |
+
wk = state_dict.pop(prefix + "wk.weight")
|
94 |
+
wv = state_dict.pop(prefix + "wv.weight")
|
95 |
+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
96 |
+
|
97 |
+
def forward(
|
98 |
+
self,
|
99 |
+
x: torch.Tensor,
|
100 |
+
cache: LayerCache,
|
101 |
+
attn_bias: AttnBias,
|
102 |
+
group: dist.ProcessGroup | None = None,
|
103 |
+
) -> torch.Tensor:
|
104 |
+
# x.shape is (sum(seq_lens), dim)
|
105 |
+
#
|
106 |
+
# Since we support heterogenous sequence
|
107 |
+
# lengths, the hidden states are all
|
108 |
+
# concatenated together along the usual
|
109 |
+
# sequence dimension. The attention below
|
110 |
+
# finds out where sequences start & end
|
111 |
+
# using the provided attention bias.
|
112 |
+
xqkv = self.wqkv(x)
|
113 |
+
xq = xqkv[:, : (self.n_local_heads * self.head_dim)]
|
114 |
+
xkv = xqkv[:, (self.n_local_heads * self.head_dim) :]
|
115 |
+
xk, xv = xkv.chunk(2, 1)
|
116 |
+
|
117 |
+
if self.qk_normalization:
|
118 |
+
xq = xq.view(-1, self.n_local_heads, self.head_dim)
|
119 |
+
xq = self.q_normalization(xq)
|
120 |
+
xq = xq.view(-1, self.n_local_heads * self.head_dim)
|
121 |
+
|
122 |
+
xk = xk.view(-1, self.n_local_kv_heads, self.head_dim)
|
123 |
+
xk = self.k_normalization(xk)
|
124 |
+
xk = xk.view(-1, self.n_local_kv_heads * self.head_dim)
|
125 |
+
|
126 |
+
output_shape = xq.shape
|
127 |
+
xq = xq.view(1, xq.shape[0], self.n_local_heads, self.head_dim)
|
128 |
+
xk = xk.view(1, xk.shape[0], self.n_local_kv_heads, self.head_dim)
|
129 |
+
xv = xv.view(1, xv.shape[0], self.n_local_kv_heads, self.head_dim)
|
130 |
+
cache_k, cache_v = cache
|
131 |
+
|
132 |
+
xq = rope_padded(
|
133 |
+
xq=xq,
|
134 |
+
xk=xk,
|
135 |
+
xv=xv,
|
136 |
+
cache_k=cache_k,
|
137 |
+
cache_v=cache_v,
|
138 |
+
attn_bias=attn_bias,
|
139 |
+
theta=self.rope_theta,
|
140 |
+
)
|
141 |
+
|
142 |
+
# Handle GQA
|
143 |
+
# Q shape: [B, M, Hkv, Hq // Hkv, K]
|
144 |
+
heads_per_group = self.n_local_heads // self.n_local_kv_heads
|
145 |
+
cache_k = cache_k.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
|
146 |
+
cache_v = cache_v.unsqueeze(3).expand(-1, -1, -1, heads_per_group, -1)
|
147 |
+
xq = xq.reshape(
|
148 |
+
[*xq.shape[:2], self.n_local_kv_heads, heads_per_group, xq.shape[-1]]
|
149 |
+
)
|
150 |
+
|
151 |
+
# rope_padded() updated the caches, so we
|
152 |
+
# call attention directly
|
153 |
+
output = fmha.memory_efficient_attention_forward(
|
154 |
+
xq, cache_k, cache_v, attn_bias
|
155 |
+
)
|
156 |
+
|
157 |
+
output = self.wo(output.reshape(output_shape))
|
158 |
+
if self.model_parallel_size > 1:
|
159 |
+
dist.all_reduce(output, group=group)
|
160 |
+
|
161 |
+
return output
|
162 |
+
|
163 |
+
|
164 |
+
class FeedForward(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self,
|
167 |
+
model_parallel_size: int,
|
168 |
+
dim: int,
|
169 |
+
hidden_dim: int,
|
170 |
+
multiple_of: int,
|
171 |
+
ffn_dim_multiplier: float | None,
|
172 |
+
):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.model_parallel_size = model_parallel_size
|
176 |
+
|
177 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
178 |
+
if ffn_dim_multiplier is not None:
|
179 |
+
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
180 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
181 |
+
assert hidden_dim % model_parallel_size == 0
|
182 |
+
|
183 |
+
self.w13 = nn.Linear(
|
184 |
+
dim,
|
185 |
+
2 * hidden_dim // model_parallel_size,
|
186 |
+
bias=False,
|
187 |
+
)
|
188 |
+
self.w2 = nn.Linear(
|
189 |
+
hidden_dim // model_parallel_size,
|
190 |
+
dim,
|
191 |
+
bias=False,
|
192 |
+
)
|
193 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
194 |
+
|
195 |
+
# This adapter makes sure we can load vanilla
|
196 |
+
# Llama checkpoints where w1 and w3 are not
|
197 |
+
# fused in a single parameter
|
198 |
+
def load_hook(
|
199 |
+
self,
|
200 |
+
state_dict,
|
201 |
+
prefix,
|
202 |
+
local_metadata,
|
203 |
+
strict,
|
204 |
+
missing_keys,
|
205 |
+
unexpected_keys,
|
206 |
+
error_msgs,
|
207 |
+
):
|
208 |
+
if prefix + "w1.weight" in state_dict:
|
209 |
+
w1 = state_dict.pop(prefix + "w1.weight")
|
210 |
+
w3 = state_dict.pop(prefix + "w3.weight")
|
211 |
+
state_dict[prefix + "w13.weight"] = torch.cat([w1, w3])
|
212 |
+
|
213 |
+
def forward(
|
214 |
+
self, x: torch.Tensor, group: dist.ProcessGroup | None = None
|
215 |
+
) -> torch.Tensor:
|
216 |
+
x13 = self.w13(x)
|
217 |
+
x1, x3 = x13.chunk(2, -1)
|
218 |
+
output = self.w2(F.silu(x1) * x3)
|
219 |
+
if self.model_parallel_size > 1:
|
220 |
+
dist.all_reduce(output, group=group)
|
221 |
+
return output
|
222 |
+
|
223 |
+
|
224 |
+
class TransformerBlock(nn.Module):
|
225 |
+
def __init__(self, args: ModelArgs):
|
226 |
+
super().__init__()
|
227 |
+
|
228 |
+
assert args.dim % args.n_heads == 0
|
229 |
+
head_dim = args.dim // args.n_heads
|
230 |
+
if args.n_kv_heads is not None:
|
231 |
+
n_kv_heads = args.n_kv_heads
|
232 |
+
else:
|
233 |
+
n_kv_heads = args.n_heads
|
234 |
+
|
235 |
+
model_parallel_size = args.model_parallel_size
|
236 |
+
assert args.n_heads % n_kv_heads == 0
|
237 |
+
assert args.n_heads % model_parallel_size == 0
|
238 |
+
assert n_kv_heads % model_parallel_size == 0
|
239 |
+
|
240 |
+
self.attention = Attention(
|
241 |
+
model_parallel_size=model_parallel_size,
|
242 |
+
dim=args.dim,
|
243 |
+
head_dim=head_dim,
|
244 |
+
n_heads=args.n_heads,
|
245 |
+
n_kv_heads=n_kv_heads,
|
246 |
+
rope_theta=args.rope_theta,
|
247 |
+
qk_normalization=args.qk_normalization,
|
248 |
+
)
|
249 |
+
self.feed_forward = FeedForward(
|
250 |
+
model_parallel_size=model_parallel_size,
|
251 |
+
dim=args.dim,
|
252 |
+
hidden_dim=4 * args.dim,
|
253 |
+
multiple_of=args.multiple_of,
|
254 |
+
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
255 |
+
)
|
256 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
257 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
258 |
+
self.swin_norm = args.swin_norm
|
259 |
+
|
260 |
+
def forward(
|
261 |
+
self,
|
262 |
+
x: torch.Tensor,
|
263 |
+
cache: LayerCache,
|
264 |
+
attn_bias: AttnBias,
|
265 |
+
group: dist.ProcessGroup | None = None,
|
266 |
+
) -> torch.Tensor:
|
267 |
+
if self.swin_norm:
|
268 |
+
h = x + self.attention_norm(
|
269 |
+
self.attention.forward(
|
270 |
+
x,
|
271 |
+
cache,
|
272 |
+
attn_bias,
|
273 |
+
group=group,
|
274 |
+
)
|
275 |
+
)
|
276 |
+
out = h + self.ffn_norm(self.feed_forward(h, group=group))
|
277 |
+
else:
|
278 |
+
h = x + self.attention.forward(
|
279 |
+
self.attention_norm(x),
|
280 |
+
cache,
|
281 |
+
attn_bias,
|
282 |
+
group=group,
|
283 |
+
)
|
284 |
+
out = h + self.feed_forward(self.ffn_norm(h), group=group)
|
285 |
+
return out
|
286 |
+
|
287 |
+
|
288 |
+
class Transformer(nn.Module):
|
289 |
+
def __init__(self, args: ModelArgs):
|
290 |
+
super().__init__()
|
291 |
+
self.args = args
|
292 |
+
|
293 |
+
self.model_parallel_size = args.model_parallel_size
|
294 |
+
assert args.dim % self.model_parallel_size == 0
|
295 |
+
assert args.vocab_size > 0
|
296 |
+
assert args.vocab_size % self.model_parallel_size == 0
|
297 |
+
|
298 |
+
self.tok_embeddings = nn.Embedding(
|
299 |
+
num_embeddings=args.vocab_size,
|
300 |
+
embedding_dim=args.dim // self.model_parallel_size,
|
301 |
+
)
|
302 |
+
|
303 |
+
self.layers = nn.ModuleList()
|
304 |
+
for _ in range(args.n_layers):
|
305 |
+
self.layers.append(TransformerBlock(args))
|
306 |
+
|
307 |
+
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
308 |
+
|
309 |
+
self.output = nn.Linear(
|
310 |
+
args.dim,
|
311 |
+
args.vocab_size // self.model_parallel_size,
|
312 |
+
bias=False,
|
313 |
+
)
|
314 |
+
|
315 |
+
@torch.no_grad()
|
316 |
+
def forward_with_attn_bias(
|
317 |
+
self,
|
318 |
+
token_values: torch.Tensor,
|
319 |
+
attn_bias: AttnBias,
|
320 |
+
cache: list[LayerCache],
|
321 |
+
group: dist.ProcessGroup | None = None,
|
322 |
+
) -> torch.Tensor:
|
323 |
+
h = self.tok_embeddings(token_values)
|
324 |
+
if self.model_parallel_size > 1:
|
325 |
+
gather = [torch.empty_like(h) for _ in range(self.model_parallel_size)]
|
326 |
+
dist.all_gather(gather, h, group=group)
|
327 |
+
h = torch.cat(gather, dim=-1)
|
328 |
+
|
329 |
+
for i, layer in enumerate(self.layers):
|
330 |
+
h = layer(h, cache[i], attn_bias, group=group)
|
331 |
+
|
332 |
+
logits = self.output(self.norm(h))
|
333 |
+
if self.model_parallel_size > 1:
|
334 |
+
gather = [torch.empty_like(logits) for _ in range(self.model_parallel_size)]
|
335 |
+
dist.all_gather(gather, logits, group=group)
|
336 |
+
logits = torch.cat(gather, dim=-1)
|
337 |
+
return logits.float()
|
338 |
+
|
339 |
+
def forward(
|
340 |
+
self,
|
341 |
+
token_values: torch.Tensor,
|
342 |
+
token_lengths: torch.Tensor,
|
343 |
+
start_pos: torch.Tensor,
|
344 |
+
cache: list[LayerCache],
|
345 |
+
kv_padding: int,
|
346 |
+
group: dist.ProcessGroup | None = None,
|
347 |
+
) -> torch.Tensor:
|
348 |
+
attn_bias = AttnBias.from_seqlens(
|
349 |
+
q_seqlen=token_lengths.tolist(),
|
350 |
+
kv_seqlen=(start_pos + token_lengths).tolist(),
|
351 |
+
kv_padding=kv_padding,
|
352 |
+
)
|
353 |
+
return self.forward_with_attn_bias(token_values, attn_bias, cache, group=group)
|
354 |
+
|
355 |
+
|
356 |
+
def make_cache(
|
357 |
+
args: ModelArgs,
|
358 |
+
length: int,
|
359 |
+
device: str | torch.device | None = None,
|
360 |
+
n_layers: int | None = None,
|
361 |
+
dtype: torch.dtype | None = None,
|
362 |
+
) -> list[LayerCache]:
|
363 |
+
"""
|
364 |
+
Allocate a cache to be used with the Transformer module.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
args (ModelArgs): the model configuration.
|
368 |
+
length (int): per layer cache size.
|
369 |
+
It is usually budgeted as ``max_batch * max_seq``
|
370 |
+
device (torch.device, optional): the device on which
|
371 |
+
the cache should be allocated.
|
372 |
+
n_layers (int, optional): the number of layers to
|
373 |
+
allocate a cache for (defaults to the model
|
374 |
+
settings).
|
375 |
+
dtype (torch.dtype, optional): the dtype to use for
|
376 |
+
cache entries (defaults to the default dtype).
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
The cache object to pass to ``Tranformer.forward``.
|
380 |
+
"""
|
381 |
+
|
382 |
+
head_dim = args.dim // args.n_heads
|
383 |
+
n_kv_heads = args.n_kv_heads
|
384 |
+
if n_kv_heads is None:
|
385 |
+
n_kv_heads = args.n_heads
|
386 |
+
n_local_kv_heads = n_kv_heads // args.model_parallel_size
|
387 |
+
|
388 |
+
if n_layers is None:
|
389 |
+
n_layers = args.n_layers
|
390 |
+
|
391 |
+
shape = (1, length, n_local_kv_heads, head_dim)
|
392 |
+
return [
|
393 |
+
(
|
394 |
+
torch.zeros(shape, device=device, dtype=dtype),
|
395 |
+
torch.zeros(shape, device=device, dtype=dtype),
|
396 |
+
)
|
397 |
+
for _ in range(n_layers)
|
398 |
+
]
|
399 |
+
|
400 |
+
|
401 |
+
def cache_prefix(cache: list[LayerCache], length: int) -> list[LayerCache]:
|
402 |
+
"""
|
403 |
+
Take a prefix view of a larger cache.
|
404 |
+
|
405 |
+
The original cache object remains of identical size and valid
|
406 |
+
after the shrinked alias has been used. This function is useful
|
407 |
+
when a cache was allocated for a larger batch size than what is
|
408 |
+
necessary.
|
409 |
+
|
410 |
+
Args:
|
411 |
+
cache: the cache to take a view in.
|
412 |
+
length (int): the desired length
|
413 |
+
|
414 |
+
Returns:
|
415 |
+
A view in the input cache object.
|
416 |
+
"""
|
417 |
+
|
418 |
+
if len(cache) > 0:
|
419 |
+
assert cache[0][0].shape[1] >= length
|
420 |
+
|
421 |
+
return [(ck[:, :length], cv[:, :length]) for ck, cv in cache]
|
chameleon/inference/utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import socket
|
7 |
+
from typing import Generator, Generic, Iterator, TypeVar
|
8 |
+
|
9 |
+
T = TypeVar("T")
|
10 |
+
|
11 |
+
|
12 |
+
class DynamicGenerator(Generic[T]):
|
13 |
+
def __init__(self, gen: Generator[T, None, None]):
|
14 |
+
self.gen = gen
|
15 |
+
|
16 |
+
def __iter__(self) -> Iterator[T]:
|
17 |
+
return self
|
18 |
+
|
19 |
+
def __next__(self) -> T:
|
20 |
+
return next(self.gen)
|
21 |
+
|
22 |
+
|
23 |
+
def advance(iterator: Iterator[T], steps: int):
|
24 |
+
try:
|
25 |
+
for _ in range(steps):
|
26 |
+
next(iterator)
|
27 |
+
except StopIteration:
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
def random_unused_port():
|
32 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
33 |
+
s.bind(("", 0))
|
34 |
+
return s.getsockname()[1]
|
chameleon/inference/vocab.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from functools import cached_property
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class VocabInfo:
|
12 |
+
def __init__(self, vocab_map: dict[str, int]):
|
13 |
+
self.name2val = vocab_map
|
14 |
+
|
15 |
+
self.bos_id = vocab_map.get("<s>")
|
16 |
+
self.eos_id = vocab_map.get("</s>")
|
17 |
+
self.boi_id = vocab_map.get("<racm3:break>")
|
18 |
+
self.eoi_id = vocab_map.get("<eoss>")
|
19 |
+
self.pad_id = vocab_map.get("<pad>")
|
20 |
+
self.eot_id = vocab_map.get("<reserved08706>")
|
21 |
+
|
22 |
+
@property
|
23 |
+
def begin_sequence(self) -> int:
|
24 |
+
return self.bos_id
|
25 |
+
|
26 |
+
@property
|
27 |
+
def end_sequence(self) -> int:
|
28 |
+
return self.eos_id
|
29 |
+
|
30 |
+
@property
|
31 |
+
def begin_image(self) -> int:
|
32 |
+
return self.boi_id
|
33 |
+
|
34 |
+
@property
|
35 |
+
def end_image(self) -> int:
|
36 |
+
return self.eoi_id
|
37 |
+
|
38 |
+
@property
|
39 |
+
def padding(self) -> int:
|
40 |
+
return self.pad_id
|
41 |
+
|
42 |
+
@property
|
43 |
+
def end_turn(self) -> int:
|
44 |
+
return self.eot_id
|
45 |
+
|
46 |
+
@cached_property
|
47 |
+
def val2name(self) -> dict[int, str]:
|
48 |
+
return {v: k for k, v in self.name2val.items()}
|
49 |
+
|
50 |
+
@cached_property
|
51 |
+
def all_tokens(self) -> list[int]:
|
52 |
+
return sorted(self.name2val.values())
|
53 |
+
|
54 |
+
@cached_property
|
55 |
+
def image_tokens(self) -> list[int]:
|
56 |
+
return sorted(
|
57 |
+
[val for name, val in self.name2val.items() if name.startswith("IMGIMG")]
|
58 |
+
)
|
59 |
+
|
60 |
+
@cached_property
|
61 |
+
def special_tokens(self) -> list[int]:
|
62 |
+
return sorted(
|
63 |
+
[
|
64 |
+
val
|
65 |
+
for name, val in self.name2val.items()
|
66 |
+
if name.startswith("<") and name != "<"
|
67 |
+
]
|
68 |
+
)
|
69 |
+
|
70 |
+
@cached_property
|
71 |
+
def text_tokens(self) -> list[int]:
|
72 |
+
return sorted(
|
73 |
+
set(self.all_tokens) - set(self.image_tokens) - set(self.special_tokens)
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class VocabTranslation:
|
78 |
+
def __init__(self, vocab_info: VocabInfo, device: str | None = None):
|
79 |
+
self._vocab = vocab_info
|
80 |
+
self._device = device
|
81 |
+
|
82 |
+
@cached_property
|
83 |
+
def bpe2img(self) -> dict[int, int]: # vocab id => codebook id, i.e. [4:8195] => [0:8191]
|
84 |
+
img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} # A-J: 0-9
|
85 |
+
|
86 |
+
def remap(old_name: str) -> str:
|
87 |
+
return "".join(
|
88 |
+
img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1] # last chr is 'Z'
|
89 |
+
)
|
90 |
+
# e.g.: IMGIMGFDZ => FD => 53,
|
91 |
+
|
92 |
+
return {
|
93 |
+
tok: int(remap(self._vocab.val2name[tok]))
|
94 |
+
for tok in self._vocab.image_tokens # the token starts with 'IMGIMG', value: [4: 8195]
|
95 |
+
}
|
96 |
+
|
97 |
+
@cached_property
|
98 |
+
def img2bpe(self) -> dict[int, int]:
|
99 |
+
return {v: k for k, v in self.bpe2img.items()} # codebook id => vocab id, i.e. [0:8191] => [4:8191]
|
100 |
+
|
101 |
+
@cached_property
|
102 |
+
def bpe2img_search_tensors(self) -> tuple[torch.Tensor, torch.Tensor]:
|
103 |
+
sorted_bpe = torch.tensor(sorted(self.bpe2img.keys()), device=self._device)
|
104 |
+
sorted_img = torch.tensor(sorted(self.bpe2img.values()), device=self._device)
|
105 |
+
return sorted_bpe, sorted_img
|
106 |
+
|
107 |
+
@cached_property
|
108 |
+
def img2bpe_mapping_tensor(self) -> torch.LongTensor:
|
109 |
+
mapping = torch.zeros(
|
110 |
+
max(self.img2bpe.keys()) + 1,
|
111 |
+
dtype=torch.int,
|
112 |
+
device=self._device,
|
113 |
+
)
|
114 |
+
for k, v in self.img2bpe.items():
|
115 |
+
mapping[k] = v
|
116 |
+
return mapping
|
117 |
+
|
118 |
+
def convert_bpe2img(self, bpe_batch: torch.Tensor) -> torch.Tensor:
|
119 |
+
bpe_tok, img_tok = self.bpe2img_search_tensors
|
120 |
+
return img_tok[torch.searchsorted(bpe_tok, bpe_batch)]
|
121 |
+
|
122 |
+
def convert_img2bp2(self, img_batch: torch.Tensor) -> torch.Tensor:
|
123 |
+
return self.img2bpe_mapping_tensor[img_batch]
|
chameleon/inference/vqgan.py
ADDED
@@ -0,0 +1,675 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""
|
7 |
+
Contents of this file are taken from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/models/vqgan.py
|
8 |
+
[with minimal dependencies]
|
9 |
+
|
10 |
+
This implementation is inference-only -- training steps and optimizer components
|
11 |
+
introduce significant additional dependencies
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
class VectorQuantizer2(nn.Module):
|
21 |
+
"""
|
22 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
23 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
24 |
+
"""
|
25 |
+
|
26 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
27 |
+
# backwards compatibility we use the buggy version by default, but you can
|
28 |
+
# specify legacy=False to fix it.
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
n_e,
|
32 |
+
e_dim,
|
33 |
+
beta,
|
34 |
+
remap=None,
|
35 |
+
unknown_index="random",
|
36 |
+
sane_index_shape=False,
|
37 |
+
legacy=True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
self.n_e = n_e
|
41 |
+
self.e_dim = e_dim
|
42 |
+
self.beta = beta
|
43 |
+
self.legacy = legacy
|
44 |
+
|
45 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
46 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
47 |
+
|
48 |
+
self.remap = remap
|
49 |
+
if self.remap is not None:
|
50 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
51 |
+
self.re_embed = self.used.shape[0]
|
52 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
53 |
+
if self.unknown_index == "extra":
|
54 |
+
self.unknown_index = self.re_embed
|
55 |
+
self.re_embed = self.re_embed + 1
|
56 |
+
print(
|
57 |
+
f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
58 |
+
f"Using {self.unknown_index} for unknown indices."
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
self.re_embed = n_e
|
62 |
+
|
63 |
+
self.sane_index_shape = sane_index_shape
|
64 |
+
|
65 |
+
def remap_to_used(self, inds):
|
66 |
+
ishape = inds.shape
|
67 |
+
assert len(ishape) > 1
|
68 |
+
inds = inds.reshape(ishape[0], -1)
|
69 |
+
used = self.used.to(inds)
|
70 |
+
match = (inds[:, :, None] == used[None, None, ...]).long()
|
71 |
+
new = match.argmax(-1)
|
72 |
+
unknown = match.sum(2) < 1
|
73 |
+
if self.unknown_index == "random":
|
74 |
+
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
|
75 |
+
device=new.device
|
76 |
+
)
|
77 |
+
else:
|
78 |
+
new[unknown] = self.unknown_index
|
79 |
+
return new.reshape(ishape)
|
80 |
+
|
81 |
+
def unmap_to_all(self, inds):
|
82 |
+
ishape = inds.shape
|
83 |
+
assert len(ishape) > 1
|
84 |
+
inds = inds.reshape(ishape[0], -1)
|
85 |
+
used = self.used.to(inds)
|
86 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
87 |
+
inds[inds >= self.used.shape[0]] = 0 # simply set to zero
|
88 |
+
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
|
89 |
+
return back.reshape(ishape)
|
90 |
+
|
91 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
92 |
+
assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
|
93 |
+
assert rescale_logits is False, "Only for interface compatible with Gumbel"
|
94 |
+
assert return_logits is False, "Only for interface compatible with Gumbel"
|
95 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
96 |
+
z = z.permute(0, 2, 3, 1).contiguous()
|
97 |
+
z_flattened = z.view(-1, self.e_dim)
|
98 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
99 |
+
|
100 |
+
d = (
|
101 |
+
torch.sum(z_flattened**2, dim=1, keepdim=True)
|
102 |
+
+ torch.sum(self.embedding.weight**2, dim=1)
|
103 |
+
- 2
|
104 |
+
* torch.einsum(
|
105 |
+
"bd,dn->bn", z_flattened, self.embedding.weight.transpose(0, 1)
|
106 |
+
)
|
107 |
+
)
|
108 |
+
|
109 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
110 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
111 |
+
perplexity = None
|
112 |
+
min_encodings = None
|
113 |
+
|
114 |
+
# compute loss for embedding
|
115 |
+
if not self.legacy:
|
116 |
+
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
|
117 |
+
(z_q - z.detach()) ** 2
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
|
121 |
+
(z_q - z.detach()) ** 2
|
122 |
+
)
|
123 |
+
|
124 |
+
# preserve gradients
|
125 |
+
z_q = z + (z_q - z).detach()
|
126 |
+
|
127 |
+
# reshape back to match original input shape
|
128 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
129 |
+
|
130 |
+
if self.remap is not None:
|
131 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
132 |
+
z.shape[0], -1
|
133 |
+
) # add batch axis
|
134 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
135 |
+
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
|
136 |
+
|
137 |
+
if self.sane_index_shape:
|
138 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
139 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3]
|
140 |
+
)
|
141 |
+
|
142 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
143 |
+
|
144 |
+
def get_codebook_entry(self, indices, shape):
|
145 |
+
# shape specifying (batch, height, width, channel)
|
146 |
+
if self.remap is not None:
|
147 |
+
indices = indices.reshape(shape[0], -1) # add batch axis
|
148 |
+
indices = self.unmap_to_all(indices)
|
149 |
+
indices = indices.reshape(-1) # flatten again
|
150 |
+
|
151 |
+
# get quantized latent vectors
|
152 |
+
z_q = self.embedding(indices)
|
153 |
+
|
154 |
+
if shape is not None:
|
155 |
+
z_q = z_q.view(shape)
|
156 |
+
# reshape back to match original input shape
|
157 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
158 |
+
|
159 |
+
return z_q
|
160 |
+
|
161 |
+
|
162 |
+
# Alias
|
163 |
+
VectorQuantizer = VectorQuantizer2
|
164 |
+
|
165 |
+
|
166 |
+
def nonlinearity(x):
|
167 |
+
# swish
|
168 |
+
return x * torch.sigmoid(x)
|
169 |
+
|
170 |
+
|
171 |
+
def Normalize(in_channels, num_groups=32):
|
172 |
+
return torch.nn.GroupNorm(
|
173 |
+
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
174 |
+
)
|
175 |
+
|
176 |
+
|
177 |
+
class Upsample(nn.Module):
|
178 |
+
def __init__(self, in_channels, with_conv):
|
179 |
+
super().__init__()
|
180 |
+
self.with_conv = with_conv
|
181 |
+
if self.with_conv:
|
182 |
+
self.conv = torch.nn.Conv2d(
|
183 |
+
in_channels, in_channels, kernel_size=3, stride=1, padding=1
|
184 |
+
)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
188 |
+
if self.with_conv:
|
189 |
+
x = self.conv(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
class Downsample(nn.Module):
|
194 |
+
def __init__(self, in_channels, with_conv):
|
195 |
+
super().__init__()
|
196 |
+
self.with_conv = with_conv
|
197 |
+
if self.with_conv:
|
198 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
199 |
+
self.conv = torch.nn.Conv2d(
|
200 |
+
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
if self.with_conv:
|
205 |
+
pad = (0, 1, 0, 1)
|
206 |
+
x = F.pad(x, pad, mode="constant", value=0)
|
207 |
+
x = self.conv(x)
|
208 |
+
else:
|
209 |
+
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
210 |
+
return x
|
211 |
+
|
212 |
+
|
213 |
+
class ResnetBlock(nn.Module):
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
*,
|
217 |
+
in_channels,
|
218 |
+
out_channels=None,
|
219 |
+
conv_shortcut=False,
|
220 |
+
dropout,
|
221 |
+
temb_channels=512,
|
222 |
+
):
|
223 |
+
super().__init__()
|
224 |
+
self.in_channels = in_channels
|
225 |
+
out_channels = in_channels if out_channels is None else out_channels
|
226 |
+
self.out_channels = out_channels
|
227 |
+
self.use_conv_shortcut = conv_shortcut
|
228 |
+
|
229 |
+
self.norm1 = Normalize(in_channels)
|
230 |
+
self.conv1 = torch.nn.Conv2d(
|
231 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
232 |
+
)
|
233 |
+
if temb_channels > 0:
|
234 |
+
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
235 |
+
self.norm2 = Normalize(out_channels)
|
236 |
+
self.dropout = torch.nn.Dropout(dropout)
|
237 |
+
self.conv2 = torch.nn.Conv2d(
|
238 |
+
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
239 |
+
)
|
240 |
+
if self.in_channels != self.out_channels:
|
241 |
+
if self.use_conv_shortcut:
|
242 |
+
self.conv_shortcut = torch.nn.Conv2d(
|
243 |
+
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
self.nin_shortcut = torch.nn.Conv2d(
|
247 |
+
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
248 |
+
)
|
249 |
+
|
250 |
+
def forward(self, x, temb):
|
251 |
+
h = x
|
252 |
+
h = self.norm1(h)
|
253 |
+
h = nonlinearity(h)
|
254 |
+
h = self.conv1(h)
|
255 |
+
|
256 |
+
if temb is not None:
|
257 |
+
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
258 |
+
|
259 |
+
h = self.norm2(h)
|
260 |
+
h = nonlinearity(h)
|
261 |
+
h = self.dropout(h)
|
262 |
+
h = self.conv2(h)
|
263 |
+
|
264 |
+
if self.in_channels != self.out_channels:
|
265 |
+
if self.use_conv_shortcut:
|
266 |
+
x = self.conv_shortcut(x)
|
267 |
+
else:
|
268 |
+
x = self.nin_shortcut(x)
|
269 |
+
|
270 |
+
return x + h
|
271 |
+
|
272 |
+
|
273 |
+
class AttnBlock(nn.Module):
|
274 |
+
def __init__(self, in_channels):
|
275 |
+
super().__init__()
|
276 |
+
self.in_channels = in_channels
|
277 |
+
|
278 |
+
self.norm = Normalize(in_channels)
|
279 |
+
self.q = torch.nn.Conv2d(
|
280 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
281 |
+
)
|
282 |
+
self.k = torch.nn.Conv2d(
|
283 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
284 |
+
)
|
285 |
+
self.v = torch.nn.Conv2d(
|
286 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
287 |
+
)
|
288 |
+
self.proj_out = torch.nn.Conv2d(
|
289 |
+
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
290 |
+
)
|
291 |
+
|
292 |
+
def forward(self, x):
|
293 |
+
h_ = x
|
294 |
+
h_ = self.norm(h_)
|
295 |
+
q = self.q(h_)
|
296 |
+
k = self.k(h_)
|
297 |
+
v = self.v(h_)
|
298 |
+
|
299 |
+
# compute attention
|
300 |
+
b, c, h, w = q.shape
|
301 |
+
q = q.reshape(b, c, h * w)
|
302 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
303 |
+
k = k.reshape(b, c, h * w) # b,c,hw
|
304 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
305 |
+
w_ = w_ * (int(c) ** (-0.5))
|
306 |
+
w_ = F.softmax(w_, dim=2)
|
307 |
+
|
308 |
+
# attend to values
|
309 |
+
v = v.reshape(b, c, h * w)
|
310 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
311 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
312 |
+
h_ = h_.reshape(b, c, h, w)
|
313 |
+
|
314 |
+
h_ = self.proj_out(h_)
|
315 |
+
|
316 |
+
return x + h_
|
317 |
+
|
318 |
+
|
319 |
+
def make_attn(in_channels, attn_type="vanilla"):
|
320 |
+
assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
|
321 |
+
# print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
322 |
+
if attn_type == "vanilla":
|
323 |
+
return AttnBlock(in_channels)
|
324 |
+
elif attn_type == "none":
|
325 |
+
return nn.Identity(in_channels)
|
326 |
+
else:
|
327 |
+
raise ValueError("Unexpected attention type")
|
328 |
+
|
329 |
+
|
330 |
+
class Encoder(nn.Module):
|
331 |
+
def __init__(
|
332 |
+
self,
|
333 |
+
*,
|
334 |
+
ch,
|
335 |
+
out_ch,
|
336 |
+
ch_mult=(1, 2, 4, 8),
|
337 |
+
num_res_blocks,
|
338 |
+
attn_resolutions,
|
339 |
+
dropout=0.0,
|
340 |
+
resamp_with_conv=True,
|
341 |
+
in_channels,
|
342 |
+
resolution,
|
343 |
+
z_channels,
|
344 |
+
double_z=True,
|
345 |
+
use_linear_attn=False,
|
346 |
+
attn_type="vanilla",
|
347 |
+
**ignore_kwargs,
|
348 |
+
):
|
349 |
+
super().__init__()
|
350 |
+
if use_linear_attn:
|
351 |
+
attn_type = "linear"
|
352 |
+
self.ch = ch
|
353 |
+
self.temb_ch = 0
|
354 |
+
self.num_resolutions = len(ch_mult)
|
355 |
+
self.num_res_blocks = num_res_blocks
|
356 |
+
self.resolution = resolution
|
357 |
+
self.in_channels = in_channels
|
358 |
+
|
359 |
+
# downsampling
|
360 |
+
self.conv_in = torch.nn.Conv2d(
|
361 |
+
in_channels, self.ch, kernel_size=3, stride=1, padding=1
|
362 |
+
)
|
363 |
+
|
364 |
+
curr_res = resolution
|
365 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
366 |
+
self.in_ch_mult = in_ch_mult
|
367 |
+
self.down = nn.ModuleList()
|
368 |
+
for i_level in range(self.num_resolutions):
|
369 |
+
block = nn.ModuleList()
|
370 |
+
attn = nn.ModuleList()
|
371 |
+
block_in = ch * in_ch_mult[i_level]
|
372 |
+
block_out = ch * ch_mult[i_level]
|
373 |
+
for i_block in range(self.num_res_blocks):
|
374 |
+
block.append(
|
375 |
+
ResnetBlock(
|
376 |
+
in_channels=block_in,
|
377 |
+
out_channels=block_out,
|
378 |
+
temb_channels=self.temb_ch,
|
379 |
+
dropout=dropout,
|
380 |
+
)
|
381 |
+
)
|
382 |
+
block_in = block_out
|
383 |
+
if curr_res in attn_resolutions:
|
384 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
385 |
+
down = nn.Module()
|
386 |
+
down.block = block
|
387 |
+
down.attn = attn
|
388 |
+
if i_level != self.num_resolutions - 1:
|
389 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
390 |
+
curr_res = curr_res // 2
|
391 |
+
self.down.append(down)
|
392 |
+
|
393 |
+
# middle
|
394 |
+
self.mid = nn.Module()
|
395 |
+
self.mid.block_1 = ResnetBlock(
|
396 |
+
in_channels=block_in,
|
397 |
+
out_channels=block_in,
|
398 |
+
temb_channels=self.temb_ch,
|
399 |
+
dropout=dropout,
|
400 |
+
)
|
401 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
402 |
+
self.mid.block_2 = ResnetBlock(
|
403 |
+
in_channels=block_in,
|
404 |
+
out_channels=block_in,
|
405 |
+
temb_channels=self.temb_ch,
|
406 |
+
dropout=dropout,
|
407 |
+
)
|
408 |
+
|
409 |
+
# end
|
410 |
+
self.norm_out = Normalize(block_in)
|
411 |
+
self.conv_out = torch.nn.Conv2d(
|
412 |
+
block_in,
|
413 |
+
2 * z_channels if double_z else z_channels,
|
414 |
+
kernel_size=3,
|
415 |
+
stride=1,
|
416 |
+
padding=1,
|
417 |
+
)
|
418 |
+
|
419 |
+
def forward(self, x):
|
420 |
+
# timestep embedding
|
421 |
+
temb = None
|
422 |
+
|
423 |
+
# downsampling
|
424 |
+
hs = [self.conv_in(x)]
|
425 |
+
for i_level in range(self.num_resolutions):
|
426 |
+
for i_block in range(self.num_res_blocks):
|
427 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
428 |
+
if len(self.down[i_level].attn) > 0:
|
429 |
+
h = self.down[i_level].attn[i_block](h)
|
430 |
+
hs.append(h)
|
431 |
+
if i_level != self.num_resolutions - 1:
|
432 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
433 |
+
|
434 |
+
# middle
|
435 |
+
h = hs[-1]
|
436 |
+
h = self.mid.block_1(h, temb)
|
437 |
+
h = self.mid.attn_1(h)
|
438 |
+
h = self.mid.block_2(h, temb)
|
439 |
+
|
440 |
+
# end
|
441 |
+
h = self.norm_out(h)
|
442 |
+
h = nonlinearity(h)
|
443 |
+
h = self.conv_out(h)
|
444 |
+
return h
|
445 |
+
|
446 |
+
|
447 |
+
class Decoder(nn.Module):
|
448 |
+
def __init__(
|
449 |
+
self,
|
450 |
+
*,
|
451 |
+
ch,
|
452 |
+
out_ch,
|
453 |
+
ch_mult=(1, 2, 4, 8),
|
454 |
+
num_res_blocks,
|
455 |
+
attn_resolutions,
|
456 |
+
dropout=0.0,
|
457 |
+
resamp_with_conv=True,
|
458 |
+
in_channels,
|
459 |
+
resolution,
|
460 |
+
z_channels,
|
461 |
+
give_pre_end=False,
|
462 |
+
tanh_out=False,
|
463 |
+
use_linear_attn=False,
|
464 |
+
attn_type="vanilla",
|
465 |
+
**ignorekwargs,
|
466 |
+
):
|
467 |
+
super().__init__()
|
468 |
+
if use_linear_attn:
|
469 |
+
attn_type = "linear"
|
470 |
+
self.ch = ch
|
471 |
+
self.temb_ch = 0
|
472 |
+
self.num_resolutions = len(ch_mult)
|
473 |
+
self.num_res_blocks = num_res_blocks
|
474 |
+
self.resolution = resolution
|
475 |
+
self.in_channels = in_channels
|
476 |
+
self.give_pre_end = give_pre_end
|
477 |
+
self.tanh_out = tanh_out
|
478 |
+
|
479 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
480 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
481 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
482 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
483 |
+
|
484 |
+
# z to block_in
|
485 |
+
self.conv_in = torch.nn.Conv2d(
|
486 |
+
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
487 |
+
)
|
488 |
+
|
489 |
+
# middle
|
490 |
+
self.mid = nn.Module()
|
491 |
+
self.mid.block_1 = ResnetBlock(
|
492 |
+
in_channels=block_in,
|
493 |
+
out_channels=block_in,
|
494 |
+
temb_channels=self.temb_ch,
|
495 |
+
dropout=dropout,
|
496 |
+
)
|
497 |
+
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
498 |
+
self.mid.block_2 = ResnetBlock(
|
499 |
+
in_channels=block_in,
|
500 |
+
out_channels=block_in,
|
501 |
+
temb_channels=self.temb_ch,
|
502 |
+
dropout=dropout,
|
503 |
+
)
|
504 |
+
|
505 |
+
# upsampling
|
506 |
+
self.up = nn.ModuleList()
|
507 |
+
for i_level in reversed(range(self.num_resolutions)):
|
508 |
+
block = nn.ModuleList()
|
509 |
+
attn = nn.ModuleList()
|
510 |
+
block_out = ch * ch_mult[i_level]
|
511 |
+
for i_block in range(self.num_res_blocks + 1):
|
512 |
+
block.append(
|
513 |
+
ResnetBlock(
|
514 |
+
in_channels=block_in,
|
515 |
+
out_channels=block_out,
|
516 |
+
temb_channels=self.temb_ch,
|
517 |
+
dropout=dropout,
|
518 |
+
)
|
519 |
+
)
|
520 |
+
block_in = block_out
|
521 |
+
if curr_res in attn_resolutions:
|
522 |
+
attn.append(make_attn(block_in, attn_type=attn_type))
|
523 |
+
up = nn.Module()
|
524 |
+
up.block = block
|
525 |
+
up.attn = attn
|
526 |
+
if i_level != 0:
|
527 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
528 |
+
curr_res = curr_res * 2
|
529 |
+
self.up.insert(0, up) # prepend to get consistent order
|
530 |
+
|
531 |
+
# end
|
532 |
+
self.norm_out = Normalize(block_in)
|
533 |
+
self.conv_out = torch.nn.Conv2d(
|
534 |
+
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
535 |
+
)
|
536 |
+
|
537 |
+
def forward(self, z):
|
538 |
+
# assert z.shape[1:] == self.z_shape[1:]
|
539 |
+
self.last_z_shape = z.shape
|
540 |
+
|
541 |
+
# timestep embedding
|
542 |
+
temb = None
|
543 |
+
|
544 |
+
# z to block_in
|
545 |
+
h = self.conv_in(z)
|
546 |
+
|
547 |
+
# middle
|
548 |
+
h = self.mid.block_1(h, temb)
|
549 |
+
h = self.mid.attn_1(h)
|
550 |
+
h = self.mid.block_2(h, temb)
|
551 |
+
|
552 |
+
# upsampling
|
553 |
+
for i_level in reversed(range(self.num_resolutions)):
|
554 |
+
for i_block in range(self.num_res_blocks + 1):
|
555 |
+
h = self.up[i_level].block[i_block](h, temb)
|
556 |
+
if len(self.up[i_level].attn) > 0:
|
557 |
+
h = self.up[i_level].attn[i_block](h)
|
558 |
+
if i_level != 0:
|
559 |
+
h = self.up[i_level].upsample(h)
|
560 |
+
|
561 |
+
# end
|
562 |
+
if self.give_pre_end:
|
563 |
+
return h
|
564 |
+
|
565 |
+
h = self.norm_out(h)
|
566 |
+
h = nonlinearity(h)
|
567 |
+
h = self.conv_out(h)
|
568 |
+
if self.tanh_out:
|
569 |
+
h = torch.tanh(h)
|
570 |
+
return h
|
571 |
+
|
572 |
+
|
573 |
+
class VQModel(nn.Module):
|
574 |
+
def __init__(
|
575 |
+
self,
|
576 |
+
ddconfig,
|
577 |
+
n_embed,
|
578 |
+
embed_dim,
|
579 |
+
ckpt_path=None,
|
580 |
+
ignore_keys=[],
|
581 |
+
image_key="image",
|
582 |
+
colorize_nlabels=None,
|
583 |
+
monitor=None,
|
584 |
+
scheduler_config=None,
|
585 |
+
lr_g_factor=1.0,
|
586 |
+
remap=None,
|
587 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
588 |
+
):
|
589 |
+
super().__init__()
|
590 |
+
self.image_key = image_key
|
591 |
+
self.encoder = Encoder(**ddconfig)
|
592 |
+
self.decoder = Decoder(**ddconfig)
|
593 |
+
self.quantize = VectorQuantizer(
|
594 |
+
n_embed,
|
595 |
+
embed_dim,
|
596 |
+
beta=0.25,
|
597 |
+
remap=remap,
|
598 |
+
sane_index_shape=sane_index_shape,
|
599 |
+
)
|
600 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
601 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
602 |
+
if ckpt_path is not None:
|
603 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
604 |
+
self.image_key = image_key
|
605 |
+
if colorize_nlabels is not None:
|
606 |
+
assert isinstance(colorize_nlabels, int)
|
607 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
608 |
+
if monitor is not None:
|
609 |
+
self.monitor = monitor
|
610 |
+
self.scheduler_config = scheduler_config
|
611 |
+
self.lr_g_factor = lr_g_factor
|
612 |
+
|
613 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
614 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
615 |
+
keys = list(sd.keys())
|
616 |
+
for k in keys:
|
617 |
+
for ik in ignore_keys:
|
618 |
+
if k.startswith(ik):
|
619 |
+
print("Deleting key {} from state_dict.".format(k))
|
620 |
+
del sd[k]
|
621 |
+
self.load_state_dict(sd, strict=False)
|
622 |
+
print(f"VQModel loaded from {path}")
|
623 |
+
|
624 |
+
def encode(self, x):
|
625 |
+
h = self.encoder(x)
|
626 |
+
h = self.quant_conv(h)
|
627 |
+
quant, emb_loss, info = self.quantize(h)
|
628 |
+
return quant, emb_loss, info
|
629 |
+
|
630 |
+
def decode(self, quant):
|
631 |
+
quant = self.post_quant_conv(quant)
|
632 |
+
dec = self.decoder(quant)
|
633 |
+
return dec
|
634 |
+
|
635 |
+
def decode_code(self, code_b):
|
636 |
+
quant_b = self.quantize.embed_code(code_b)
|
637 |
+
dec = self.decode(quant_b)
|
638 |
+
return dec
|
639 |
+
|
640 |
+
def forward(self, input):
|
641 |
+
quant, diff, _ = self.encode(input)
|
642 |
+
dec = self.decode(quant)
|
643 |
+
return dec, diff
|
644 |
+
|
645 |
+
def get_input(self, batch, k):
|
646 |
+
x = batch[k]
|
647 |
+
if len(x.shape) == 3:
|
648 |
+
x = x[..., None]
|
649 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
650 |
+
return x.float()
|
651 |
+
|
652 |
+
def get_last_layer(self):
|
653 |
+
return self.decoder.conv_out.weight
|
654 |
+
|
655 |
+
def log_images(self, batch, **kwargs):
|
656 |
+
log = dict()
|
657 |
+
x = self.get_input(batch, self.image_key)
|
658 |
+
x = x.to(self.device)
|
659 |
+
xrec, _ = self(x)
|
660 |
+
if x.shape[1] > 3:
|
661 |
+
# colorize with random projection
|
662 |
+
assert xrec.shape[1] > 3
|
663 |
+
x = self.to_rgb(x)
|
664 |
+
xrec = self.to_rgb(xrec)
|
665 |
+
log["inputs"] = x
|
666 |
+
log["reconstructions"] = xrec
|
667 |
+
return log
|
668 |
+
|
669 |
+
def to_rgb(self, x):
|
670 |
+
assert self.image_key == "segmentation"
|
671 |
+
if not hasattr(self, "colorize"):
|
672 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
673 |
+
x = F.conv2d(x, weight=self.colorize)
|
674 |
+
x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
|
675 |
+
return x
|
chameleon/miniviewer/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/miniviewer/__main__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from chameleon.miniviewer.miniviewer import main
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
main()
|
chameleon/miniviewer/miniviewer.html
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- Copyright (c) Meta Platforms, Inc. and affiliates. -->
|
2 |
+
|
3 |
+
<!-- This source code is licensed under the Chameleon License found in the -->
|
4 |
+
<!-- LICENSE file in the root directory of this source tree. -->
|
5 |
+
<h1>
|
6 |
+
<div id="connection-status"></div>
|
7 |
+
MiniViewer:
|
8 |
+
</h1>
|
9 |
+
<div class="container">
|
10 |
+
<div class="sidebar with-padding">
|
11 |
+
<h4>Input Controls</h4>
|
12 |
+
<div class="input-controls-container">
|
13 |
+
<button class="button" onclick="addInput('text')">Add text input</button>
|
14 |
+
<button class="button" onclick="addInput('image')">
|
15 |
+
Add image input
|
16 |
+
</button>
|
17 |
+
<button class="button" onclick="addInput('<END-OF-TURN>')">
|
18 |
+
Add end-of-turn token
|
19 |
+
</button>
|
20 |
+
</div>
|
21 |
+
<hr />
|
22 |
+
<h4>General Options</h4>
|
23 |
+
<div class="option">
|
24 |
+
<label for="seed">seed</label>
|
25 |
+
<input type="number" id="seed" value="0" />
|
26 |
+
</div>
|
27 |
+
<div class="option">
|
28 |
+
<label for="max-seq-len">max sequence length</label>
|
29 |
+
<input type="number" id="max-seq-len" value="4096" />
|
30 |
+
</div>
|
31 |
+
<div class="option">
|
32 |
+
<label for="max-gen-len">max generation length</label>
|
33 |
+
<input type="number" id="max-gen-len" value="4096" />
|
34 |
+
</div>
|
35 |
+
<h4>
|
36 |
+
<input type="checkbox" id="enable-text" name="enable-text" checked />
|
37 |
+
<label for="enable-text">Text Decoder Options</label>
|
38 |
+
</h4>
|
39 |
+
<div class="option">
|
40 |
+
<label for="text-rep-penalty">repetition penalty</label>
|
41 |
+
<input type="number" id="text-rep-penalty" value="1.2" step="0.01" />
|
42 |
+
</div>
|
43 |
+
<div class="option">
|
44 |
+
<label for="text-temp">temperature</label>
|
45 |
+
<input type="number" id="text-temp" value="0.7" step="0.01" />
|
46 |
+
</div>
|
47 |
+
<div class="option">
|
48 |
+
<label for="text-top-p">top-p</label>
|
49 |
+
<input type="number" id="text-top-p" value="0.9" step="0.01" />
|
50 |
+
</div>
|
51 |
+
<h4>
|
52 |
+
<input type="checkbox" id="enable-image" name="enable-image" checked />
|
53 |
+
<label for="enable-image">Image Decoder Options</label>
|
54 |
+
</h4>
|
55 |
+
<div class="option">
|
56 |
+
<label for="img-cfg-gstext">cfg text</label>
|
57 |
+
<input type="number" id="img-cfg-gstext" value="3.0" step="0.01" />
|
58 |
+
</div>
|
59 |
+
<div class="option">
|
60 |
+
<label for="img-cfg-gsimage">cfg image</label>
|
61 |
+
<input type="number" id="img-cfg-gsimage" value="1.2" step="0.01" />
|
62 |
+
</div>
|
63 |
+
<div class="option">
|
64 |
+
<label for="img-temp">temperature</label>
|
65 |
+
<input type="number" id="img-temp" value="0.7" step="0.01" />
|
66 |
+
</div>
|
67 |
+
<div class="option">
|
68 |
+
<label for="img-top-p">top-p</label>
|
69 |
+
<input type="number" id="img-top-p" value="0.9" step="0.01" />
|
70 |
+
</div>
|
71 |
+
</div>
|
72 |
+
<div class="content with-padding">
|
73 |
+
<div class="input-wrapper">
|
74 |
+
Inputs:
|
75 |
+
<div id="inputs" class="with-padding"></div>
|
76 |
+
</div>
|
77 |
+
<h4>
|
78 |
+
<button id="generate" class="button" onclick="generate()">
|
79 |
+
Generate
|
80 |
+
</button>
|
81 |
+
<button
|
82 |
+
id="cancel"
|
83 |
+
class="button"
|
84 |
+
onclick="cancel()"
|
85 |
+
style="display: none"
|
86 |
+
>
|
87 |
+
Cancel
|
88 |
+
</button>
|
89 |
+
</h4>
|
90 |
+
Results:
|
91 |
+
<pre id="results" class="with-padding"></pre>
|
92 |
+
<div id="timing" class="with-padding"></div>
|
93 |
+
<div id="queue" class="with-padding"></div>
|
94 |
+
</div>
|
95 |
+
</div>
|
96 |
+
|
97 |
+
<style>
|
98 |
+
.container {
|
99 |
+
display: inline-flex;
|
100 |
+
}
|
101 |
+
|
102 |
+
.sidebar {
|
103 |
+
flex: 0 0 200px;
|
104 |
+
border-right: 2px solid #ddd;
|
105 |
+
}
|
106 |
+
|
107 |
+
#connection-status {
|
108 |
+
width: 20px;
|
109 |
+
height: 20px;
|
110 |
+
border-radius: 10px;
|
111 |
+
background-color: grey;
|
112 |
+
display: inline-block;
|
113 |
+
}
|
114 |
+
|
115 |
+
.input-controls-container {
|
116 |
+
display: inline-grid;
|
117 |
+
}
|
118 |
+
|
119 |
+
.option {
|
120 |
+
display: flex;
|
121 |
+
margin-bottom: 5px;
|
122 |
+
}
|
123 |
+
|
124 |
+
.option label {
|
125 |
+
white-space: nowrap;
|
126 |
+
margin-right: 10px;
|
127 |
+
}
|
128 |
+
|
129 |
+
.option input {
|
130 |
+
flex-grow: 1;
|
131 |
+
text-align: right;
|
132 |
+
}
|
133 |
+
|
134 |
+
.content {
|
135 |
+
width: 100%;
|
136 |
+
}
|
137 |
+
|
138 |
+
.with-padding {
|
139 |
+
padding: 10px;
|
140 |
+
}
|
141 |
+
|
142 |
+
.input-wrapper {
|
143 |
+
border: dotted;
|
144 |
+
}
|
145 |
+
|
146 |
+
.input-container {
|
147 |
+
display: flex;
|
148 |
+
align-items: center;
|
149 |
+
}
|
150 |
+
|
151 |
+
.input-controls {
|
152 |
+
display: inline-flex;
|
153 |
+
padding: 2px;
|
154 |
+
}
|
155 |
+
|
156 |
+
#results {
|
157 |
+
background: lightgray;
|
158 |
+
}
|
159 |
+
|
160 |
+
button {
|
161 |
+
text-align: left;
|
162 |
+
}
|
163 |
+
|
164 |
+
img {
|
165 |
+
width: 200px;
|
166 |
+
height: 200px;
|
167 |
+
}
|
168 |
+
</style>
|
169 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.6.0/socket.io.min.js"></script>
|
170 |
+
<script>
|
171 |
+
var active_key;
|
172 |
+
var socket;
|
173 |
+
|
174 |
+
function createButton(text, onClick) {
|
175 |
+
var button = document.createElement("button");
|
176 |
+
button.textContent = text;
|
177 |
+
button.onclick = onClick;
|
178 |
+
return button;
|
179 |
+
}
|
180 |
+
|
181 |
+
function removeInput(evt) {
|
182 |
+
var inputWrapper = evt.target.parentNode.parentNode;
|
183 |
+
inputWrapper.parentNode.removeChild(inputWrapper);
|
184 |
+
}
|
185 |
+
|
186 |
+
function moveInputUp(evt) {
|
187 |
+
var inputWrapper = evt.target.parentNode.parentNode;
|
188 |
+
var prev = inputWrapper.previousElementSibling;
|
189 |
+
if (prev) {
|
190 |
+
inputWrapper.parentNode.insertBefore(inputWrapper, prev);
|
191 |
+
}
|
192 |
+
}
|
193 |
+
|
194 |
+
function moveInputDown(evt) {
|
195 |
+
var inputWrapper = evt.target.parentNode.parentNode;
|
196 |
+
var next = inputWrapper.nextElementSibling;
|
197 |
+
if (next) {
|
198 |
+
inputWrapper.parentNode.insertBefore(next, inputWrapper);
|
199 |
+
}
|
200 |
+
}
|
201 |
+
|
202 |
+
function readFileAsync(file) {
|
203 |
+
return new Promise((resolve, reject) => {
|
204 |
+
let reader = new FileReader();
|
205 |
+
reader.onload = () => resolve(reader.result);
|
206 |
+
reader.onerror = reject;
|
207 |
+
reader.readAsDataURL(file);
|
208 |
+
});
|
209 |
+
}
|
210 |
+
|
211 |
+
async function loadImageSource(dataTransfer) {
|
212 |
+
if (dataTransfer.files.length > 0) {
|
213 |
+
return await readFileAsync(dataTransfer.files[0]);
|
214 |
+
}
|
215 |
+
|
216 |
+
let htmlContent = dataTransfer.getData("text/html");
|
217 |
+
if (htmlContent) {
|
218 |
+
const div = document.createElement("div");
|
219 |
+
div.innerHTML = htmlContent;
|
220 |
+
return div.querySelector("img").src;
|
221 |
+
}
|
222 |
+
|
223 |
+
return (
|
224 |
+
dataTransfer.getData("text/uri-list") ||
|
225 |
+
dataTransfer.getData("text/plain")
|
226 |
+
);
|
227 |
+
}
|
228 |
+
|
229 |
+
async function showPreview(evt) {
|
230 |
+
var wrapper = evt.target.parentElement;
|
231 |
+
wrapper.querySelector("img").src = await loadImageSource(evt.target);
|
232 |
+
wrapper.querySelector("img").style.display = "block";
|
233 |
+
wrapper.querySelector("p").style.display = "none";
|
234 |
+
}
|
235 |
+
|
236 |
+
async function handleDrop(evt) {
|
237 |
+
evt.preventDefault();
|
238 |
+
var wrapper = evt.target.parentElement;
|
239 |
+
var file = evt.dataTransfer.files[0];
|
240 |
+
var fileInput = wrapper.querySelector('input[type="file"]');
|
241 |
+
fileInput.files = evt.dataTransfer.files;
|
242 |
+
wrapper.querySelector("img").src = await loadImageSource(evt.dataTransfer);
|
243 |
+
wrapper.querySelector("img").style.display = "block";
|
244 |
+
wrapper.querySelector("p").style.display = "none";
|
245 |
+
}
|
246 |
+
|
247 |
+
function addInput(input_kind) {
|
248 |
+
var inputs_div = document.getElementById("inputs");
|
249 |
+
var wrapper = document.createElement("div");
|
250 |
+
wrapper.kind = input_kind;
|
251 |
+
wrapper.className = "input-container";
|
252 |
+
|
253 |
+
var new_inputs = [];
|
254 |
+
if (input_kind === "text") {
|
255 |
+
new_inputs.push(document.createElement("textarea"));
|
256 |
+
} else if (input_kind === "image") {
|
257 |
+
wrapper.setAttribute("draggable", true);
|
258 |
+
wrapper.ondragover = (evt) => evt.preventDefault();
|
259 |
+
wrapper.ondrop = handleDrop;
|
260 |
+
|
261 |
+
var hiddenImageFromFile = document.createElement("input");
|
262 |
+
hiddenImageFromFile.type = "file";
|
263 |
+
hiddenImageFromFile.accept = "image/*";
|
264 |
+
hiddenImageFromFile.addEventListener("change", showPreview);
|
265 |
+
hiddenImageFromFile.style.display = "none";
|
266 |
+
wrapper.onclick = function () {
|
267 |
+
hiddenImageFromFile.click();
|
268 |
+
};
|
269 |
+
new_inputs.push(hiddenImageFromFile);
|
270 |
+
|
271 |
+
var description = document.createElement("p");
|
272 |
+
description.textContent =
|
273 |
+
"Drag and drop your image here, or click to select.";
|
274 |
+
new_inputs.push(description);
|
275 |
+
|
276 |
+
var preview = document.createElement("img");
|
277 |
+
preview.style.display = "none";
|
278 |
+
new_inputs.push(preview);
|
279 |
+
} else {
|
280 |
+
var span = document.createElement("span");
|
281 |
+
span.textContent = input_kind;
|
282 |
+
new_inputs.push(span);
|
283 |
+
}
|
284 |
+
|
285 |
+
const input_controls = document.createElement("div");
|
286 |
+
input_controls.className = "input-controls";
|
287 |
+
input_controls.appendChild(createButton("-", removeInput));
|
288 |
+
input_controls.appendChild(createButton("↓", moveInputDown));
|
289 |
+
input_controls.appendChild(createButton("↑", moveInputUp));
|
290 |
+
|
291 |
+
wrapper.appendChild(input_controls);
|
292 |
+
for (var new_input of new_inputs) {
|
293 |
+
wrapper.appendChild(new_input);
|
294 |
+
}
|
295 |
+
wrapper.appendChild(document.createElement("br"));
|
296 |
+
|
297 |
+
inputs_div.appendChild(wrapper);
|
298 |
+
}
|
299 |
+
|
300 |
+
async function generate() {
|
301 |
+
document.getElementById("generate").style.display = "none";
|
302 |
+
document.getElementById("cancel").style.display = "block";
|
303 |
+
document.getElementById("results").innerHTML = "";
|
304 |
+
document.getElementById("timing").innerHTML = "";
|
305 |
+
document.getElementById("queue").innerHTML = "";
|
306 |
+
|
307 |
+
active_key = `key_${Math.random()
|
308 |
+
.toString(36)
|
309 |
+
.substring(2, 11)}_${Date.now()}`;
|
310 |
+
|
311 |
+
const user_options = {};
|
312 |
+
for (const option of document.getElementsByClassName("option")) {
|
313 |
+
const input = option.querySelector("input");
|
314 |
+
user_options[input.id] = Number(input.value);
|
315 |
+
}
|
316 |
+
|
317 |
+
user_options["enable-text"] =
|
318 |
+
document.getElementById("enable-text").checked;
|
319 |
+
user_options["enable-image"] =
|
320 |
+
document.getElementById("enable-image").checked;
|
321 |
+
|
322 |
+
const user_inputs = [];
|
323 |
+
const inputs_div = document.getElementById("inputs");
|
324 |
+
|
325 |
+
const input_elems = Array.from(inputs_div.children).map((wrapper) =>
|
326 |
+
wrapper.querySelector("textarea, input, span")
|
327 |
+
);
|
328 |
+
|
329 |
+
const image_promises = Array.from(inputs_div.children)
|
330 |
+
.filter((wrapper) => wrapper.kind === "image")
|
331 |
+
.map((wrapper) => {
|
332 |
+
const file_input = wrapper.querySelector('input[type="file"]');
|
333 |
+
return file_input.files[0]
|
334 |
+
? readFileAsync(file_input.files[0])
|
335 |
+
: Promise.resolve(null);
|
336 |
+
});
|
337 |
+
|
338 |
+
const images = await Promise.all(image_promises);
|
339 |
+
|
340 |
+
for (const wrapper of inputs_div.children) {
|
341 |
+
if (wrapper.kind === "text") {
|
342 |
+
user_inputs.push({
|
343 |
+
type: "text",
|
344 |
+
value: wrapper.querySelector("textarea").value,
|
345 |
+
});
|
346 |
+
} else if (wrapper.kind === "image") {
|
347 |
+
user_inputs.push({ type: "image", value: images.shift() });
|
348 |
+
} else {
|
349 |
+
user_inputs.push({ type: "sentinel", value: wrapper.kind });
|
350 |
+
}
|
351 |
+
}
|
352 |
+
|
353 |
+
socket.emit("generate", active_key, user_options, user_inputs);
|
354 |
+
}
|
355 |
+
|
356 |
+
function cancel() {
|
357 |
+
document.getElementById("generate").style.display = "block";
|
358 |
+
document.getElementById("cancel").style.display = "none";
|
359 |
+
document.getElementById("queue").innerHTML = "";
|
360 |
+
socket.emit("cancel", active_key);
|
361 |
+
active_key = null;
|
362 |
+
}
|
363 |
+
|
364 |
+
function connectSocket() {
|
365 |
+
socket = io();
|
366 |
+
|
367 |
+
socket.on("connect", function() {
|
368 |
+
document.getElementById("connection-status").style.backgroundColor = 'green';
|
369 |
+
});
|
370 |
+
|
371 |
+
socket.on("disconnect", function(reason) {
|
372 |
+
cancel();
|
373 |
+
document.getElementById("connection-status").style.backgroundColor = 'red';
|
374 |
+
});
|
375 |
+
|
376 |
+
socket.on("progress", function (data) {
|
377 |
+
if (data.key != active_key) {
|
378 |
+
return;
|
379 |
+
}
|
380 |
+
|
381 |
+
document.getElementById("queue").innerHTML = "";
|
382 |
+
if (data.type == "queue") {
|
383 |
+
document.getElementById(
|
384 |
+
"queue"
|
385 |
+
).innerHTML = `queue position ${data.value}`;
|
386 |
+
}
|
387 |
+
|
388 |
+
if (data.type == "text") {
|
389 |
+
document.getElementById("results").innerHTML += data.value;
|
390 |
+
} else if (data.type == "image_start") {
|
391 |
+
document.getElementById("results").appendChild(new Image());
|
392 |
+
} else if (data.type == "image") {
|
393 |
+
document.getElementById("results").lastElementChild.src = data.value;
|
394 |
+
} else if (data.type == "image_end") {
|
395 |
+
} else if (data.type == "done") {
|
396 |
+
document.getElementById(
|
397 |
+
"timing"
|
398 |
+
).innerHTML = `Generation time: ${data.value.toFixed(2)} sec`;
|
399 |
+
document.getElementById("generate").style.display = "block";
|
400 |
+
document.getElementById("cancel").style.display = "none";
|
401 |
+
active_key = null;
|
402 |
+
}
|
403 |
+
});
|
404 |
+
}
|
405 |
+
|
406 |
+
window.onload = (evt) => {
|
407 |
+
connectSocket();
|
408 |
+
};
|
409 |
+
</script>
|
chameleon/miniviewer/miniviewer.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import base64
|
7 |
+
import os
|
8 |
+
import threading
|
9 |
+
import time
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from enum import Enum
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import click
|
15 |
+
import torch
|
16 |
+
from flask import Flask, request
|
17 |
+
from flask_socketio import SocketIO
|
18 |
+
|
19 |
+
from chameleon.inference.chameleon import ChameleonInferenceModel, Options, TokenManager
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Request:
|
24 |
+
room: str
|
25 |
+
key: str
|
26 |
+
options: dict[str, int | float | bool]
|
27 |
+
prompt_ui: list[dict]
|
28 |
+
|
29 |
+
|
30 |
+
def convert_options(ui_options: dict) -> Options:
|
31 |
+
txt = None
|
32 |
+
if ui_options["enable-text"]:
|
33 |
+
txt = Options.Text(
|
34 |
+
repetition_penalty=ui_options["text-rep-penalty"],
|
35 |
+
temp=ui_options["text-temp"],
|
36 |
+
top_p=ui_options["text-top-p"],
|
37 |
+
)
|
38 |
+
img = None
|
39 |
+
if ui_options["enable-image"]:
|
40 |
+
img = Options.Image(
|
41 |
+
cfg=Options.Image.CFG(
|
42 |
+
guidance_scale_image=ui_options["img-cfg-gsimage"],
|
43 |
+
guidance_scale_text=ui_options["img-cfg-gstext"],
|
44 |
+
),
|
45 |
+
temp=ui_options["img-temp"],
|
46 |
+
top_p=ui_options["img-top-p"],
|
47 |
+
)
|
48 |
+
return Options(
|
49 |
+
max_seq_len=ui_options["max-seq-len"],
|
50 |
+
max_gen_len=ui_options["max-gen-len"],
|
51 |
+
seed=ui_options["seed"],
|
52 |
+
txt=txt,
|
53 |
+
img=img,
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class UIDecoder:
|
58 |
+
class State(Enum):
|
59 |
+
TXT = 1
|
60 |
+
IMG = 2
|
61 |
+
IMG_END = 3
|
62 |
+
|
63 |
+
def __init__(self, token_manager: TokenManager):
|
64 |
+
self.token_manager = token_manager
|
65 |
+
self.state = UIDecoder.State.TXT
|
66 |
+
self.image_builder = []
|
67 |
+
self.image_yield_every_n = 32
|
68 |
+
self.image_has_updated = False
|
69 |
+
|
70 |
+
def _image_progress(self) -> dict:
|
71 |
+
self.image_has_updated = False
|
72 |
+
png = self.token_manager.png_from_bpe_tokens(torch.cat(self.image_builder))
|
73 |
+
return {
|
74 |
+
"type": "image",
|
75 |
+
"value": "data:image/png;base64," + base64.b64encode(png).decode(),
|
76 |
+
}
|
77 |
+
|
78 |
+
def next(self, gpu_token: torch.LongTensor) -> dict | None:
|
79 |
+
if self.state == UIDecoder.State.TXT:
|
80 |
+
cpu_tok = gpu_token.item()
|
81 |
+
|
82 |
+
if cpu_tok == self.token_manager.vocab.begin_image:
|
83 |
+
self.state = UIDecoder.State.IMG
|
84 |
+
return {"type": "image_start"}
|
85 |
+
|
86 |
+
return {
|
87 |
+
"type": "text",
|
88 |
+
"value": self.token_manager.tokenizer.decode([cpu_tok]),
|
89 |
+
}
|
90 |
+
|
91 |
+
elif self.state == UIDecoder.State.IMG:
|
92 |
+
self.image_builder.append(gpu_token)
|
93 |
+
self.image_has_updated = True
|
94 |
+
if len(self.image_builder) == 1024:
|
95 |
+
self.state = UIDecoder.State.IMG_END
|
96 |
+
if len(self.image_builder) % self.image_yield_every_n == 0:
|
97 |
+
return self._image_progress()
|
98 |
+
|
99 |
+
elif self.state == UIDecoder.State.IMG_END:
|
100 |
+
# assert gpu_token == end_image
|
101 |
+
self.state = UIDecoder.State.TXT
|
102 |
+
progress = self._image_progress() if self.image_has_updated else None
|
103 |
+
self.image_builder = []
|
104 |
+
return progress
|
105 |
+
|
106 |
+
|
107 |
+
@dataclass
|
108 |
+
class State:
|
109 |
+
room_keys: dict[str, set[str]]
|
110 |
+
pending_requests: list[Request]
|
111 |
+
cond: threading.Condition
|
112 |
+
|
113 |
+
def __enter__(self, *args, **kwargs):
|
114 |
+
self.cond.__enter__(*args, **kwargs)
|
115 |
+
return self
|
116 |
+
|
117 |
+
def __exit__(self, *args, **kwargs):
|
118 |
+
self.cond.__exit__(*args, **kwargs)
|
119 |
+
return self
|
120 |
+
|
121 |
+
|
122 |
+
GlobalState = State(room_keys={}, pending_requests=[], cond=threading.Condition())
|
123 |
+
|
124 |
+
app = Flask(__name__)
|
125 |
+
socketio = SocketIO(app, max_http_buffer_size=16 * 1024 * 1024)
|
126 |
+
|
127 |
+
|
128 |
+
@app.route("/")
|
129 |
+
def index():
|
130 |
+
with open(Path(__file__).parent / "miniviewer.html") as f:
|
131 |
+
return f.read()
|
132 |
+
|
133 |
+
|
134 |
+
@socketio.on("disconnect")
|
135 |
+
def handle_disconnect():
|
136 |
+
with GlobalState as state:
|
137 |
+
try:
|
138 |
+
del state.room_keys[request.sid]
|
139 |
+
except KeyError:
|
140 |
+
pass
|
141 |
+
|
142 |
+
|
143 |
+
@socketio.on("cancel")
|
144 |
+
def handle_cancel(key):
|
145 |
+
with GlobalState as state:
|
146 |
+
try:
|
147 |
+
state.room_keys[request.sid].remove(key)
|
148 |
+
except KeyError:
|
149 |
+
pass
|
150 |
+
|
151 |
+
|
152 |
+
@socketio.on("generate")
|
153 |
+
def handle_generate(key, options, prompt_ui):
|
154 |
+
with GlobalState as state:
|
155 |
+
if request.sid not in state.room_keys:
|
156 |
+
state.room_keys[request.sid] = set()
|
157 |
+
state.room_keys[request.sid].add(key)
|
158 |
+
state.pending_requests.append(Request(request.sid, key, options, prompt_ui))
|
159 |
+
state.cond.notify_all()
|
160 |
+
|
161 |
+
|
162 |
+
def generation_thread(model: ChameleonInferenceModel):
|
163 |
+
while True:
|
164 |
+
with GlobalState as state:
|
165 |
+
state.cond.wait_for(lambda: state.pending_requests)
|
166 |
+
req = state.pending_requests.pop(0)
|
167 |
+
|
168 |
+
start = time.time()
|
169 |
+
ui_decoder = UIDecoder(model.token_manager)
|
170 |
+
options = convert_options(req.options)
|
171 |
+
|
172 |
+
if not options.txt:
|
173 |
+
progress = ui_decoder.next(
|
174 |
+
torch.tensor([model.token_manager.vocab.begin_image])
|
175 |
+
)
|
176 |
+
socketio.emit(
|
177 |
+
"progress",
|
178 |
+
{"key": req.key, **progress},
|
179 |
+
room=req.room,
|
180 |
+
)
|
181 |
+
|
182 |
+
for token in model.stream(
|
183 |
+
prompt_ui=req.prompt_ui,
|
184 |
+
options=options,
|
185 |
+
):
|
186 |
+
with GlobalState as state:
|
187 |
+
if req.key not in state.room_keys.get(req.room, {}):
|
188 |
+
break
|
189 |
+
|
190 |
+
if progress := ui_decoder.next(token.id):
|
191 |
+
socketio.emit(
|
192 |
+
"progress",
|
193 |
+
{"key": req.key, **progress},
|
194 |
+
room=req.room,
|
195 |
+
)
|
196 |
+
|
197 |
+
timing = time.time() - start
|
198 |
+
socketio.emit(
|
199 |
+
"progress",
|
200 |
+
{"key": req.key, "type": "done", "value": timing},
|
201 |
+
room=req.room,
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
def queue_position_thread():
|
206 |
+
local_pending_requests = []
|
207 |
+
while True:
|
208 |
+
with GlobalState as state:
|
209 |
+
state.cond.wait_for(
|
210 |
+
lambda: local_pending_requests != state.pending_requests
|
211 |
+
)
|
212 |
+
local_pending_requests = state.pending_requests[:]
|
213 |
+
|
214 |
+
for i, req in enumerate(local_pending_requests):
|
215 |
+
progress = {
|
216 |
+
"type": "queue",
|
217 |
+
"key": req.key,
|
218 |
+
"value": i + 1,
|
219 |
+
}
|
220 |
+
socketio.emit("progress", progress, room=req.room)
|
221 |
+
|
222 |
+
|
223 |
+
@click.command()
|
224 |
+
@click.option("--data-path", type=click.Path(), default="./data")
|
225 |
+
@click.option(
|
226 |
+
"--model-size", type=click.Choice(["7b", "30b"], case_sensitive=False), default="7b"
|
227 |
+
)
|
228 |
+
def main(data_path, model_size):
|
229 |
+
data_path = Path(data_path)
|
230 |
+
|
231 |
+
model_path = str(data_path / "models" / model_size)
|
232 |
+
tokenizer_path = str(data_path / "tokenizer/text_tokenizer.json")
|
233 |
+
vqgan_cfg_path = str(data_path / "tokenizer/vqgan.yaml")
|
234 |
+
vqgan_ckpt_path = str(data_path / "tokenizer/vqgan.ckpt")
|
235 |
+
|
236 |
+
if not os.path.exists(model_path):
|
237 |
+
raise ValueError(
|
238 |
+
"Model not found. Did you run python -m chameleon.download_data {PRESIGNED_URL}"
|
239 |
+
)
|
240 |
+
|
241 |
+
cm3v2_inference_model = ChameleonInferenceModel(
|
242 |
+
model_path, tokenizer_path, vqgan_cfg_path, vqgan_ckpt_path
|
243 |
+
)
|
244 |
+
threading.Thread(
|
245 |
+
target=generation_thread,
|
246 |
+
args=(cm3v2_inference_model,),
|
247 |
+
daemon=True,
|
248 |
+
).start()
|
249 |
+
threading.Thread(target=queue_position_thread, daemon=True).start()
|
250 |
+
socketio.run(app, debug=False)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
main()
|
chameleon/viewer/backend/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/viewer/backend/data_types.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Literal
|
8 |
+
|
9 |
+
from pydantic import BaseModel, Extra, Field
|
10 |
+
|
11 |
+
from chameleon.viewer.backend.models.abstract_model import (
|
12 |
+
DEFAULT_MULTIMODAL_CFG_IMAGE,
|
13 |
+
DEFAULT_MULTIMODAL_CFG_TEXT,
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
+
class WSMessageType(str, Enum):
|
18 |
+
GENERATE_IMAGE = "GENERATE_IMAGE"
|
19 |
+
GENERATE_TEXT = "GENERATE_TEXT"
|
20 |
+
GENERATE_MULTIMODAL = "GENERATE_MULTIMODAL"
|
21 |
+
PARTIAL_OUTPUT = "PARTIAL_OUTPUT"
|
22 |
+
FULL_OUTPUT = "FULL_OUTPUT"
|
23 |
+
COMPLETE = "COMPLETE"
|
24 |
+
ERROR = "ERROR"
|
25 |
+
QUEUE_STATUS = "QUEUE_STATUS"
|
26 |
+
|
27 |
+
|
28 |
+
class ContentType(str, Enum):
|
29 |
+
TEXT = "TEXT"
|
30 |
+
IMAGE = "IMAGE"
|
31 |
+
|
32 |
+
|
33 |
+
class Content(BaseModel):
|
34 |
+
content_type: ContentType
|
35 |
+
content: str
|
36 |
+
|
37 |
+
class Config:
|
38 |
+
extra = Extra.forbid
|
39 |
+
|
40 |
+
|
41 |
+
class NoOptionsForPartial(BaseModel):
|
42 |
+
message_type: Literal[WSMessageType.PARTIAL_OUTPUT] = WSMessageType.PARTIAL_OUTPUT
|
43 |
+
|
44 |
+
|
45 |
+
class NoOptionsForFull(BaseModel):
|
46 |
+
message_type: Literal[WSMessageType.FULL_OUTPUT] = WSMessageType.FULL_OUTPUT
|
47 |
+
|
48 |
+
|
49 |
+
class NoOptionsForComplete(BaseModel):
|
50 |
+
message_type: Literal[WSMessageType.COMPLETE] = WSMessageType.COMPLETE
|
51 |
+
|
52 |
+
|
53 |
+
class NoOptionsForError(BaseModel):
|
54 |
+
message_type: Literal[WSMessageType.ERROR] = WSMessageType.ERROR
|
55 |
+
|
56 |
+
|
57 |
+
class NoOptionsForQueueStatus(BaseModel):
|
58 |
+
message_type: Literal[WSMessageType.QUEUE_STATUS] = WSMessageType.QUEUE_STATUS
|
59 |
+
|
60 |
+
|
61 |
+
class MultimodalGeneratorOptions(BaseModel):
|
62 |
+
message_type: Literal[
|
63 |
+
WSMessageType.GENERATE_MULTIMODAL
|
64 |
+
] = WSMessageType.GENERATE_MULTIMODAL
|
65 |
+
temp: float = 0.7
|
66 |
+
top_p: float = 0.9
|
67 |
+
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE
|
68 |
+
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT
|
69 |
+
yield_every_n: int = 32
|
70 |
+
max_gen_tokens: int = 4096
|
71 |
+
repetition_penalty: float = 1.2
|
72 |
+
suffix_tokens: list[str] | None = None
|
73 |
+
seed: int | None = None
|
74 |
+
|
75 |
+
class Config:
|
76 |
+
extra = Extra.forbid
|
77 |
+
|
78 |
+
|
79 |
+
class WSMultimodalMessage(BaseModel):
|
80 |
+
message_type: WSMessageType
|
81 |
+
content: list[Content]
|
82 |
+
options: (
|
83 |
+
MultimodalGeneratorOptions
|
84 |
+
| NoOptionsForPartial
|
85 |
+
| NoOptionsForFull
|
86 |
+
| NoOptionsForError
|
87 |
+
| NoOptionsForComplete
|
88 |
+
| NoOptionsForQueueStatus
|
89 |
+
) = Field(..., discriminator="message_type")
|
90 |
+
debug_info: dict[str, str] = {}
|
chameleon/viewer/backend/model_viewer.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import hydra
|
7 |
+
import torch
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
|
10 |
+
from chameleon.inference import loader
|
11 |
+
from chameleon.viewer.backend.models.chameleon_distributed import (
|
12 |
+
ChameleonDistributedGenerator,
|
13 |
+
)
|
14 |
+
from chameleon.viewer.backend.models.chameleon_local import ChameleonLocalGenerator
|
15 |
+
from chameleon.viewer.backend.models.service import serve
|
16 |
+
from chameleon.viewer.backend.utils import configure_rich_logging, get_logger
|
17 |
+
|
18 |
+
logger = get_logger(__name__)
|
19 |
+
|
20 |
+
VERSION = "2.0"
|
21 |
+
SEED = 42
|
22 |
+
|
23 |
+
|
24 |
+
def create_chameleon_generator(cfg: DictConfig):
|
25 |
+
world_size = loader.detect_shard_count(cfg.model_path)
|
26 |
+
if world_size > 1:
|
27 |
+
torch.multiprocessing.set_start_method("spawn")
|
28 |
+
generator = ChameleonDistributedGenerator(
|
29 |
+
model_path=cfg.model_path,
|
30 |
+
tokenizer_path=cfg.tokenizer_path,
|
31 |
+
vqgan_config_path=cfg.vqgan_config_path,
|
32 |
+
vqgan_ckpt_path=cfg.vqgan_ckpt_path,
|
33 |
+
additional_eos_tokens=cfg.additional_eos_tokens,
|
34 |
+
world_size=world_size,
|
35 |
+
master_address=cfg.distributed.master_address,
|
36 |
+
master_port=cfg.distributed.master_port,
|
37 |
+
redis_port=cfg.redis_port,
|
38 |
+
)
|
39 |
+
else:
|
40 |
+
generator = ChameleonLocalGenerator(
|
41 |
+
model_path=cfg.model_path,
|
42 |
+
tokenizer_path=cfg.tokenizer_path,
|
43 |
+
vqgan_config_path=cfg.vqgan_config_path,
|
44 |
+
vqgan_ckpt_path=cfg.vqgan_ckpt_path,
|
45 |
+
additional_eos_tokens=cfg.additional_eos_tokens,
|
46 |
+
)
|
47 |
+
return generator
|
48 |
+
|
49 |
+
|
50 |
+
@hydra.main("../../../config", config_name="model_viewer", version_base="1.3.2")
|
51 |
+
def main(cfg: DictConfig) -> None:
|
52 |
+
configure_rich_logging()
|
53 |
+
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
54 |
+
logger.info("Starting viewer server with hydra cfg: %s", cfg)
|
55 |
+
|
56 |
+
serve(
|
57 |
+
create_chameleon_generator(cfg),
|
58 |
+
cfg.host,
|
59 |
+
cfg.port,
|
60 |
+
debug=cfg.debug,
|
61 |
+
redis_port=cfg.redis_port,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
main()
|
chameleon/viewer/backend/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
chameleon/viewer/backend/models/abstract_model.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import abc
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from typing import Generator
|
9 |
+
|
10 |
+
import PIL.Image
|
11 |
+
|
12 |
+
# images, joined retrieval queries, retrieval images
|
13 |
+
MixedTokenType = str | PIL.Image.Image
|
14 |
+
MixedSequenceType = list[MixedTokenType]
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class StreamingImage:
|
19 |
+
image: PIL.Image.Image
|
20 |
+
final: bool
|
21 |
+
|
22 |
+
|
23 |
+
DEFAULT_MULTIMODAL_CFG_IMAGE = 1.2
|
24 |
+
DEFAULT_MULTIMODAL_CFG_TEXT = 3.0
|
25 |
+
DEFAULT_IMAGE_CFG_IMAGE = 3.0
|
26 |
+
DEFAULT_IMAGE_CFG_TEXT = 3.0
|
27 |
+
|
28 |
+
|
29 |
+
class AbstractMultimodalGenerator(abc.ABC):
|
30 |
+
@abc.abstractmethod
|
31 |
+
def generate_text_streaming(
|
32 |
+
self,
|
33 |
+
prompts: list[MixedSequenceType],
|
34 |
+
temp: float = 1.0,
|
35 |
+
top_p: float = 0.8,
|
36 |
+
seed: int | None = None,
|
37 |
+
) -> Generator[list[str], None, None]:
|
38 |
+
pass
|
39 |
+
|
40 |
+
@abc.abstractmethod
|
41 |
+
def generate_image_streaming(
|
42 |
+
self,
|
43 |
+
prompt: MixedSequenceType,
|
44 |
+
temp: float = 1.0,
|
45 |
+
top_p: float = 0.8,
|
46 |
+
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
|
47 |
+
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
|
48 |
+
yield_every_n: int = 32,
|
49 |
+
seed: int | None = None,
|
50 |
+
) -> Generator[PIL.Image.Image, None, None]:
|
51 |
+
pass
|
52 |
+
|
53 |
+
@abc.abstractmethod
|
54 |
+
def generate_multimodal_streaming(
|
55 |
+
self,
|
56 |
+
prompt: MixedSequenceType,
|
57 |
+
temp: float = 1.0,
|
58 |
+
top_p: float = 0.8,
|
59 |
+
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
|
60 |
+
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
|
61 |
+
yield_every_n: int = 32,
|
62 |
+
max_gen_tokens: int = 4096,
|
63 |
+
repetition_penalty: float = 1.2,
|
64 |
+
suffix_tokens: list[str] | None = None,
|
65 |
+
seed: int | None = None,
|
66 |
+
) -> Generator[MixedSequenceType, None, None]:
|
67 |
+
pass
|
chameleon/viewer/backend/models/chameleon_distributed.py
ADDED
@@ -0,0 +1,827 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import asyncio
|
7 |
+
import json
|
8 |
+
import multiprocessing
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import sys
|
12 |
+
import threading
|
13 |
+
import time
|
14 |
+
import traceback
|
15 |
+
from functools import partial
|
16 |
+
from typing import Any, Generator, TypeVar
|
17 |
+
|
18 |
+
import redis
|
19 |
+
import redis.asyncio as async_redis
|
20 |
+
import torch
|
21 |
+
from tokenizers import Tokenizer
|
22 |
+
|
23 |
+
from chameleon.inference.image_tokenizer import ImageTokenizer
|
24 |
+
from chameleon.inference.loader import load_model
|
25 |
+
from chameleon.inference.vocab import VocabInfo
|
26 |
+
from chameleon.viewer.backend.data_types import WSMessageType
|
27 |
+
from chameleon.viewer.backend.models.abstract_model import (
|
28 |
+
DEFAULT_IMAGE_CFG_IMAGE,
|
29 |
+
DEFAULT_IMAGE_CFG_TEXT,
|
30 |
+
DEFAULT_MULTIMODAL_CFG_IMAGE,
|
31 |
+
DEFAULT_MULTIMODAL_CFG_TEXT,
|
32 |
+
AbstractMultimodalGenerator,
|
33 |
+
MixedSequenceType,
|
34 |
+
StreamingImage,
|
35 |
+
)
|
36 |
+
from chameleon.viewer.backend.models.chameleon_local import (
|
37 |
+
ChameleonForwardMixin,
|
38 |
+
ChameleonTokenizationMixin,
|
39 |
+
)
|
40 |
+
from chameleon.viewer.backend.utils import get_logger
|
41 |
+
|
42 |
+
logger = get_logger(__name__)
|
43 |
+
|
44 |
+
START = "START"
|
45 |
+
|
46 |
+
T = TypeVar("T")
|
47 |
+
|
48 |
+
|
49 |
+
def find_any(queue_by_id: dict[str, list]) -> str | None:
|
50 |
+
for candidate_queue_id, candidate_queue in queue_by_id.items():
|
51 |
+
if len(candidate_queue) > 0:
|
52 |
+
return candidate_queue_id
|
53 |
+
return None
|
54 |
+
|
55 |
+
|
56 |
+
class RedisQueue:
|
57 |
+
def __init__(self, redis_client: redis.Redis, name: str, interval: float = 0.1):
|
58 |
+
self.redis_client = redis_client
|
59 |
+
self.name = name
|
60 |
+
self.interval = interval
|
61 |
+
self.lock = redis.lock.Lock(redis_client, f"lock_for_{name}")
|
62 |
+
|
63 |
+
def reset(self):
|
64 |
+
self.redis_client.set(self.name, json.dumps({}))
|
65 |
+
try:
|
66 |
+
self.lock.release()
|
67 |
+
except redis.lock.LockError:
|
68 |
+
pass
|
69 |
+
|
70 |
+
def size(self) -> int:
|
71 |
+
maybe_queue_by_id = self.redis_client.get(self.name)
|
72 |
+
if maybe_queue_by_id is None:
|
73 |
+
return 0
|
74 |
+
else:
|
75 |
+
return len(json.loads(maybe_queue_by_id))
|
76 |
+
|
77 |
+
def clear(self, queue_id: str):
|
78 |
+
with self.lock:
|
79 |
+
maybe_queue_by_id = self.redis_client.get(self.name)
|
80 |
+
if maybe_queue_by_id is None:
|
81 |
+
queue_by_id: dict[str, list] = {}
|
82 |
+
else:
|
83 |
+
queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id)
|
84 |
+
queue_by_id[queue_id] = []
|
85 |
+
self.redis_client.set(self.name, json.dumps(queue_by_id))
|
86 |
+
|
87 |
+
def put(self, queue_id: str, value: T):
|
88 |
+
logger.debug(
|
89 |
+
"Thread %s: Starting PUT(%s) for %s",
|
90 |
+
threading.get_ident(),
|
91 |
+
self.name,
|
92 |
+
queue_id,
|
93 |
+
)
|
94 |
+
with self.lock:
|
95 |
+
maybe_queue_by_id = self.redis_client.get(self.name)
|
96 |
+
if maybe_queue_by_id is None:
|
97 |
+
queue_by_id: dict[str, list[T]] = {}
|
98 |
+
else:
|
99 |
+
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
|
100 |
+
|
101 |
+
if queue_id not in queue_by_id:
|
102 |
+
queue_by_id[queue_id] = []
|
103 |
+
queue_by_id[queue_id] = [value] + queue_by_id[queue_id]
|
104 |
+
self.redis_client.set(self.name, json.dumps(queue_by_id))
|
105 |
+
|
106 |
+
logger.debug(
|
107 |
+
"Thread %s: Finished PUT(%s) for %s",
|
108 |
+
threading.get_ident(),
|
109 |
+
self.name,
|
110 |
+
queue_id,
|
111 |
+
)
|
112 |
+
|
113 |
+
def get(self, queue_id: str | None) -> tuple[str, T]:
|
114 |
+
"""
|
115 |
+
Get the next value in the queue.
|
116 |
+
|
117 |
+
if queue_id is None, will get a value from any queue
|
118 |
+
|
119 |
+
if queue_id is not none, will wait to get a value from a specific queue
|
120 |
+
"""
|
121 |
+
logger.debug(
|
122 |
+
"Thread %s: Starting GET(%s) for %s",
|
123 |
+
threading.get_ident(),
|
124 |
+
self.name,
|
125 |
+
queue_id,
|
126 |
+
)
|
127 |
+
while True:
|
128 |
+
with self.lock:
|
129 |
+
# Initialization hasn't happened, so wait for it to happen
|
130 |
+
maybe_queue_by_id = self.redis_client.get(self.name)
|
131 |
+
if maybe_queue_by_id is None:
|
132 |
+
continue
|
133 |
+
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
|
134 |
+
if queue_id is None:
|
135 |
+
queue_id = find_any(queue_by_id)
|
136 |
+
|
137 |
+
# Ensure a queue_id was found or that it already existed
|
138 |
+
if queue_id is not None and queue_id in queue_by_id:
|
139 |
+
queue = queue_by_id[queue_id]
|
140 |
+
if len(queue) == 0:
|
141 |
+
continue
|
142 |
+
value = queue.pop(-1)
|
143 |
+
# queue is mutated and queue_by_id references it, so this works
|
144 |
+
self.redis_client.set(self.name, json.dumps(queue_by_id))
|
145 |
+
logger.debug(
|
146 |
+
"Thread %s: Finished GET(%s) for %s",
|
147 |
+
threading.get_ident(),
|
148 |
+
self.name,
|
149 |
+
queue_id,
|
150 |
+
)
|
151 |
+
return queue_id, value
|
152 |
+
time.sleep(self.interval)
|
153 |
+
|
154 |
+
|
155 |
+
class AsyncRedisQueue:
|
156 |
+
def __init__(
|
157 |
+
self, redis_client: async_redis.Redis, name: str, interval: float = 0.1
|
158 |
+
) -> None:
|
159 |
+
self.redis_client = redis_client
|
160 |
+
self.name = name
|
161 |
+
self.interval = interval
|
162 |
+
self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}")
|
163 |
+
|
164 |
+
async def reset(self):
|
165 |
+
await self.redis_client.set(self.name, json.dumps({}))
|
166 |
+
try:
|
167 |
+
await self.lock.release()
|
168 |
+
except async_redis.lock.LockError:
|
169 |
+
pass
|
170 |
+
|
171 |
+
async def size(self) -> int:
|
172 |
+
maybe_queue_by_id = await self.redis_client.get(self.name)
|
173 |
+
if maybe_queue_by_id is None:
|
174 |
+
return 0
|
175 |
+
else:
|
176 |
+
return len(json.loads(maybe_queue_by_id))
|
177 |
+
|
178 |
+
async def clear(self, queue_id: str):
|
179 |
+
logger.debug(
|
180 |
+
"ASYNC Thread %s: Starting CLEAR(%s) for %s",
|
181 |
+
threading.get_ident(),
|
182 |
+
self.name,
|
183 |
+
queue_id,
|
184 |
+
)
|
185 |
+
async with self.lock:
|
186 |
+
maybe_queue_by_id = await self.redis_client.get(self.name)
|
187 |
+
if maybe_queue_by_id is None:
|
188 |
+
queue_by_id: dict[str, list] = {}
|
189 |
+
else:
|
190 |
+
queue_by_id: dict[str, list] = json.loads(maybe_queue_by_id)
|
191 |
+
queue_by_id[queue_id] = []
|
192 |
+
await self.redis_client.set(self.name, json.dumps(queue_by_id))
|
193 |
+
|
194 |
+
logger.debug(
|
195 |
+
"ASYNC Thread %s: Finished CLEAR(%s) for %s",
|
196 |
+
threading.get_ident(),
|
197 |
+
self.name,
|
198 |
+
queue_id,
|
199 |
+
)
|
200 |
+
|
201 |
+
async def put(self, queue_id: str, value: T):
|
202 |
+
logger.debug(
|
203 |
+
"ASYNC Thread %s: Starting PUT(%s) for %s",
|
204 |
+
threading.get_ident(),
|
205 |
+
self.name,
|
206 |
+
queue_id,
|
207 |
+
)
|
208 |
+
|
209 |
+
async with self.lock:
|
210 |
+
maybe_queue_by_id = await self.redis_client.get(self.name)
|
211 |
+
if maybe_queue_by_id is None:
|
212 |
+
queue_by_id: dict[str, list[T]] = {}
|
213 |
+
else:
|
214 |
+
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
|
215 |
+
|
216 |
+
if queue_id not in queue_by_id:
|
217 |
+
queue_by_id[queue_id] = []
|
218 |
+
queue_by_id[queue_id] = [value] + queue_by_id[queue_id]
|
219 |
+
await self.redis_client.set(self.name, json.dumps(queue_by_id))
|
220 |
+
|
221 |
+
logger.debug(
|
222 |
+
"ASYNC Thread %s: Finished PUT(%s) for %s",
|
223 |
+
threading.get_ident(),
|
224 |
+
self.name,
|
225 |
+
queue_id,
|
226 |
+
)
|
227 |
+
|
228 |
+
async def get(self, queue_id: str | None):
|
229 |
+
"""
|
230 |
+
Get the next value in the queue.
|
231 |
+
|
232 |
+
if queue_id is None, will get a value from any queue
|
233 |
+
|
234 |
+
if queue_id is not none, will wait to get a value from a specific queue
|
235 |
+
"""
|
236 |
+
logger.debug(
|
237 |
+
"ASYNC Thread %s: Starting GET(%s) for %s",
|
238 |
+
threading.get_ident(),
|
239 |
+
self.name,
|
240 |
+
queue_id,
|
241 |
+
)
|
242 |
+
while True:
|
243 |
+
async with self.lock:
|
244 |
+
maybe_queue_by_id = await self.redis_client.get(self.name)
|
245 |
+
if maybe_queue_by_id is None:
|
246 |
+
continue
|
247 |
+
queue_by_id: dict[str, list[T]] = json.loads(maybe_queue_by_id)
|
248 |
+
if queue_id is None:
|
249 |
+
queue_id = find_any(queue_by_id)
|
250 |
+
|
251 |
+
# Ensure a queue_id was found or that it already existed
|
252 |
+
if queue_id is not None and queue_id in queue_by_id:
|
253 |
+
queue: list = queue_by_id[queue_id]
|
254 |
+
if len(queue) == 0:
|
255 |
+
continue
|
256 |
+
value = queue.pop(-1)
|
257 |
+
# queue is mutated and queue_by_id references it, so this works
|
258 |
+
await self.redis_client.set(self.name, json.dumps(queue_by_id))
|
259 |
+
logger.debug(
|
260 |
+
"ASYNC Thread %s: Finished GET(%s) for %s",
|
261 |
+
threading.get_ident(),
|
262 |
+
self.name,
|
263 |
+
queue_id,
|
264 |
+
)
|
265 |
+
return queue_id, value
|
266 |
+
await asyncio.sleep(self.interval)
|
267 |
+
|
268 |
+
|
269 |
+
class AsyncRedisCounter:
|
270 |
+
def __init__(self, redis_client: async_redis.Redis, name: str) -> None:
|
271 |
+
self.redis_client = redis_client
|
272 |
+
self.name = name
|
273 |
+
self.lock = async_redis.lock.Lock(redis_client, f"lock_for_{name}")
|
274 |
+
|
275 |
+
async def reset(self) -> int:
|
276 |
+
try:
|
277 |
+
await self.lock.release()
|
278 |
+
except async_redis.lock.LockError:
|
279 |
+
pass
|
280 |
+
await self.redis_client.set(self.name, 0)
|
281 |
+
|
282 |
+
async def add(self, n: int) -> int:
|
283 |
+
async with self.lock:
|
284 |
+
current_val = await self.redis_client.get(self.name)
|
285 |
+
if current_val is None:
|
286 |
+
current_val = 0
|
287 |
+
else:
|
288 |
+
current_val = int(current_val)
|
289 |
+
new_val = current_val + n
|
290 |
+
await self.redis_client.set(self.name, new_val)
|
291 |
+
return new_val
|
292 |
+
|
293 |
+
async def sub(self, n: int) -> int:
|
294 |
+
async with self.lock:
|
295 |
+
current_val = await self.redis_client.get(self.name)
|
296 |
+
if current_val is None:
|
297 |
+
raise ValueError("Invalid sub counter when counter does not exist")
|
298 |
+
current_val = int(current_val)
|
299 |
+
if current_val <= 0:
|
300 |
+
raise ValueError("Invalid sub counter to counter that is already zero")
|
301 |
+
new_val = current_val - n
|
302 |
+
await self.redis_client.set(self.name, new_val)
|
303 |
+
return new_val
|
304 |
+
|
305 |
+
async def count(self) -> int:
|
306 |
+
value = await self.redis_client.get(self.name)
|
307 |
+
if value is None:
|
308 |
+
return 0
|
309 |
+
else:
|
310 |
+
return int(value)
|
311 |
+
|
312 |
+
|
313 |
+
def distributed_workers(
|
314 |
+
model_args: dict,
|
315 |
+
master_address: str,
|
316 |
+
master_port: str,
|
317 |
+
world_size: int,
|
318 |
+
rank: int,
|
319 |
+
redis_port: int,
|
320 |
+
worker_queues: dict[int, multiprocessing.Queue],
|
321 |
+
) -> None:
|
322 |
+
redis_client = redis.Redis("redis", redis_port)
|
323 |
+
request_queue = RedisQueue(redis_client, "request")
|
324 |
+
response_queue = RedisQueue(redis_client, "response")
|
325 |
+
|
326 |
+
os.environ["MASTER_ADDR"] = master_address
|
327 |
+
os.environ["MASTER_PORT"] = str(master_port)
|
328 |
+
|
329 |
+
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
330 |
+
|
331 |
+
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
|
332 |
+
assert rank == torch.distributed.get_rank()
|
333 |
+
|
334 |
+
torch.cuda.set_device(rank)
|
335 |
+
|
336 |
+
is_coord = rank == 0
|
337 |
+
|
338 |
+
worker = ChameleonWorker(
|
339 |
+
rank=rank,
|
340 |
+
model_path=model_args["model_path"],
|
341 |
+
tokenizer_path=model_args["tokenizer_path"],
|
342 |
+
additional_eos_tokens=model_args["additional_eos_tokens"],
|
343 |
+
)
|
344 |
+
worker_id = id(worker)
|
345 |
+
logger.info("Rank %s, master_port=%s worker=%s", rank, master_port, worker_id)
|
346 |
+
|
347 |
+
step = 0
|
348 |
+
while True:
|
349 |
+
step += 1
|
350 |
+
redis_client.set(f"status_rank_{rank}", "Pre-coordinator sync")
|
351 |
+
if is_coord:
|
352 |
+
distributed_objs = [request_queue.get(None)]
|
353 |
+
logger.info("Objects from queue: %s", distributed_objs)
|
354 |
+
for worker_rank in range(1, world_size):
|
355 |
+
worker_message = {"message": START, "src": rank, "dst": worker_rank}
|
356 |
+
logger.info("Rank %s Sending: %s", rank, worker_message)
|
357 |
+
worker_queues[worker_rank].put(worker_message)
|
358 |
+
else:
|
359 |
+
distributed_objs = [None]
|
360 |
+
logger.info("Rank %s worker %s waiting for rank 0", rank, worker_id)
|
361 |
+
message_from_rank_0 = worker_queues[rank].get()
|
362 |
+
logger.info(
|
363 |
+
"Received message from rank 0 in rank %s: %s", rank, message_from_rank_0
|
364 |
+
)
|
365 |
+
if message_from_rank_0["message"] != START:
|
366 |
+
raise ValueError(
|
367 |
+
f"Unexpected message from rank 0: {message_from_rank_0['message']}"
|
368 |
+
)
|
369 |
+
redis_client.set(f"status_rank_{rank}", "Post-coordinator sync")
|
370 |
+
|
371 |
+
try:
|
372 |
+
logger.info(
|
373 |
+
"Broadcast Starting: Rank %s, worker %s, step %s",
|
374 |
+
rank,
|
375 |
+
worker_id,
|
376 |
+
step,
|
377 |
+
)
|
378 |
+
redis_client.set(f"status_rank_{rank}", "Pre-torch sync")
|
379 |
+
torch.distributed.broadcast_object_list(distributed_objs, src=0)
|
380 |
+
redis_client.set(f"status_rank_{rank}", "Post-torch sync")
|
381 |
+
logger.info(
|
382 |
+
"Broadcast Complete: Rank %s, worker %s, step %s",
|
383 |
+
rank,
|
384 |
+
worker_id,
|
385 |
+
step,
|
386 |
+
)
|
387 |
+
except RuntimeError as e:
|
388 |
+
logger.error(
|
389 |
+
"Rank %s, worker %s, step %s, Error detected in torch broadcast: %s",
|
390 |
+
rank,
|
391 |
+
worker_id,
|
392 |
+
step,
|
393 |
+
str(e),
|
394 |
+
)
|
395 |
+
raise
|
396 |
+
|
397 |
+
logger.info("rank %s, objs %s", rank, distributed_objs)
|
398 |
+
queue_id, data = distributed_objs[0]
|
399 |
+
mode = data.pop("mode")
|
400 |
+
request_id = data.pop("request_id")
|
401 |
+
assert queue_id == request_id
|
402 |
+
tokenized_prompt = data.pop("tokenized_prompt")
|
403 |
+
try:
|
404 |
+
match mode:
|
405 |
+
case WSMessageType.GENERATE_TEXT:
|
406 |
+
generator_fn = partial(
|
407 |
+
worker._generate_text_streaming, tokenized_prompt, **data
|
408 |
+
)
|
409 |
+
case WSMessageType.GENERATE_IMAGE:
|
410 |
+
generator_fn = partial(
|
411 |
+
worker._generate_image_streaming, tokenized_prompt, **data
|
412 |
+
)
|
413 |
+
case WSMessageType.GENERATE_MULTIMODAL:
|
414 |
+
generator_fn = partial(
|
415 |
+
worker._generate_multimodal_streaming, tokenized_prompt, **data
|
416 |
+
)
|
417 |
+
case _:
|
418 |
+
logger.error(
|
419 |
+
"Encountered unknown mode, crashing the program: %s", mode
|
420 |
+
)
|
421 |
+
response_queue.put(
|
422 |
+
queue_id, {"error": True, "final": True, "message": mode}
|
423 |
+
)
|
424 |
+
raise ValueError("Unknown mode")
|
425 |
+
logger.info("Rank: %s, Processing request: %s", rank, request_id)
|
426 |
+
i = 0
|
427 |
+
redis_client.set(f"status_rank_{rank}", "Pre-generate")
|
428 |
+
for output in generator_fn():
|
429 |
+
i += 1
|
430 |
+
if is_coord:
|
431 |
+
response = {"final": False, "output": output, "error": False}
|
432 |
+
logger.info(
|
433 |
+
"Rank: %s, Adding to response queue: %.100s",
|
434 |
+
rank,
|
435 |
+
response,
|
436 |
+
)
|
437 |
+
redis_client.set(f"status_rank_{rank}", f"Generate Pre Put {i}")
|
438 |
+
response_queue.put(queue_id, response)
|
439 |
+
redis_client.set(f"status_rank_{rank}", f"Generate Post Put {i}")
|
440 |
+
else:
|
441 |
+
redis_client.set(f"status_rank_{rank}", f"Generate {i}")
|
442 |
+
redis_client.set(f"step_on_rank_{rank}", i)
|
443 |
+
redis_client.set(f"status_rank_{rank}", "Post-generate")
|
444 |
+
if is_coord:
|
445 |
+
logger.info("Rank: %s, Adding final result to output queue", rank)
|
446 |
+
response_queue.put(queue_id, {"final": True, "error": False})
|
447 |
+
except torch.cuda.OutOfMemoryError as e:
|
448 |
+
logger.error("Encountered OOM, crashing the program: %s", e)
|
449 |
+
response_queue.put(
|
450 |
+
queue_id, {"error": True, "final": True, "message": str(e)}
|
451 |
+
)
|
452 |
+
crash_program()
|
453 |
+
except RuntimeError as e:
|
454 |
+
message = str(e)
|
455 |
+
if "CUDA" in message:
|
456 |
+
logger.error("Encountered CUDA error, crashing the program: %s", e)
|
457 |
+
response_queue.put(
|
458 |
+
queue_id, {"error": True, "final": True, "message": str(e)}
|
459 |
+
)
|
460 |
+
crash_program()
|
461 |
+
else:
|
462 |
+
logger.error(
|
463 |
+
"Encountered unexpected runtime error, crashing the program: %s %s",
|
464 |
+
e,
|
465 |
+
traceback.format_exc(),
|
466 |
+
)
|
467 |
+
response_queue.put(
|
468 |
+
queue_id, {"error": True, "final": True, "message": str(e)}
|
469 |
+
)
|
470 |
+
crash_program()
|
471 |
+
except Exception as e:
|
472 |
+
logger.error(
|
473 |
+
"Encountered unexpected exception: %s %s",
|
474 |
+
str(e),
|
475 |
+
traceback.format_exc(),
|
476 |
+
)
|
477 |
+
response_queue.put(
|
478 |
+
queue_id, {"error": True, "final": True, "message": str(e)}
|
479 |
+
)
|
480 |
+
crash_program()
|
481 |
+
|
482 |
+
|
483 |
+
class ChameleonWorker(ChameleonForwardMixin):
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
*,
|
487 |
+
rank: int,
|
488 |
+
model_path: str,
|
489 |
+
tokenizer_path: str,
|
490 |
+
additional_eos_tokens: list[str] | None,
|
491 |
+
) -> None:
|
492 |
+
self.rank = rank
|
493 |
+
self.model_path = model_path
|
494 |
+
self.additional_eos_tokens = additional_eos_tokens
|
495 |
+
torch.set_default_device(f"cuda:{rank}")
|
496 |
+
self.model = load_model(model_path, rank)
|
497 |
+
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
498 |
+
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
|
499 |
+
logger.info(
|
500 |
+
"Rank: %s, Model loaded in worker_obj: %s",
|
501 |
+
rank,
|
502 |
+
id(self),
|
503 |
+
)
|
504 |
+
|
505 |
+
|
506 |
+
def crash_program() -> None:
|
507 |
+
logger.error(
|
508 |
+
"Crashing the program as instructed, likely due to distributed worker failures"
|
509 |
+
)
|
510 |
+
sys.exit(1)
|
511 |
+
|
512 |
+
|
513 |
+
class ChameleonDistributedGenerator(AbstractMultimodalGenerator, ChameleonTokenizationMixin):
|
514 |
+
def __init__(
|
515 |
+
self,
|
516 |
+
*,
|
517 |
+
world_size: int,
|
518 |
+
model_path: str,
|
519 |
+
master_port: int,
|
520 |
+
tokenizer_path: str,
|
521 |
+
vqgan_config_path: str,
|
522 |
+
vqgan_ckpt_path: str | None = None,
|
523 |
+
master_address: str = "0.0.0.0",
|
524 |
+
additional_eos_tokens: list[str] | None = None,
|
525 |
+
redis_port: int | None = None,
|
526 |
+
) -> None:
|
527 |
+
self.master_port = master_port
|
528 |
+
self.master_address = master_address
|
529 |
+
self.additional_eos_tokens = additional_eos_tokens
|
530 |
+
logger.info("Loading tokenizer...")
|
531 |
+
tokenizer_path = tokenizer_path
|
532 |
+
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
533 |
+
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
|
534 |
+
|
535 |
+
logger.info("Loading VQGAN...")
|
536 |
+
self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path)
|
537 |
+
self.redis_port = redis_port
|
538 |
+
self.redis_pool = async_redis.ConnectionPool.from_url(
|
539 |
+
f"redis://redis:{redis_port}"
|
540 |
+
)
|
541 |
+
self.redis_client = async_redis.Redis.from_pool(self.redis_pool)
|
542 |
+
self.request_queue = AsyncRedisQueue(self.redis_client, "request")
|
543 |
+
self.response_queue = AsyncRedisQueue(self.redis_client, "response")
|
544 |
+
self.worker_queues: dict[int, multiprocessing.Queue] = {
|
545 |
+
rank: multiprocessing.Queue() for rank in range(world_size)
|
546 |
+
}
|
547 |
+
self.procs: list[multiprocessing.Process] = []
|
548 |
+
model_args = {
|
549 |
+
"model_path": model_path,
|
550 |
+
"master_address": master_address,
|
551 |
+
"master_port": master_port,
|
552 |
+
"tokenizer_path": tokenizer_path,
|
553 |
+
"additional_eos_tokens": additional_eos_tokens,
|
554 |
+
}
|
555 |
+
logger.info("Launching paralle model with world_size=%s", world_size)
|
556 |
+
for i in range(world_size):
|
557 |
+
proc = multiprocessing.Process(
|
558 |
+
target=distributed_workers,
|
559 |
+
args=(
|
560 |
+
model_args,
|
561 |
+
master_address,
|
562 |
+
master_port,
|
563 |
+
world_size,
|
564 |
+
i,
|
565 |
+
self.redis_port,
|
566 |
+
self.worker_queues,
|
567 |
+
),
|
568 |
+
daemon=True,
|
569 |
+
)
|
570 |
+
self.procs.append(proc)
|
571 |
+
proc.start()
|
572 |
+
|
573 |
+
def check_error(self, output: dict) -> None:
|
574 |
+
if output["error"]:
|
575 |
+
import sys
|
576 |
+
print(f"check_error({output})", file=sys.stderr)
|
577 |
+
self.kill_procs()
|
578 |
+
logger.error(
|
579 |
+
"COORDINATOR: Encountered error in managed processes, exiting: %s",
|
580 |
+
output,
|
581 |
+
)
|
582 |
+
crash_program()
|
583 |
+
|
584 |
+
def __del__(self) -> None:
|
585 |
+
self.kill_procs(error=False)
|
586 |
+
|
587 |
+
def kill_procs(self, error: bool = True) -> None:
|
588 |
+
if error:
|
589 |
+
log_fn = logger.error
|
590 |
+
else:
|
591 |
+
log_fn = logger.info
|
592 |
+
log_fn("Error encountered, killing worker procs: %s", self.procs)
|
593 |
+
for p in self.procs:
|
594 |
+
try:
|
595 |
+
log_fn("Killing: %s", p)
|
596 |
+
p.kill()
|
597 |
+
except:
|
598 |
+
log_fn("Encountered issue killing process and ignoring: %s", p)
|
599 |
+
|
600 |
+
# ALLOW_ANY(get_next_output.return)
|
601 |
+
async def get_next_output(self, request_id: str) -> Any:
|
602 |
+
logger.info("Waiting for response for request_id=%s", request_id)
|
603 |
+
queue_id, output = await self.response_queue.get(request_id)
|
604 |
+
assert queue_id == request_id
|
605 |
+
return output
|
606 |
+
|
607 |
+
async def generate_text_streaming(
|
608 |
+
self,
|
609 |
+
prompt: MixedSequenceType,
|
610 |
+
max_gen_tokens: int = 256,
|
611 |
+
temp: float = 1.0,
|
612 |
+
top_p: float = 0.8,
|
613 |
+
repetition_penalty: float = 1.2,
|
614 |
+
seed: int | None = None,
|
615 |
+
debug: dict | None = None,
|
616 |
+
) -> Generator[str, None, None]:
|
617 |
+
tokenized_prompt = self.tokens_from_inputs(prompt)
|
618 |
+
request_id = f"request_{random.randint(100_000, 200_000)}"
|
619 |
+
if seed is None:
|
620 |
+
seed = random.randint(1, 2048)
|
621 |
+
if debug is not None:
|
622 |
+
debug["seed"] = seed
|
623 |
+
if len(tokenized_prompt) > (4096 - 3):
|
624 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
|
625 |
+
return
|
626 |
+
assert not isinstance(tokenized_prompt, torch.Tensor)
|
627 |
+
request = {
|
628 |
+
"mode": WSMessageType.GENERATE_TEXT.value,
|
629 |
+
"request_id": request_id,
|
630 |
+
"tokenized_prompt": tokenized_prompt,
|
631 |
+
"max_gen_tokens": max_gen_tokens,
|
632 |
+
"temp": temp,
|
633 |
+
"top_p": top_p,
|
634 |
+
"repetition_penalty": repetition_penalty,
|
635 |
+
"seed": seed,
|
636 |
+
}
|
637 |
+
logger.info(
|
638 |
+
"Sending request_id=%s: %s",
|
639 |
+
request_id,
|
640 |
+
request,
|
641 |
+
)
|
642 |
+
await asyncio.gather(
|
643 |
+
self.request_queue.clear(request_id),
|
644 |
+
self.response_queue.clear(request_id),
|
645 |
+
)
|
646 |
+
logger.info("Cleared request/response queue for %s", request_id)
|
647 |
+
await self.request_queue.put(request_id, request)
|
648 |
+
logger.info("Sent request to coordinator %s", request_id)
|
649 |
+
try:
|
650 |
+
while True:
|
651 |
+
output = await self.get_next_output(request_id)
|
652 |
+
logger.info("Received response for %s", request_id)
|
653 |
+
self.check_error(output)
|
654 |
+
if output["final"]:
|
655 |
+
break
|
656 |
+
|
657 |
+
n_outs = len(output["output"])
|
658 |
+
if n_outs != 1:
|
659 |
+
logger.error(
|
660 |
+
"Encountered unexpected number of %s arguments in: %s",
|
661 |
+
n_outs,
|
662 |
+
output["output"],
|
663 |
+
)
|
664 |
+
tokens = output["output"]
|
665 |
+
assert not isinstance(tokens, torch.Tensor)
|
666 |
+
logger.info("output info: type=%s, value=%.20s", type(tokens), tokens)
|
667 |
+
yield self.tokenizer.decode(tokens)
|
668 |
+
finally:
|
669 |
+
logger.info("Cleaning up queues in request_id=%s", request_id)
|
670 |
+
await asyncio.gather(
|
671 |
+
self.request_queue.clear(request_id),
|
672 |
+
self.response_queue.clear(request_id),
|
673 |
+
)
|
674 |
+
logger.info("Completed cleaning for request_id=%s", request_id)
|
675 |
+
|
676 |
+
async def generate_image_streaming(
|
677 |
+
self,
|
678 |
+
prompt: MixedSequenceType,
|
679 |
+
temp: float = 1.0,
|
680 |
+
top_p: float = 0.8,
|
681 |
+
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
|
682 |
+
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
|
683 |
+
yield_every_n: int = 32,
|
684 |
+
debug: dict | None = None,
|
685 |
+
seed: int | None = None,
|
686 |
+
) -> Generator[StreamingImage, None, None]:
|
687 |
+
tokenized_prompt = self.tokens_from_inputs(prompt)
|
688 |
+
tokenized_prompt.append(self.vocab.begin_image)
|
689 |
+
assert not isinstance(tokenized_prompt, torch.Tensor)
|
690 |
+
request_id = f"request_{random.randint(100_000, 200_000)}"
|
691 |
+
if seed is None:
|
692 |
+
seed = random.randint(1, 2048)
|
693 |
+
if debug is not None:
|
694 |
+
debug["seed"] = seed
|
695 |
+
if len(tokenized_prompt) > (4096 - 3 - 1024):
|
696 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
|
697 |
+
return
|
698 |
+
request = {
|
699 |
+
"mode": WSMessageType.GENERATE_IMAGE.value,
|
700 |
+
"request_id": request_id,
|
701 |
+
"tokenized_prompt": tokenized_prompt,
|
702 |
+
"cfg_image_weight": cfg_image_weight,
|
703 |
+
"cfg_text_weight": cfg_text_weight,
|
704 |
+
"yield_every_n": yield_every_n,
|
705 |
+
"temp": temp,
|
706 |
+
"top_p": top_p,
|
707 |
+
"seed": seed,
|
708 |
+
}
|
709 |
+
logger.info(
|
710 |
+
"Sending request_id=%s: %s",
|
711 |
+
request_id,
|
712 |
+
request,
|
713 |
+
)
|
714 |
+
await asyncio.gather(
|
715 |
+
self.request_queue.clear(request_id),
|
716 |
+
self.response_queue.clear(request_id),
|
717 |
+
)
|
718 |
+
logger.info("Cleared request/response queue for %s", request_id)
|
719 |
+
await self.request_queue.put(request_id, request)
|
720 |
+
logger.info("Sent request to coordinator %s", request_id)
|
721 |
+
try:
|
722 |
+
while True:
|
723 |
+
output = await self.get_next_output(request_id)
|
724 |
+
logger.info("Received response for %s", request_id)
|
725 |
+
self.check_error(output)
|
726 |
+
if output["final"]:
|
727 |
+
break
|
728 |
+
n_outs = len(output["output"])
|
729 |
+
if n_outs != 2:
|
730 |
+
logger.error(
|
731 |
+
"Encountered unexpected number of %s arguments in: %s",
|
732 |
+
n_outs,
|
733 |
+
output["output"],
|
734 |
+
)
|
735 |
+
tokens, final = output["output"]
|
736 |
+
assert not isinstance(tokens, torch.Tensor)
|
737 |
+
yield StreamingImage(
|
738 |
+
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final
|
739 |
+
)
|
740 |
+
finally:
|
741 |
+
logger.info("Cleaning up queues in request_id=%s", request_id)
|
742 |
+
await asyncio.gather(
|
743 |
+
self.request_queue.clear(request_id),
|
744 |
+
self.response_queue.clear(request_id),
|
745 |
+
)
|
746 |
+
logger.info("Completed cleaning for request_id=%s", request_id)
|
747 |
+
|
748 |
+
async def generate_multimodal_streaming(
|
749 |
+
self,
|
750 |
+
prompt: MixedSequenceType,
|
751 |
+
temp: float = 1.0,
|
752 |
+
top_p: float = 0.8,
|
753 |
+
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
|
754 |
+
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
|
755 |
+
yield_every_n: int = 32,
|
756 |
+
max_gen_tokens: int = 4096,
|
757 |
+
repetition_penalty: float = 1.2,
|
758 |
+
suffix_tokens: list[str] | None = None,
|
759 |
+
seed: int | None = None,
|
760 |
+
debug: dict | None = None,
|
761 |
+
) -> Generator[MixedSequenceType, None, None]:
|
762 |
+
tokenized_prompt = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens)
|
763 |
+
assert not isinstance(tokenized_prompt, torch.Tensor)
|
764 |
+
request_id = f"request_{random.randint(100_000, 200_000)}"
|
765 |
+
if seed is None:
|
766 |
+
seed = random.randint(1, 2048)
|
767 |
+
if debug is not None:
|
768 |
+
debug["seed"] = seed
|
769 |
+
if len(tokenized_prompt) > (4096 - 3):
|
770 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens."
|
771 |
+
return
|
772 |
+
|
773 |
+
request = {
|
774 |
+
"mode": WSMessageType.GENERATE_MULTIMODAL.value,
|
775 |
+
"request_id": request_id,
|
776 |
+
"tokenized_prompt": tokenized_prompt,
|
777 |
+
"cfg_image_weight": cfg_image_weight,
|
778 |
+
"cfg_text_weight": cfg_text_weight,
|
779 |
+
"repetition_penalty": repetition_penalty,
|
780 |
+
"yield_every_n": yield_every_n,
|
781 |
+
"max_gen_tokens": max_gen_tokens,
|
782 |
+
"temp": temp,
|
783 |
+
"top_p": top_p,
|
784 |
+
"seed": seed,
|
785 |
+
}
|
786 |
+
logger.info(
|
787 |
+
"Sending request_id=%s: %s",
|
788 |
+
request_id,
|
789 |
+
request,
|
790 |
+
)
|
791 |
+
await asyncio.gather(
|
792 |
+
self.request_queue.clear(request_id),
|
793 |
+
self.response_queue.clear(request_id),
|
794 |
+
)
|
795 |
+
logger.info("Cleared request/response queue for %s", request_id)
|
796 |
+
await self.request_queue.put(request_id, request)
|
797 |
+
logger.info("Sent request to coordinator %s", request_id)
|
798 |
+
try:
|
799 |
+
while True:
|
800 |
+
output = await self.get_next_output(request_id)
|
801 |
+
logger.info("Received response for %s", request_id)
|
802 |
+
self.check_error(output)
|
803 |
+
if output["final"]:
|
804 |
+
break
|
805 |
+
n_outs = len(output["output"])
|
806 |
+
if n_outs != 3:
|
807 |
+
logger.error(
|
808 |
+
"Encountered unexpected number of %s arguments in: %s",
|
809 |
+
n_outs,
|
810 |
+
output["output"],
|
811 |
+
)
|
812 |
+
token_type, tokens, image_is_final = output["output"]
|
813 |
+
assert not isinstance(tokens, torch.Tensor)
|
814 |
+
match token_type:
|
815 |
+
case "TEXT":
|
816 |
+
yield self.tokenizer.decode(tokens)
|
817 |
+
case "IMAGE":
|
818 |
+
yield StreamingImage(
|
819 |
+
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)),
|
820 |
+
final=image_is_final,
|
821 |
+
)
|
822 |
+
case _:
|
823 |
+
raise ValueError("Unknown token type")
|
824 |
+
finally:
|
825 |
+
logger.info("Cleaning up queues in request_id=%s", request_id)
|
826 |
+
await self.request_queue.clear(request_id)
|
827 |
+
await self.response_queue.clear(request_id)
|
chameleon/viewer/backend/models/chameleon_local.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import io
|
7 |
+
import json
|
8 |
+
from typing import Generator
|
9 |
+
|
10 |
+
import PIL.Image
|
11 |
+
import torch
|
12 |
+
import transformers
|
13 |
+
from tokenizers import Tokenizer
|
14 |
+
from transformers import (
|
15 |
+
MaxLengthCriteria,
|
16 |
+
RepetitionPenaltyLogitsProcessor,
|
17 |
+
TemperatureLogitsWarper,
|
18 |
+
TopPLogitsWarper,
|
19 |
+
)
|
20 |
+
|
21 |
+
from chameleon.inference.alignment import AlignPromptRight
|
22 |
+
from chameleon.inference.generation import ChameleonGenerator
|
23 |
+
from chameleon.inference.image_tokenizer import ImageTokenizer
|
24 |
+
from chameleon.inference.loader import load_model
|
25 |
+
from chameleon.inference.logits_processor import (
|
26 |
+
AllowOnlyTokensAfterIndexLogitsProcessor,
|
27 |
+
AllowOnlyTokensLogitsProcessor,
|
28 |
+
InBatchInstructCFGLogitsProcessor,
|
29 |
+
)
|
30 |
+
from chameleon.inference.model_adapter import ChameleonModelAdapter
|
31 |
+
from chameleon.inference.stopping_criteria import StopOnEOS, StopOnEOSAfterBatchIndex
|
32 |
+
from chameleon.inference.token_selector import (
|
33 |
+
MultinomialTokenSelector,
|
34 |
+
ReplicatedInputTokenSelector,
|
35 |
+
)
|
36 |
+
from chameleon.inference.vocab import VocabInfo, VocabTranslation
|
37 |
+
from chameleon.viewer.backend.models.abstract_model import (
|
38 |
+
DEFAULT_IMAGE_CFG_IMAGE,
|
39 |
+
DEFAULT_IMAGE_CFG_TEXT,
|
40 |
+
DEFAULT_MULTIMODAL_CFG_IMAGE,
|
41 |
+
DEFAULT_MULTIMODAL_CFG_TEXT,
|
42 |
+
AbstractMultimodalGenerator,
|
43 |
+
MixedSequenceType,
|
44 |
+
StreamingImage,
|
45 |
+
)
|
46 |
+
from chameleon.viewer.backend.utils import get_logger
|
47 |
+
|
48 |
+
logger = get_logger(__name__)
|
49 |
+
|
50 |
+
|
51 |
+
def set_seed(seed: int) -> None:
|
52 |
+
transformers.enable_full_determinism(seed, warn_only=True)
|
53 |
+
|
54 |
+
|
55 |
+
def get_rank() -> int:
|
56 |
+
if torch.distributed.is_initialized():
|
57 |
+
return torch.distributed.get_rank()
|
58 |
+
else:
|
59 |
+
return 0
|
60 |
+
|
61 |
+
|
62 |
+
class ChameleonTokenizationMixin:
|
63 |
+
def png_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> bytes:
|
64 |
+
img = self.pillow_from_bpe_tokens(bpe_tokens)
|
65 |
+
|
66 |
+
img_io = io.BytesIO()
|
67 |
+
img.save(img_io, format="PNG")
|
68 |
+
return img_io.getvalue()
|
69 |
+
|
70 |
+
def pillow_from_bpe_tokens(self, bpe_tokens: torch.Tensor) -> PIL.Image.Image:
|
71 |
+
image_tensor = VocabTranslation(self.vocab).convert_bpe2img(bpe_tokens)
|
72 |
+
if image_tensor.shape[0] < 1024:
|
73 |
+
padding = (
|
74 |
+
torch.ones([1024 - image_tensor.shape[0]], dtype=int) * image_tensor[0]
|
75 |
+
)
|
76 |
+
image_tensor = torch.cat((image_tensor, padding)).unsqueeze(0)
|
77 |
+
|
78 |
+
return self.image_tokenizer.pil_from_img_toks(image_tensor)
|
79 |
+
|
80 |
+
def tokens_from_inputs(
|
81 |
+
self,
|
82 |
+
inputs: MixedSequenceType,
|
83 |
+
suffix_tokens: list[str] | None = None,
|
84 |
+
) -> list[int]:
|
85 |
+
tokens = [self.vocab.bos_id]
|
86 |
+
for input_ in inputs:
|
87 |
+
if isinstance(input_, str):
|
88 |
+
tokens.extend(self.tokenizer.encode(input_.strip()).ids)
|
89 |
+
elif isinstance(input_, PIL.Image.Image):
|
90 |
+
tokens.append(self.vocab.begin_image)
|
91 |
+
imgtoks = self.image_tokenizer.img_tokens_from_pil(input_)
|
92 |
+
tokens.extend(VocabTranslation(self.vocab).convert_img2bp2(imgtoks))
|
93 |
+
tokens.append(self.vocab.end_image)
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unknown input type: {type(input_)}")
|
96 |
+
|
97 |
+
if suffix_tokens is not None:
|
98 |
+
for t in suffix_tokens:
|
99 |
+
tokens.extend(self.tokenizer.encode(t).ids)
|
100 |
+
sanitized_tokens = []
|
101 |
+
for t in tokens:
|
102 |
+
if isinstance(t, torch.Tensor):
|
103 |
+
sanitized_tokens.append(t.item())
|
104 |
+
else:
|
105 |
+
sanitized_tokens.append(t)
|
106 |
+
return sanitized_tokens
|
107 |
+
|
108 |
+
|
109 |
+
class GeneratorWrapper:
|
110 |
+
def __init__(self, gen):
|
111 |
+
self.gen = gen
|
112 |
+
|
113 |
+
def __iter__(self):
|
114 |
+
return self
|
115 |
+
|
116 |
+
def __next__(self):
|
117 |
+
return next(self.gen)
|
118 |
+
|
119 |
+
|
120 |
+
class Decoder:
|
121 |
+
def __init__(
|
122 |
+
self,
|
123 |
+
chameleon_generator: "ChameleonLocalGenerator",
|
124 |
+
input_ids: list[int],
|
125 |
+
):
|
126 |
+
...
|
127 |
+
|
128 |
+
def __next__(self) -> tuple[list[int], dict | None, type["Decoder"] | None]:
|
129 |
+
...
|
130 |
+
|
131 |
+
|
132 |
+
class TextDecoder(Decoder):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
chameleon_generator: "ChameleonLocalGenerator",
|
136 |
+
input_ids: list[int],
|
137 |
+
*,
|
138 |
+
temp: float,
|
139 |
+
top_p: float,
|
140 |
+
max_seq_len: int,
|
141 |
+
# TODO: Propagage setting upwards
|
142 |
+
repetition_penalty: float,
|
143 |
+
**kwargs,
|
144 |
+
):
|
145 |
+
self.chameleon_generator = chameleon_generator
|
146 |
+
assert chameleon_generator.vocab.eos_id is not None
|
147 |
+
|
148 |
+
stopping_criteria = [
|
149 |
+
StopOnEOS(chameleon_generator.vocab.eos_id),
|
150 |
+
MaxLengthCriteria(max_seq_len),
|
151 |
+
]
|
152 |
+
if chameleon_generator.additional_eos_tokens is not None:
|
153 |
+
for token in chameleon_generator.additional_eos_tokens:
|
154 |
+
stopping_criteria.append(
|
155 |
+
StopOnEOSAfterBatchIndex(
|
156 |
+
chameleon_generator.tokenizer.token_to_id(token), [len(input_ids)]
|
157 |
+
)
|
158 |
+
)
|
159 |
+
|
160 |
+
logits_processors = [
|
161 |
+
AllowOnlyTokensLogitsProcessor(
|
162 |
+
chameleon_generator.vocab.text_tokens
|
163 |
+
+ [chameleon_generator.vocab.eos_id, chameleon_generator.vocab.begin_image]
|
164 |
+
),
|
165 |
+
# Don't allow any more images near the end since there isn't enough room
|
166 |
+
AllowOnlyTokensAfterIndexLogitsProcessor(
|
167 |
+
chameleon_generator.vocab.text_tokens + [chameleon_generator.vocab.eos_id],
|
168 |
+
# TODO: Calculate exact
|
169 |
+
1024 * 3 - 3,
|
170 |
+
),
|
171 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
172 |
+
TemperatureLogitsWarper(temp),
|
173 |
+
TopPLogitsWarper(top_p),
|
174 |
+
]
|
175 |
+
|
176 |
+
self.gen = ChameleonGenerator(
|
177 |
+
model=ChameleonModelAdapter(chameleon_generator.model, max_seq_len=max_seq_len),
|
178 |
+
input_ids=[input_ids],
|
179 |
+
stopping_criteria=stopping_criteria,
|
180 |
+
logits_processors=logits_processors,
|
181 |
+
)
|
182 |
+
for _ in range(len(input_ids)):
|
183 |
+
next(self.gen)
|
184 |
+
|
185 |
+
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
|
186 |
+
gpu_tok = next(self.gen).id.item()
|
187 |
+
cpu_tok = gpu_tok
|
188 |
+
if cpu_tok == self.chameleon_generator.vocab.begin_image:
|
189 |
+
# return "TEXT", [cpu_tok], [], False, ImageDecoder
|
190 |
+
raise StopIteration()
|
191 |
+
|
192 |
+
return (
|
193 |
+
"TEXT",
|
194 |
+
[cpu_tok],
|
195 |
+
[cpu_tok],
|
196 |
+
False,
|
197 |
+
None,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
class ImageDecoder(Decoder):
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
chameleon_generator: "ChameleonLocalGenerator",
|
205 |
+
input_ids: list[int],
|
206 |
+
*,
|
207 |
+
cfg_image_weight: float,
|
208 |
+
cfg_text_weight: float,
|
209 |
+
temp: float,
|
210 |
+
top_p: float,
|
211 |
+
yield_every_n: int,
|
212 |
+
**kwargs,
|
213 |
+
):
|
214 |
+
self.yield_every_n = yield_every_n
|
215 |
+
self.chameleon_generator = chameleon_generator
|
216 |
+
logits_processors = [
|
217 |
+
InBatchInstructCFGLogitsProcessor(cfg_text_weight, cfg_image_weight),
|
218 |
+
AllowOnlyTokensLogitsProcessor(chameleon_generator.vocab.image_tokens),
|
219 |
+
TemperatureLogitsWarper(temp),
|
220 |
+
TopPLogitsWarper(top_p),
|
221 |
+
]
|
222 |
+
|
223 |
+
image_conditioned_allowed = set(chameleon_generator.vocab.image_tokens) | {
|
224 |
+
chameleon_generator.vocab.bos_id,
|
225 |
+
chameleon_generator.vocab.begin_image,
|
226 |
+
chameleon_generator.vocab.end_image,
|
227 |
+
}
|
228 |
+
|
229 |
+
full_conditioned = input_ids
|
230 |
+
image_conditioned = [
|
231 |
+
in_id for in_id in input_ids if in_id in image_conditioned_allowed
|
232 |
+
]
|
233 |
+
unconditioned = [
|
234 |
+
chameleon_generator.vocab.bos_id,
|
235 |
+
chameleon_generator.vocab.begin_image,
|
236 |
+
]
|
237 |
+
|
238 |
+
self.gen = ChameleonGenerator(
|
239 |
+
model=ChameleonModelAdapter(
|
240 |
+
chameleon_generator.model, max_seq_len=len(input_ids) + 1024
|
241 |
+
),
|
242 |
+
input_ids=[full_conditioned, image_conditioned, unconditioned],
|
243 |
+
logits_processors=logits_processors,
|
244 |
+
alignment=AlignPromptRight(chameleon_generator.vocab.pad_id),
|
245 |
+
token_selector=ReplicatedInputTokenSelector(
|
246 |
+
MultinomialTokenSelector(), n=3
|
247 |
+
),
|
248 |
+
)
|
249 |
+
for _ in range(len(input_ids)):
|
250 |
+
next(self.gen)
|
251 |
+
self.image_builder: list[torch.LongTensor] = []
|
252 |
+
self.gpu_tok_batch: list[torch.LongTensor] = []
|
253 |
+
|
254 |
+
def __next__(self) -> tuple[list[int], dict | None, type[Decoder] | None]:
|
255 |
+
while True:
|
256 |
+
gpu_tok = next(self.gen)
|
257 |
+
gpu_tok = torch.chunk(gpu_tok, chunks=3, dim=0)[0]
|
258 |
+
|
259 |
+
self.image_builder.append(gpu_tok)
|
260 |
+
self.gpu_tok_batch.append(gpu_tok)
|
261 |
+
|
262 |
+
if len(self.image_builder) == 1024:
|
263 |
+
return (
|
264 |
+
"IMAGE",
|
265 |
+
torch.tensor(self.gpu_tok_batch).tolist()
|
266 |
+
+ [self.chameleon_generator.vocab.end_image],
|
267 |
+
torch.tensor(self.image_builder).tolist(),
|
268 |
+
True,
|
269 |
+
TextDecoder,
|
270 |
+
)
|
271 |
+
elif len(self.image_builder) % self.yield_every_n == 0:
|
272 |
+
cpu_toks = torch.tensor(self.gpu_tok_batch).tolist()
|
273 |
+
self.gpu_tok_batch = []
|
274 |
+
|
275 |
+
return (
|
276 |
+
"IMAGE",
|
277 |
+
cpu_toks,
|
278 |
+
torch.tensor(self.image_builder).tolist(),
|
279 |
+
False,
|
280 |
+
None,
|
281 |
+
)
|
282 |
+
|
283 |
+
|
284 |
+
class ChameleonForwardMixin:
|
285 |
+
@torch.inference_mode()
|
286 |
+
def _generate_text_streaming(
|
287 |
+
self,
|
288 |
+
input_ids: list[int],
|
289 |
+
max_gen_tokens: int = 256,
|
290 |
+
temp: float = 1.0,
|
291 |
+
top_p: float = 0.8,
|
292 |
+
repetition_penalty: float = 1.2,
|
293 |
+
seed: int | None = None,
|
294 |
+
) -> Generator[str, None, None]:
|
295 |
+
if seed is not None:
|
296 |
+
set_seed(seed)
|
297 |
+
logger.info(
|
298 |
+
"Rank: %s, set seed: %s",
|
299 |
+
get_rank(),
|
300 |
+
seed,
|
301 |
+
)
|
302 |
+
|
303 |
+
logits_processors = [
|
304 |
+
# Only allow text tokens and end-of-sequence.
|
305 |
+
AllowOnlyTokensLogitsProcessor(
|
306 |
+
self.vocab.text_tokens + [self.vocab.eos_id]
|
307 |
+
),
|
308 |
+
# Don't allow the first token to be end-of-sequence.
|
309 |
+
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
|
310 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
311 |
+
TemperatureLogitsWarper(temp),
|
312 |
+
TopPLogitsWarper(top_p),
|
313 |
+
]
|
314 |
+
|
315 |
+
stopping_criteria = [
|
316 |
+
StopOnEOS(self.vocab.eos_id),
|
317 |
+
MaxLengthCriteria(len(input_ids) + max_gen_tokens),
|
318 |
+
]
|
319 |
+
if self.additional_eos_tokens is not None:
|
320 |
+
for token in self.additional_eos_tokens:
|
321 |
+
stopping_criteria.append(
|
322 |
+
StopOnEOSAfterBatchIndex(
|
323 |
+
self.tokenizer.token_to_id(token), [len(input_ids)]
|
324 |
+
)
|
325 |
+
)
|
326 |
+
for tok in ChameleonGenerator(
|
327 |
+
model=ChameleonModelAdapter(
|
328 |
+
self.model,
|
329 |
+
max_seq_len=len(input_ids) + max_gen_tokens,
|
330 |
+
),
|
331 |
+
input_ids=[input_ids],
|
332 |
+
stopping_criteria=stopping_criteria,
|
333 |
+
logits_processors=logits_processors,
|
334 |
+
):
|
335 |
+
yield tok.tolist()
|
336 |
+
|
337 |
+
@torch.inference_mode()
|
338 |
+
def _generate_batched_text_streaming(
|
339 |
+
self,
|
340 |
+
batch: list[list[int]],
|
341 |
+
max_gen_tokens: int = 256,
|
342 |
+
temp: float = 1.0,
|
343 |
+
top_p: float = 0.8,
|
344 |
+
repetition_penalty: float = 1.2,
|
345 |
+
seed: int | None = None,
|
346 |
+
) -> Generator[list[str], None, None]:
|
347 |
+
if seed is not None:
|
348 |
+
set_seed(seed)
|
349 |
+
logits_processors = [
|
350 |
+
# Only allow text tokens and end-of-sequence.
|
351 |
+
AllowOnlyTokensLogitsProcessor(
|
352 |
+
self.vocab.text_tokens + [self.vocab.eos_id]
|
353 |
+
),
|
354 |
+
# Don't allow the first token to be end-of-sequence.
|
355 |
+
# DisallowTokensAtIndexLogitProcessor([self.vocab.eos_id], len()),
|
356 |
+
RepetitionPenaltyLogitsProcessor(repetition_penalty),
|
357 |
+
TemperatureLogitsWarper(temp),
|
358 |
+
TopPLogitsWarper(top_p),
|
359 |
+
]
|
360 |
+
|
361 |
+
max_batch_size = max(len(p) for p in batch)
|
362 |
+
stopping_criteria = [
|
363 |
+
StopOnEOS(self.vocab.eos_id),
|
364 |
+
MaxLengthCriteria(max_batch_size + max_gen_tokens),
|
365 |
+
]
|
366 |
+
if self.additional_eos_tokens is not None:
|
367 |
+
for token in self.additional_eos_tokens:
|
368 |
+
stopping_criteria.append(
|
369 |
+
StopOnEOSAfterBatchIndex(
|
370 |
+
self.tokenizer.token_to_id(token), [len(x) for x in batch]
|
371 |
+
)
|
372 |
+
)
|
373 |
+
for tok in ChameleonGenerator(
|
374 |
+
model=ChameleonModelAdapter(
|
375 |
+
self.model,
|
376 |
+
max_seq_len=max_batch_size + max_gen_tokens,
|
377 |
+
),
|
378 |
+
input_ids=batch,
|
379 |
+
stopping_criteria=stopping_criteria,
|
380 |
+
logits_processors=logits_processors,
|
381 |
+
):
|
382 |
+
yield tok.unsqueeze(1).tolist()
|
383 |
+
|
384 |
+
@torch.inference_mode()
|
385 |
+
def _generate_image_streaming(
|
386 |
+
self,
|
387 |
+
tokenized_prompt: list[int],
|
388 |
+
temp: float = 1.0,
|
389 |
+
top_p: float = 0.8,
|
390 |
+
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
|
391 |
+
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
|
392 |
+
yield_every_n: int = 32,
|
393 |
+
seed: int | None = None,
|
394 |
+
) -> Generator[tuple[list[int], bool], None, None]:
|
395 |
+
if seed is not None:
|
396 |
+
set_seed(seed)
|
397 |
+
logger.info(
|
398 |
+
"Rank: %s, set seed: %s",
|
399 |
+
get_rank(),
|
400 |
+
seed,
|
401 |
+
)
|
402 |
+
|
403 |
+
decoder = ImageDecoder(
|
404 |
+
self,
|
405 |
+
tokenized_prompt,
|
406 |
+
cfg_image_weight=cfg_image_weight,
|
407 |
+
cfg_text_weight=cfg_text_weight,
|
408 |
+
temp=temp,
|
409 |
+
top_p=top_p,
|
410 |
+
yield_every_n=yield_every_n,
|
411 |
+
)
|
412 |
+
|
413 |
+
for _, _, frontend_tokens, is_final, next_decoder in GeneratorWrapper(decoder):
|
414 |
+
if next_decoder is not None:
|
415 |
+
break
|
416 |
+
|
417 |
+
yield torch.tensor(frontend_tokens).tolist(), is_final
|
418 |
+
|
419 |
+
@torch.inference_mode()
|
420 |
+
def _generate_multimodal_streaming(
|
421 |
+
self,
|
422 |
+
input_ids: list[int],
|
423 |
+
temp: float = 1.0,
|
424 |
+
top_p: float = 0.8,
|
425 |
+
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
|
426 |
+
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
|
427 |
+
yield_every_n: int = 32,
|
428 |
+
max_gen_tokens: int = 4096,
|
429 |
+
repetition_penalty: float = 1.2,
|
430 |
+
seed: int | None = None,
|
431 |
+
) -> Generator[tuple[str, list[int], bool], None, None]:
|
432 |
+
if seed is not None:
|
433 |
+
set_seed(seed)
|
434 |
+
logger.info(
|
435 |
+
"Rank: %s, set seed: %s",
|
436 |
+
get_rank(),
|
437 |
+
seed,
|
438 |
+
)
|
439 |
+
max_seq_len = min(len(input_ids) + max_gen_tokens, 4096)
|
440 |
+
gen_wrapper = GeneratorWrapper(
|
441 |
+
TextDecoder(
|
442 |
+
self,
|
443 |
+
input_ids,
|
444 |
+
temp=temp,
|
445 |
+
top_p=top_p,
|
446 |
+
max_seq_len=max_seq_len,
|
447 |
+
repetition_penalty=repetition_penalty,
|
448 |
+
)
|
449 |
+
)
|
450 |
+
|
451 |
+
for (
|
452 |
+
message_type,
|
453 |
+
cpu_toks,
|
454 |
+
frontend_tokens,
|
455 |
+
is_final,
|
456 |
+
next_decoder,
|
457 |
+
) in gen_wrapper:
|
458 |
+
input_ids.extend(cpu_toks)
|
459 |
+
if len(frontend_tokens) > 0:
|
460 |
+
yield message_type, frontend_tokens, is_final
|
461 |
+
if next_decoder is not None:
|
462 |
+
gen_wrapper.gen = next_decoder(
|
463 |
+
self,
|
464 |
+
input_ids,
|
465 |
+
temp=temp,
|
466 |
+
top_p=top_p,
|
467 |
+
max_seq_len=max_seq_len,
|
468 |
+
cfg_image_weight=cfg_image_weight,
|
469 |
+
cfg_text_weight=cfg_text_weight,
|
470 |
+
yield_every_n=yield_every_n,
|
471 |
+
repetition_penalty=repetition_penalty,
|
472 |
+
)
|
473 |
+
|
474 |
+
|
475 |
+
class ChameleonLocalGenerator(
|
476 |
+
AbstractMultimodalGenerator, ChameleonForwardMixin, ChameleonTokenizationMixin
|
477 |
+
):
|
478 |
+
def __init__(
|
479 |
+
self,
|
480 |
+
model_path: str,
|
481 |
+
tokenizer_path: str,
|
482 |
+
vqgan_config_path: str,
|
483 |
+
vqgan_ckpt_path: str | None = None,
|
484 |
+
additional_eos_tokens: list[str] | None = None,
|
485 |
+
) -> None:
|
486 |
+
super().__init__()
|
487 |
+
logger.info("Loading model...")
|
488 |
+
self.model = load_model(model_path)
|
489 |
+
self.additional_eos_tokens = additional_eos_tokens
|
490 |
+
|
491 |
+
logger.info("Loading tokenizer...")
|
492 |
+
tokenizer_path = tokenizer_path
|
493 |
+
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
494 |
+
self.vocab = VocabInfo(json.load(open(tokenizer_path))["model"]["vocab"])
|
495 |
+
|
496 |
+
logger.info("Loading VQGAN...")
|
497 |
+
self.image_tokenizer = ImageTokenizer(vqgan_config_path, vqgan_ckpt_path)
|
498 |
+
|
499 |
+
@torch.inference_mode()
|
500 |
+
def generate_batched_text(
|
501 |
+
self,
|
502 |
+
prompts: list[MixedSequenceType],
|
503 |
+
max_gen_tokens: int = 256,
|
504 |
+
temp: float = 1.0,
|
505 |
+
top_p: float = 0.8,
|
506 |
+
repetition_penalty: float = 1.2,
|
507 |
+
seed: int | None = None,
|
508 |
+
) -> list[str]:
|
509 |
+
outputs = [""] * len(prompts)
|
510 |
+
for vals in self.generate_batched_text_streaming(
|
511 |
+
prompts,
|
512 |
+
max_gen_tokens=max_gen_tokens,
|
513 |
+
temp=temp,
|
514 |
+
top_p=top_p,
|
515 |
+
repetition_penalty=repetition_penalty,
|
516 |
+
seed=seed,
|
517 |
+
):
|
518 |
+
for idx, val in enumerate(vals):
|
519 |
+
outputs[idx] += val
|
520 |
+
return outputs
|
521 |
+
|
522 |
+
@torch.inference_mode()
|
523 |
+
def generate_batched_text_streaming(
|
524 |
+
self,
|
525 |
+
prompts: list[MixedSequenceType],
|
526 |
+
max_gen_tokens: int = 256,
|
527 |
+
temp: float = 1.0,
|
528 |
+
top_p: float = 0.8,
|
529 |
+
repetition_penalty: float = 1.2,
|
530 |
+
seed: int | None = None,
|
531 |
+
) -> Generator[list[str], None, None]:
|
532 |
+
batch = []
|
533 |
+
for prompt in prompts:
|
534 |
+
batch.append(self.tokens_from_inputs(prompt))
|
535 |
+
|
536 |
+
for tok in self._generate_batched_text_streaming(
|
537 |
+
batch,
|
538 |
+
max_gen_tokens=max_gen_tokens,
|
539 |
+
temp=temp,
|
540 |
+
top_p=top_p,
|
541 |
+
repetition_penalty=repetition_penalty,
|
542 |
+
seed=seed,
|
543 |
+
):
|
544 |
+
yield self.tokenizer.decode_batch(tok)
|
545 |
+
|
546 |
+
@torch.inference_mode()
|
547 |
+
async def generate_text_streaming(
|
548 |
+
self,
|
549 |
+
prompt: MixedSequenceType,
|
550 |
+
max_gen_tokens: int = 256,
|
551 |
+
temp: float = 1.0,
|
552 |
+
top_p: float = 0.8,
|
553 |
+
repetition_penalty: float = 1.2,
|
554 |
+
seed: int | None = None,
|
555 |
+
debug: dict | None = None,
|
556 |
+
) -> Generator[str, None, None]:
|
557 |
+
tokenized_prompt = self.tokens_from_inputs(prompt)
|
558 |
+
if len(tokenized_prompt) > (4096 - 3):
|
559 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
|
560 |
+
return
|
561 |
+
for out in self.generate_batched_text_streaming(
|
562 |
+
[prompt],
|
563 |
+
max_gen_tokens=max_gen_tokens,
|
564 |
+
temp=temp,
|
565 |
+
top_p=top_p,
|
566 |
+
repetition_penalty=repetition_penalty,
|
567 |
+
seed=seed,
|
568 |
+
):
|
569 |
+
yield out[0]
|
570 |
+
|
571 |
+
@torch.inference_mode()
|
572 |
+
async def generate_image_streaming(
|
573 |
+
self,
|
574 |
+
prompt: MixedSequenceType,
|
575 |
+
temp: float = 1.0,
|
576 |
+
top_p: float = 0.8,
|
577 |
+
cfg_image_weight: float = DEFAULT_IMAGE_CFG_IMAGE,
|
578 |
+
cfg_text_weight: float = DEFAULT_IMAGE_CFG_TEXT,
|
579 |
+
yield_every_n: int = 32,
|
580 |
+
seed: int | None = None,
|
581 |
+
debug: dict | None = None,
|
582 |
+
) -> Generator[StreamingImage, None, None]:
|
583 |
+
assert isinstance(prompt, list)
|
584 |
+
tokenized_prompt = self.tokens_from_inputs(prompt)
|
585 |
+
tokenized_prompt.append(self.vocab.begin_image)
|
586 |
+
if len(tokenized_prompt) > (4096 - 3 - 1024):
|
587 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens whether in input or output."
|
588 |
+
return
|
589 |
+
for tokens, final in self._generate_image_streaming(
|
590 |
+
tokenized_prompt,
|
591 |
+
temp=temp,
|
592 |
+
top_p=top_p,
|
593 |
+
cfg_image_weight=cfg_image_weight,
|
594 |
+
cfg_text_weight=cfg_text_weight,
|
595 |
+
yield_every_n=yield_every_n,
|
596 |
+
seed=seed,
|
597 |
+
):
|
598 |
+
yield StreamingImage(
|
599 |
+
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)), final=final
|
600 |
+
)
|
601 |
+
|
602 |
+
@torch.inference_mode()
|
603 |
+
async def generate_multimodal_streaming(
|
604 |
+
self,
|
605 |
+
prompt: MixedSequenceType,
|
606 |
+
temp: float = 1.0,
|
607 |
+
top_p: float = 0.8,
|
608 |
+
cfg_image_weight: float = DEFAULT_MULTIMODAL_CFG_IMAGE,
|
609 |
+
cfg_text_weight: float = DEFAULT_MULTIMODAL_CFG_TEXT,
|
610 |
+
yield_every_n: int = 32,
|
611 |
+
max_gen_tokens: int = 4096,
|
612 |
+
repetition_penalty: float = 1.2,
|
613 |
+
suffix_tokens: list[str] | None = None,
|
614 |
+
seed: int | None = None,
|
615 |
+
debug: dict | None = None,
|
616 |
+
) -> Generator[MixedSequenceType, None, None]:
|
617 |
+
input_ids = self.tokens_from_inputs(prompt, suffix_tokens=suffix_tokens)
|
618 |
+
if len(input_ids) > (4096 - 3):
|
619 |
+
yield "ERROR: Your input exceeds the model's context length of 4096. Note that images consume 1024 tokens."
|
620 |
+
return
|
621 |
+
|
622 |
+
for token_type, tokens, is_final in self._generate_multimodal_streaming(
|
623 |
+
input_ids,
|
624 |
+
temp=temp,
|
625 |
+
top_p=top_p,
|
626 |
+
cfg_image_weight=cfg_image_weight,
|
627 |
+
cfg_text_weight=cfg_text_weight,
|
628 |
+
yield_every_n=yield_every_n,
|
629 |
+
max_gen_tokens=max_gen_tokens,
|
630 |
+
repetition_penalty=repetition_penalty,
|
631 |
+
seed=seed,
|
632 |
+
):
|
633 |
+
match token_type:
|
634 |
+
case "TEXT":
|
635 |
+
yield self.tokenizer.decode(tokens)
|
636 |
+
case "IMAGE":
|
637 |
+
yield StreamingImage(
|
638 |
+
image=self.pillow_from_bpe_tokens(torch.tensor(tokens)),
|
639 |
+
final=is_final,
|
640 |
+
)
|
641 |
+
case _:
|
642 |
+
raise ValueError("Unknown token type")
|
chameleon/viewer/backend/models/service.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import base64
|
7 |
+
import io
|
8 |
+
import socket
|
9 |
+
import subprocess
|
10 |
+
import time
|
11 |
+
from functools import partial
|
12 |
+
|
13 |
+
import fastapi
|
14 |
+
import PIL
|
15 |
+
import pydantic
|
16 |
+
import redis.asyncio as async_redis
|
17 |
+
import uvicorn
|
18 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, WebSocketException
|
19 |
+
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
20 |
+
|
21 |
+
from chameleon.viewer.backend.data_types import (
|
22 |
+
Content,
|
23 |
+
ContentType,
|
24 |
+
NoOptionsForComplete,
|
25 |
+
NoOptionsForFull,
|
26 |
+
NoOptionsForPartial,
|
27 |
+
NoOptionsForQueueStatus,
|
28 |
+
WSMessageType,
|
29 |
+
WSMultimodalMessage,
|
30 |
+
)
|
31 |
+
from chameleon.viewer.backend.models.abstract_model import (
|
32 |
+
AbstractMultimodalGenerator,
|
33 |
+
StreamingImage,
|
34 |
+
)
|
35 |
+
from chameleon.viewer.backend.models.chameleon_distributed import AsyncRedisCounter
|
36 |
+
from chameleon.viewer.backend.utils import get_logger
|
37 |
+
|
38 |
+
logger = get_logger(__name__)
|
39 |
+
|
40 |
+
|
41 |
+
def nvidia_smi() -> str:
|
42 |
+
return subprocess.check_output(["nvidia-smi"], text=True)
|
43 |
+
|
44 |
+
|
45 |
+
async def await_generate_message(websocket: WebSocket) -> WSMultimodalMessage:
|
46 |
+
while True:
|
47 |
+
rec_message = await websocket.receive_json()
|
48 |
+
try:
|
49 |
+
maybe_message = WSMultimodalMessage.parse_obj(rec_message)
|
50 |
+
except pydantic.ValidationError:
|
51 |
+
maybe_message = None
|
52 |
+
logger.info("Got invalid message", maybe_message)
|
53 |
+
if maybe_message is not None:
|
54 |
+
return maybe_message
|
55 |
+
|
56 |
+
|
57 |
+
async def async_acquire_lock(
|
58 |
+
*,
|
59 |
+
websocket: WebSocket,
|
60 |
+
counter: AsyncRedisCounter,
|
61 |
+
lock: async_redis.lock.Lock,
|
62 |
+
interval=0.1,
|
63 |
+
status_interval=1,
|
64 |
+
hostname: str | None = None,
|
65 |
+
):
|
66 |
+
start = time.time()
|
67 |
+
await counter.add(1)
|
68 |
+
while True:
|
69 |
+
acquired = await lock.acquire(blocking_timeout=interval)
|
70 |
+
if acquired:
|
71 |
+
break
|
72 |
+
elapsed = time.time() - start
|
73 |
+
if elapsed > status_interval:
|
74 |
+
n_requests = await counter.count()
|
75 |
+
message = WSMultimodalMessage(
|
76 |
+
message_type=WSMessageType.QUEUE_STATUS,
|
77 |
+
content=[
|
78 |
+
Content(
|
79 |
+
content_type=ContentType.TEXT,
|
80 |
+
content=f"n_requests={n_requests}",
|
81 |
+
)
|
82 |
+
],
|
83 |
+
options=NoOptionsForQueueStatus(),
|
84 |
+
debug_info={"hostname": hostname},
|
85 |
+
).dict()
|
86 |
+
await websocket.send_json(message)
|
87 |
+
start = time.time()
|
88 |
+
await counter.sub(1)
|
89 |
+
|
90 |
+
|
91 |
+
COORDINATOR = "coordinator"
|
92 |
+
|
93 |
+
|
94 |
+
def web_app(
|
95 |
+
generator: AbstractMultimodalGenerator,
|
96 |
+
debug: bool = True,
|
97 |
+
redis_port: int | None = None,
|
98 |
+
) -> FastAPI:
|
99 |
+
app = FastAPI(debug=debug)
|
100 |
+
if redis_port is None:
|
101 |
+
redis_client = None
|
102 |
+
redis_lock = None
|
103 |
+
queue_counter = None
|
104 |
+
else:
|
105 |
+
redis_client = async_redis.Redis.from_url(f"redis://redis:{redis_port}")
|
106 |
+
redis_lock = async_redis.lock.Lock(redis_client, COORDINATOR)
|
107 |
+
queue_counter = AsyncRedisCounter(redis_client, "count_pending")
|
108 |
+
hostname = socket.gethostname()
|
109 |
+
|
110 |
+
@app.get("/api/2.0/status")
|
111 |
+
def alive() -> dict:
|
112 |
+
return {
|
113 |
+
"status": "alive",
|
114 |
+
"hostname": hostname,
|
115 |
+
"nvidia-smi": nvidia_smi(),
|
116 |
+
}
|
117 |
+
|
118 |
+
@app.websocket("/ws/chameleon/v2/{client_id}")
|
119 |
+
async def websocket_chameleon_v2(*, websocket: WebSocket, client_id: str):
|
120 |
+
logger.info("Requested client_id: %s", client_id)
|
121 |
+
await websocket.accept()
|
122 |
+
logger.info("Client opened %s with generator id %s", client_id, id(generator))
|
123 |
+
|
124 |
+
try:
|
125 |
+
while True:
|
126 |
+
generate_message = await await_generate_message(websocket)
|
127 |
+
logger.info("Got generate message: %s", str(generate_message)[:300])
|
128 |
+
parsed_prompt = []
|
129 |
+
for c in generate_message.content:
|
130 |
+
match c.content_type:
|
131 |
+
case ContentType.TEXT:
|
132 |
+
parsed_prompt.append(c.content)
|
133 |
+
case ContentType.IMAGE:
|
134 |
+
image_parts = c.content.split(",", 1)
|
135 |
+
if len(image_parts) < 2:
|
136 |
+
logger.error(
|
137 |
+
"Encountered invalid image: %s", image_parts
|
138 |
+
)
|
139 |
+
raise WebSocketException(
|
140 |
+
code=fastapi.status.WS_1008_POLICY_VIOLATION,
|
141 |
+
reason=f"Invalid image: {image_parts}",
|
142 |
+
)
|
143 |
+
image_data = image_parts[1]
|
144 |
+
base64_image = base64.b64decode(image_data)
|
145 |
+
image_file = io.BytesIO(base64_image)
|
146 |
+
parsed_prompt.append(PIL.Image.open(image_file))
|
147 |
+
case _:
|
148 |
+
raise ValueError("Unknown content type")
|
149 |
+
logger.info("Prompt: %s", parsed_prompt)
|
150 |
+
partial_outputs = []
|
151 |
+
final_contents: list[Content] = []
|
152 |
+
|
153 |
+
match generate_message.message_type:
|
154 |
+
case WSMessageType.GENERATE_TEXT:
|
155 |
+
output_generator = generator.generate_text_streaming
|
156 |
+
case WSMessageType.GENERATE_IMAGE:
|
157 |
+
output_generator = generator.generate_image_streaming
|
158 |
+
case WSMessageType.GENERATE_MULTIMODAL:
|
159 |
+
output_generator = generator.generate_multimodal_streaming
|
160 |
+
case _:
|
161 |
+
raise WebSocketException(
|
162 |
+
code=fastapi.status.WS_1008_POLICY_VIOLATION,
|
163 |
+
reason="Unknown message type",
|
164 |
+
)
|
165 |
+
|
166 |
+
logger.info(
|
167 |
+
"Acquiring lock for client %s generation with options: %s",
|
168 |
+
client_id,
|
169 |
+
generate_message.options,
|
170 |
+
)
|
171 |
+
option_args = generate_message.options.dict()
|
172 |
+
debug_info = {"hostname": hostname}
|
173 |
+
del option_args["message_type"]
|
174 |
+
output_generator = partial(
|
175 |
+
output_generator,
|
176 |
+
**option_args,
|
177 |
+
debug=debug_info,
|
178 |
+
)
|
179 |
+
if redis_lock is not None:
|
180 |
+
await async_acquire_lock(
|
181 |
+
websocket=websocket,
|
182 |
+
lock=redis_lock,
|
183 |
+
hostname=hostname,
|
184 |
+
counter=queue_counter,
|
185 |
+
)
|
186 |
+
await redis_client.set("has_lock", client_id)
|
187 |
+
|
188 |
+
logger.info(
|
189 |
+
"Starting locked generation for client %s with options: %s",
|
190 |
+
client_id,
|
191 |
+
generate_message.options,
|
192 |
+
)
|
193 |
+
try:
|
194 |
+
async for output_token in output_generator(parsed_prompt):
|
195 |
+
if isinstance(output_token, str):
|
196 |
+
content_type = ContentType.TEXT
|
197 |
+
content = output_token
|
198 |
+
message_type = WSMessageType.PARTIAL_OUTPUT
|
199 |
+
options = NoOptionsForPartial()
|
200 |
+
partial_outputs.extend(output_token)
|
201 |
+
elif isinstance(output_token, StreamingImage):
|
202 |
+
content_type = ContentType.IMAGE
|
203 |
+
image = output_token.image
|
204 |
+
img_io = io.BytesIO()
|
205 |
+
image.save(img_io, format="png")
|
206 |
+
content = (
|
207 |
+
"data:image/png;base64,"
|
208 |
+
+ base64.b64encode(img_io.getvalue()).decode()
|
209 |
+
)
|
210 |
+
if output_token.final:
|
211 |
+
message_type = WSMessageType.FULL_OUTPUT
|
212 |
+
options = NoOptionsForFull()
|
213 |
+
else:
|
214 |
+
message_type = WSMessageType.PARTIAL_OUTPUT
|
215 |
+
options = NoOptionsForPartial()
|
216 |
+
|
217 |
+
if output_token.final:
|
218 |
+
partial_outputs.append(output_token.image)
|
219 |
+
else:
|
220 |
+
raise ValueError(f"Invalid output_token: {output_token}")
|
221 |
+
|
222 |
+
message_content = Content(
|
223 |
+
content_type=content_type, content=content
|
224 |
+
)
|
225 |
+
match content_type:
|
226 |
+
case ContentType.TEXT:
|
227 |
+
final_contents.append(message_content)
|
228 |
+
case ContentType.IMAGE:
|
229 |
+
if message_type == WSMessageType.FULL_OUTPUT:
|
230 |
+
final_contents.append(message_content)
|
231 |
+
case _:
|
232 |
+
pass
|
233 |
+
|
234 |
+
message = WSMultimodalMessage(
|
235 |
+
message_type=message_type,
|
236 |
+
content=[message_content],
|
237 |
+
options=options,
|
238 |
+
debug_info=debug_info,
|
239 |
+
).dict()
|
240 |
+
await websocket.send_json(message)
|
241 |
+
finally:
|
242 |
+
if redis_lock is not None:
|
243 |
+
logger.info(
|
244 |
+
"Attempting release of lock for client %s generation with options: %s",
|
245 |
+
client_id,
|
246 |
+
generate_message.options,
|
247 |
+
)
|
248 |
+
owned = await redis_lock.owned()
|
249 |
+
if owned:
|
250 |
+
await redis_client.set("has_lock", "")
|
251 |
+
try:
|
252 |
+
await redis_lock.release()
|
253 |
+
except async_redis.lock.LockError:
|
254 |
+
pass
|
255 |
+
|
256 |
+
logger.info(
|
257 |
+
"Released lock for client %s generation with options: %s",
|
258 |
+
client_id,
|
259 |
+
generate_message.options,
|
260 |
+
)
|
261 |
+
await websocket.send_json(
|
262 |
+
WSMultimodalMessage(
|
263 |
+
message_type=WSMessageType.COMPLETE,
|
264 |
+
content=final_contents,
|
265 |
+
options=NoOptionsForComplete(),
|
266 |
+
debug_info=debug_info,
|
267 |
+
).dict()
|
268 |
+
)
|
269 |
+
except WebSocketDisconnect:
|
270 |
+
logger.info("Client disconnected %s", client_id)
|
271 |
+
except ConnectionClosedError:
|
272 |
+
logger.info("Client forced a close %s", client_id)
|
273 |
+
except ConnectionClosedOK:
|
274 |
+
logger.info("Connection closed ok %s", client_id)
|
275 |
+
finally:
|
276 |
+
if redis_lock is not None:
|
277 |
+
logger.info("Checking for client holding lock: %s", client_id)
|
278 |
+
owned = await redis_lock.owned()
|
279 |
+
if owned:
|
280 |
+
try:
|
281 |
+
logger.info("Attempted to release owned lock: %s", client_id)
|
282 |
+
await redis_lock.release()
|
283 |
+
except async_redis.lock.LockError:
|
284 |
+
pass
|
285 |
+
await redis_client.set("has_lock", "")
|
286 |
+
|
287 |
+
return app
|
288 |
+
|
289 |
+
|
290 |
+
def serve(
|
291 |
+
model: AbstractMultimodalGenerator,
|
292 |
+
host: str,
|
293 |
+
port: int,
|
294 |
+
debug: bool = True,
|
295 |
+
redis_port: int | None = None,
|
296 |
+
) -> None:
|
297 |
+
app = web_app(model, debug=debug, redis_port=redis_port)
|
298 |
+
# TODO: convert this to a subprocess call so enable more
|
299 |
+
# uvicorn features like multiple workers
|
300 |
+
uvicorn.run(app, host=host, port=port)
|
chameleon/viewer/backend/requirements.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# If black/isort/pytest change, then update `.circleci/config.yml`
|
2 |
+
black==23.7.0
|
3 |
+
isort==5.12.0
|
4 |
+
pytest==7.4.0
|
5 |
+
rich==13.5.*
|
6 |
+
ipython
|
7 |
+
|
8 |
+
# Do not change, python 3.11 needs this
|
9 |
+
hydra-core==1.3.2
|
10 |
+
typer==0.9.0
|
11 |
+
httpx==0.24.1
|
12 |
+
pylint==2.17.5
|
13 |
+
submitit==1.4.2
|
14 |
+
pudb==2022.1.3
|
15 |
+
|
16 |
+
# These do/should match dependency versions
|
17 |
+
# This is so that the viewer can run without any other deps outside of this file
|
18 |
+
Pillow==10.0.*
|
19 |
+
fastapi==0.101.1
|
20 |
+
pydantic==1.10.*
|
21 |
+
requests==2.31.*
|
22 |
+
uvicorn==0.23.2
|
23 |
+
python-multipart==0.0.6
|
24 |
+
ruff==0.1.2
|
25 |
+
websockets==12.0
|
26 |
+
redis[hiredis]==5.0.1
|
27 |
+
psutil==5.9.7
|
28 |
+
|
29 |
+
# For inference
|
30 |
+
albumentations==1.3.1
|
31 |
+
einops==0.7.0
|
32 |
+
pytorch_lightning==2.1.2
|
33 |
+
transformers==4.36.2
|
34 |
+
xformers==0.0.23
|
35 |
+
torchvision==0.16.*
|
chameleon/viewer/backend/utils.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Chameleon License found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import types
|
8 |
+
|
9 |
+
from rich.logging import RichHandler
|
10 |
+
|
11 |
+
|
12 |
+
def configure_rich_logging():
|
13 |
+
FORMAT = "%(message)s"
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO,
|
16 |
+
handlers=[RichHandler(rich_tracebacks=True)],
|
17 |
+
format=FORMAT,
|
18 |
+
force=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
configure_rich_logging()
|
23 |
+
|
24 |
+
|
25 |
+
def get_logger(module: types.ModuleType) -> logging.Logger:
|
26 |
+
"""This forces logging.basicConfig to be called first."""
|
27 |
+
logger = logging.getLogger(module)
|
28 |
+
return logger
|
chameleon/viewer/frontend/README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Install
|
2 |
+
|
3 |
+
```
|
4 |
+
npm install
|
5 |
+
```
|
6 |
+
|
7 |
+
|
8 |
+
# Run local
|
9 |
+
```
|
10 |
+
npm run dev
|
11 |
+
```
|
chameleon/viewer/frontend/index.html
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- Copyright (c) Meta Platforms, Inc. and affiliates. -->
|
2 |
+
|
3 |
+
<!-- This source code is licensed under the Chameleon License found in the -->
|
4 |
+
<!-- LICENSE file in the root directory of this source tree. -->
|
5 |
+
<!doctype html>
|
6 |
+
<html lang="en">
|
7 |
+
<head>
|
8 |
+
<meta charset="UTF-8" />
|
9 |
+
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
10 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
11 |
+
<title>Chameleon Viewer</title>
|
12 |
+
</head>
|
13 |
+
<body>
|
14 |
+
<div id="root"></div>
|
15 |
+
<script type="module" src="/src/main.tsx"></script>
|
16 |
+
</body>
|
17 |
+
</html>
|
chameleon/viewer/frontend/package-lock.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
chameleon/viewer/frontend/package.json
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "chameleon-frontend",
|
3 |
+
"private": true,
|
4 |
+
"version": "0.0.0",
|
5 |
+
"type": "module",
|
6 |
+
"scripts": {
|
7 |
+
"dev": "vite --host 0.0.0.0 --port 7654",
|
8 |
+
"staging": "vite --mode staging --host 0.0.0.0",
|
9 |
+
"datadev": "vite --mode datadev --host 0.0.0.0",
|
10 |
+
"check-build": "tsc && vite build",
|
11 |
+
"build": "vite build",
|
12 |
+
"lint": "eslint . --ext ts,tsx --report-unused-disable-directives --max-warnings 0",
|
13 |
+
"preview": "vite preview",
|
14 |
+
"check-format": "prettier --check src",
|
15 |
+
"format": "prettier --write src",
|
16 |
+
"test": "vitest"
|
17 |
+
},
|
18 |
+
"dependencies": {
|
19 |
+
"@carbon/icons-react": "^11.25.0",
|
20 |
+
"@lexical/react": "^0.12.2",
|
21 |
+
"axios": "^1.4.0",
|
22 |
+
"lexical": "^0.12.2",
|
23 |
+
"prettier": "^3.0.3",
|
24 |
+
"react": "^18.2.0",
|
25 |
+
"react-cookie": "^6.1.1",
|
26 |
+
"react-daisyui": "^4.1.0",
|
27 |
+
"react-dnd": "^16.0.1",
|
28 |
+
"react-dnd-html5-backend": "^16.0.1",
|
29 |
+
"react-dom": "^18.2.0",
|
30 |
+
"react-dropzone": "^14.2.3",
|
31 |
+
"react-hotkeys-hook": "^4.4.1",
|
32 |
+
"react-markdown": "^9.0.1",
|
33 |
+
"react-router-dom": "^6.15.0",
|
34 |
+
"react-use-websocket": "^4.5.0",
|
35 |
+
"react18-json-view": "^0.2.4",
|
36 |
+
"remark-gfm": "^4.0.0",
|
37 |
+
"unique-username-generator": "^1.2.0",
|
38 |
+
"ws": "^8.14.2",
|
39 |
+
"zod": "^3.22.2",
|
40 |
+
"zustand": "^4.4.1"
|
41 |
+
},
|
42 |
+
"devDependencies": {
|
43 |
+
"@tailwindcss/typography": "^0.5.9",
|
44 |
+
"@types/react": "^18.2.15",
|
45 |
+
"@types/react-dom": "^18.2.7",
|
46 |
+
"@types/ws": "^8.5.9",
|
47 |
+
"@typescript-eslint/eslint-plugin": "^6.0.0",
|
48 |
+
"@typescript-eslint/parser": "^6.0.0",
|
49 |
+
"@vitejs/plugin-react": "^4.0.3",
|
50 |
+
"autoprefixer": "^10.4.15",
|
51 |
+
"daisyui": "^3.9.2",
|
52 |
+
"eslint": "^8.45.0",
|
53 |
+
"eslint-plugin-react-hooks": "^4.6.0",
|
54 |
+
"eslint-plugin-react-refresh": "^0.4.3",
|
55 |
+
"postcss": "^8.4.28",
|
56 |
+
"prettier": "^3.0.3",
|
57 |
+
"tailwindcss": "^3.3.3",
|
58 |
+
"typescript": "^5.0.2",
|
59 |
+
"vite": "^4.4.5",
|
60 |
+
"vitest": "^0.34.6"
|
61 |
+
}
|
62 |
+
}
|
chameleon/viewer/frontend/postcss.config.cjs
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
3 |
+
*
|
4 |
+
* This source code is licensed under the Chameleon License found in the
|
5 |
+
* LICENSE file in the root directory of this source tree.
|
6 |
+
*/
|
7 |
+
|
8 |
+
module.exports = {
|
9 |
+
plugins: {
|
10 |
+
tailwindcss: {},
|
11 |
+
autoprefixer: {},
|
12 |
+
},
|
13 |
+
};
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_DisplayVF_W_Wght.woff2
ADDED
Binary file (36.5 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff
ADDED
Binary file (28.9 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Bd.woff2
ADDED
Binary file (23.6 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff
ADDED
Binary file (28.6 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_Md.woff2
ADDED
Binary file (23.4 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff
ADDED
Binary file (28.7 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Display_W_SBd.woff2
ADDED
Binary file (23.6 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_TextVF_W_Wght.woff2
ADDED
Binary file (36.3 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff
ADDED
Binary file (28.8 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Bd.woff2
ADDED
Binary file (23.5 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff
ADDED
Binary file (28.7 kB). View file
|
|
chameleon/viewer/frontend/public/fonts/optimistic/Optimistic_Text_W_Md.woff2
ADDED
Binary file (23.5 kB). View file
|
|