Spaces:
Paused
Paused
calculating
commited on
Commit
•
824afbf
1
Parent(s):
35d94d0
committing...
Browse files- app.py +245 -0
- ioblocks.py +333 -0
- model.py +443 -0
- requirements.txt +14 -0
- tokenizer.py +581 -0
- transformer.py +382 -0
- utils/__init__.py +3 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/blocks.cpython-310.pyc +0 -0
- utils/__pycache__/dist.cpython-310.pyc +0 -0
- utils/__pycache__/interp.cpython-310.pyc +0 -0
- utils/blocks.py +92 -0
- utils/dist.py +99 -0
- utils/interp.py +84 -0
app.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch as T
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torchaudio
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from utils import load_ckpt, print_colored
|
8 |
+
from tokenizer import make_tokenizer
|
9 |
+
from model import get_hertz_dev_config
|
10 |
+
from typing import Tuple
|
11 |
+
import numpy as np
|
12 |
+
import os
|
13 |
+
|
14 |
+
# Global variables for model and tokenizer
|
15 |
+
global_generator = None
|
16 |
+
global_tokenizer = None
|
17 |
+
default_audio_path = "testingtesting.wav" # Your default audio file
|
18 |
+
|
19 |
+
def init_model(use_pure_audio_ablation: bool = False) -> Tuple[nn.Module, object]:
|
20 |
+
"""Initialize the model and tokenizer"""
|
21 |
+
global global_generator, global_tokenizer
|
22 |
+
|
23 |
+
if global_generator is not None and global_tokenizer is not None:
|
24 |
+
return global_generator, global_tokenizer
|
25 |
+
|
26 |
+
device = 'cuda' if T.cuda.is_available() else 'cpu'
|
27 |
+
T.cuda.set_device(0) if device == 'cuda' else None
|
28 |
+
|
29 |
+
print_colored("Initializing model and tokenizer...", "blue")
|
30 |
+
global_tokenizer = make_tokenizer(device)
|
31 |
+
model_config = get_hertz_dev_config(is_split=False, use_pure_audio_ablation=use_pure_audio_ablation)
|
32 |
+
|
33 |
+
global_generator = model_config()
|
34 |
+
global_generator = global_generator.eval().to(T.bfloat16).to(device)
|
35 |
+
print_colored("Model initialization complete!", "green")
|
36 |
+
|
37 |
+
return global_generator, global_tokenizer
|
38 |
+
|
39 |
+
def process_audio(audio_path: str, sr: int) -> T.Tensor:
|
40 |
+
"""Load and preprocess audio file"""
|
41 |
+
audio_tensor, sr = torchaudio.load(audio_path)
|
42 |
+
|
43 |
+
|
44 |
+
if audio_tensor.shape[0] == 2:
|
45 |
+
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
|
46 |
+
|
47 |
+
if sr != 16000:
|
48 |
+
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
|
49 |
+
audio_tensor = resampler(audio_tensor)
|
50 |
+
|
51 |
+
max_samples = 16000 * 60 * 5 # 5 minutes
|
52 |
+
if audio_tensor.shape[1] > max_samples:
|
53 |
+
audio_tensor = audio_tensor[:, :max_samples]
|
54 |
+
|
55 |
+
return audio_tensor.unsqueeze(0)
|
56 |
+
|
57 |
+
def generate_completion(
|
58 |
+
audio_file,
|
59 |
+
prompt_len_seconds: float = 3.0,
|
60 |
+
num_completions: int = 5,
|
61 |
+
generation_seconds: float = 20.0,
|
62 |
+
token_temp: float = 0.8,
|
63 |
+
categorical_temp: float = 0.5,
|
64 |
+
gaussian_temp: float = 0.1,
|
65 |
+
progress=gr.Progress(track_tqdm=True)
|
66 |
+
) -> list:
|
67 |
+
"""Generate audio completions from the input audio"""
|
68 |
+
device = 'cuda' if T.cuda.is_available() else 'cpu'
|
69 |
+
|
70 |
+
# Use existing model and tokenizer
|
71 |
+
generator, audio_tokenizer = global_generator, global_tokenizer
|
72 |
+
|
73 |
+
progress(0, desc="Processing input audio...")
|
74 |
+
# Process input audio
|
75 |
+
prompt_audio = process_audio(audio_file, sr=16000)
|
76 |
+
prompt_len = int(prompt_len_seconds * 8)
|
77 |
+
|
78 |
+
progress(0.2, desc="Encoding prompt...")
|
79 |
+
# Encode prompt
|
80 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
81 |
+
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
|
82 |
+
|
83 |
+
completions = []
|
84 |
+
for i in range(num_completions):
|
85 |
+
progress((i + 1) / num_completions, desc=f"Generating completion {i+1}/{num_completions}")
|
86 |
+
|
87 |
+
# Generate completion
|
88 |
+
encoded_prompt = encoded_prompt_audio[:, :prompt_len]
|
89 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
90 |
+
completed_audio_batch = generator.completion(
|
91 |
+
encoded_prompt,
|
92 |
+
temps=(token_temp, (categorical_temp, gaussian_temp)),
|
93 |
+
use_cache=True,
|
94 |
+
gen_len=int(generation_seconds * 8)
|
95 |
+
)
|
96 |
+
|
97 |
+
decoded_completion = audio_tokenizer.data_from_latent(completed_audio_batch.bfloat16())
|
98 |
+
|
99 |
+
# Process audio for output
|
100 |
+
audio_tensor = decoded_completion.cpu().squeeze()
|
101 |
+
if audio_tensor.ndim == 1:
|
102 |
+
audio_tensor = audio_tensor.unsqueeze(0)
|
103 |
+
audio_tensor = audio_tensor.float()
|
104 |
+
|
105 |
+
if audio_tensor.abs().max() > 1:
|
106 |
+
audio_tensor = audio_tensor / audio_tensor.abs().max()
|
107 |
+
|
108 |
+
# Trim to include only the generated portion
|
109 |
+
output_audio = audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
|
110 |
+
completions.append((16000, output_audio.numpy().T))
|
111 |
+
|
112 |
+
progress(1.0, desc="Generation complete!")
|
113 |
+
return completions
|
114 |
+
|
115 |
+
def create_interface():
|
116 |
+
# Initialize model at startup
|
117 |
+
init_model()
|
118 |
+
|
119 |
+
with gr.Blocks(title="Audio Completion Generator") as app:
|
120 |
+
gr.Markdown("""
|
121 |
+
# Audio Completion Generator
|
122 |
+
Upload an audio file (or use the default) and generate AI completions based on the prompt.
|
123 |
+
""")
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column():
|
127 |
+
# Load the default audio if it exists
|
128 |
+
default_value = default_audio_path if os.path.exists(default_audio_path) else None
|
129 |
+
|
130 |
+
audio_input = gr.Audio(
|
131 |
+
label="Input Audio",
|
132 |
+
type="filepath",
|
133 |
+
sources=["microphone", "upload"],
|
134 |
+
value=default_value
|
135 |
+
)
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
prompt_len = gr.Slider(
|
139 |
+
minimum=1,
|
140 |
+
maximum=10,
|
141 |
+
value=3,
|
142 |
+
step=0.5,
|
143 |
+
label="Prompt Length (seconds)"
|
144 |
+
)
|
145 |
+
default_num_completions = 5
|
146 |
+
num_completions = gr.Slider(
|
147 |
+
minimum=1,
|
148 |
+
maximum=10,
|
149 |
+
value=default_num_completions,
|
150 |
+
step=1,
|
151 |
+
label="Number of Completions"
|
152 |
+
)
|
153 |
+
gen_length = gr.Slider(
|
154 |
+
minimum=5,
|
155 |
+
maximum=60,
|
156 |
+
value=20,
|
157 |
+
step=5,
|
158 |
+
label="Generation Length (seconds)"
|
159 |
+
)
|
160 |
+
|
161 |
+
with gr.Row():
|
162 |
+
token_temp = gr.Slider(
|
163 |
+
minimum=0.1,
|
164 |
+
maximum=1.0,
|
165 |
+
value=0.8,
|
166 |
+
step=0.1,
|
167 |
+
label="Token Temperature"
|
168 |
+
)
|
169 |
+
cat_temp = gr.Slider(
|
170 |
+
minimum=0.1,
|
171 |
+
maximum=1.0,
|
172 |
+
value=0.5,
|
173 |
+
step=0.1,
|
174 |
+
label="Categorical Temperature"
|
175 |
+
)
|
176 |
+
gauss_temp = gr.Slider(
|
177 |
+
minimum=0.1,
|
178 |
+
maximum=1.0,
|
179 |
+
value=0.1,
|
180 |
+
step=0.1,
|
181 |
+
label="Gaussian Temperature"
|
182 |
+
)
|
183 |
+
|
184 |
+
generate_btn = gr.Button("Generate Completions")
|
185 |
+
status_text = gr.Markdown("Ready")
|
186 |
+
|
187 |
+
with gr.Column():
|
188 |
+
output_audios = []
|
189 |
+
for i in range(10): # Create 10 audio components
|
190 |
+
output_audios.append(gr.Audio(
|
191 |
+
label=f"Generated Completion {i+1}",
|
192 |
+
type="numpy",
|
193 |
+
visible=False
|
194 |
+
))
|
195 |
+
|
196 |
+
def update_visibility(num):
|
197 |
+
return [gr.update(visible=(i < num)) for i in range(10)]
|
198 |
+
|
199 |
+
def generate_with_status(*args):
|
200 |
+
status_text.value = "Processing input audio..."
|
201 |
+
completions = generate_completion(*args)
|
202 |
+
status_text.value = "Generation complete!"
|
203 |
+
|
204 |
+
# Prepare outputs for all audio components
|
205 |
+
outputs = []
|
206 |
+
for i in range(10):
|
207 |
+
if i < len(completions):
|
208 |
+
outputs.append(completions[i])
|
209 |
+
else:
|
210 |
+
outputs.append(None)
|
211 |
+
return outputs
|
212 |
+
|
213 |
+
# Set initial visibility on load
|
214 |
+
app.load(
|
215 |
+
fn=update_visibility,
|
216 |
+
inputs=[num_completions],
|
217 |
+
outputs=output_audios
|
218 |
+
)
|
219 |
+
|
220 |
+
# Update visibility when slider changes
|
221 |
+
num_completions.change(
|
222 |
+
fn=update_visibility,
|
223 |
+
inputs=[num_completions],
|
224 |
+
outputs=output_audios
|
225 |
+
)
|
226 |
+
|
227 |
+
generate_btn.click(
|
228 |
+
fn=generate_with_status,
|
229 |
+
inputs=[
|
230 |
+
audio_input,
|
231 |
+
prompt_len,
|
232 |
+
num_completions,
|
233 |
+
gen_length,
|
234 |
+
token_temp,
|
235 |
+
cat_temp,
|
236 |
+
gauss_temp
|
237 |
+
],
|
238 |
+
outputs=output_audios
|
239 |
+
)
|
240 |
+
|
241 |
+
return app
|
242 |
+
|
243 |
+
if __name__ == "__main__":
|
244 |
+
app = create_interface()
|
245 |
+
app.launch(share=True)
|
ioblocks.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from functools import partial
|
3 |
+
from contextlib import nullcontext
|
4 |
+
from typing import List, Tuple
|
5 |
+
from math import ceil
|
6 |
+
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.distributed as dist
|
11 |
+
from torch import Tensor, int32
|
12 |
+
from torch.amp import autocast
|
13 |
+
|
14 |
+
from einops import rearrange, pack, unpack
|
15 |
+
|
16 |
+
|
17 |
+
from utils import si_module, exists, default, maybe
|
18 |
+
|
19 |
+
|
20 |
+
@si_module
|
21 |
+
class GaussianMixtureIOLayer(nn.Module):
|
22 |
+
class Config:
|
23 |
+
latent_dim: int
|
24 |
+
dim: int
|
25 |
+
num_components: int
|
26 |
+
|
27 |
+
def __init__(self, c: Config):
|
28 |
+
super().__init__()
|
29 |
+
self.latent_dim = c.latent_dim
|
30 |
+
self.num_components = c.num_components
|
31 |
+
self.input_projection = nn.Linear(c.latent_dim, c.dim)
|
32 |
+
|
33 |
+
self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
34 |
+
self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
35 |
+
self.fc_weight = nn.Linear(c.dim, c.num_components)
|
36 |
+
|
37 |
+
def _square_plus(self, x):
|
38 |
+
return (x + T.sqrt(T.square(x) + 4)) / 2
|
39 |
+
|
40 |
+
def input(self, sampled_latents: T.Tensor) -> T.Tensor:
|
41 |
+
"""Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
|
42 |
+
hidden = self.input_projection(sampled_latents)
|
43 |
+
return hidden
|
44 |
+
|
45 |
+
def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
|
46 |
+
"""float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
|
47 |
+
batch_size, seq_len, _ = h.shape
|
48 |
+
|
49 |
+
locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
50 |
+
scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
51 |
+
weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
|
52 |
+
|
53 |
+
return (locs, scales, weights)
|
54 |
+
|
55 |
+
def loss(self, data, dataHat):
|
56 |
+
locs, scales, weights = dataHat
|
57 |
+
log_probs = -0.5 * T.sum(
|
58 |
+
(data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
|
59 |
+
2 * T.log(scales) +
|
60 |
+
T.log(T.tensor(2 * T.pi)),
|
61 |
+
dim=-1
|
62 |
+
)
|
63 |
+
log_weights = F.log_softmax(weights, dim=-1)
|
64 |
+
return -T.logsumexp(log_weights + log_probs, dim=-1)
|
65 |
+
|
66 |
+
|
67 |
+
def temp_sample(self, orig_pdist, temp):
|
68 |
+
locs, scales, weights = orig_pdist
|
69 |
+
if temp is None:
|
70 |
+
component_samples = locs + scales * T.randn_like(scales)
|
71 |
+
mixture_samples = F.gumbel_softmax(weights, hard=True)
|
72 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
73 |
+
elif isinstance(temp, tuple):
|
74 |
+
assert len(temp) == 2
|
75 |
+
categorical_temp, gaussian_temp = temp
|
76 |
+
component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
|
77 |
+
mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
|
78 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
79 |
+
else:
|
80 |
+
component_samples = locs + scales * temp * T.randn_like(scales)
|
81 |
+
mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
|
82 |
+
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
83 |
+
return sampled
|
84 |
+
|
85 |
+
|
86 |
+
class GPTOutput(nn.Module):
|
87 |
+
def __init__(self, dim, vocab_size):
|
88 |
+
super().__init__()
|
89 |
+
self.output = nn.Linear(dim, vocab_size, bias=False)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
return self.output(x)
|
93 |
+
|
94 |
+
|
95 |
+
# helper functions
|
96 |
+
|
97 |
+
def pack_one(t, pattern):
|
98 |
+
return pack([t], pattern)
|
99 |
+
|
100 |
+
def unpack_one(t, ps, pattern):
|
101 |
+
return unpack(t, ps, pattern)[0]
|
102 |
+
|
103 |
+
def first(l):
|
104 |
+
return l[0]
|
105 |
+
|
106 |
+
def round_up_multiple(num, mult):
|
107 |
+
return ceil(num / mult) * mult
|
108 |
+
|
109 |
+
def get_code_utilization(codes, codebook_size, get_global=False):
|
110 |
+
if get_global and dist.is_initialized():
|
111 |
+
world_size = dist.get_world_size()
|
112 |
+
else:
|
113 |
+
world_size = 1
|
114 |
+
|
115 |
+
if world_size > 1:
|
116 |
+
gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
|
117 |
+
dist.all_gather(gathered_tokens, codes)
|
118 |
+
gathered_tokens = T.cat(gathered_tokens, dim=0)
|
119 |
+
else:
|
120 |
+
gathered_tokens = codes
|
121 |
+
unique_tokens = len(T.unique(gathered_tokens))
|
122 |
+
code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
|
123 |
+
return code_utilization
|
124 |
+
|
125 |
+
# tensor helpers
|
126 |
+
|
127 |
+
def round_ste(z: Tensor) -> Tensor:
|
128 |
+
"""Round with straight through gradients."""
|
129 |
+
zhat = z.round()
|
130 |
+
return z + (zhat - z).detach()
|
131 |
+
|
132 |
+
# main class
|
133 |
+
# lucidrains fsq
|
134 |
+
@si_module
|
135 |
+
class FSQ(nn.Module):
|
136 |
+
@property
|
137 |
+
def needs_float32_params(self):
|
138 |
+
return True
|
139 |
+
|
140 |
+
class Config:
|
141 |
+
levels: List[int]
|
142 |
+
dim: int | None = None
|
143 |
+
num_codebooks: int = 1
|
144 |
+
keep_num_codebooks_dim: bool | None = None
|
145 |
+
scale: float | None = None
|
146 |
+
allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
|
147 |
+
channel_first: bool = False
|
148 |
+
projection_has_bias: bool = True
|
149 |
+
return_indices: bool = True
|
150 |
+
force_quantization_f32: bool = True
|
151 |
+
use_rms: bool = False
|
152 |
+
|
153 |
+
def __init__(self, c: Config):
|
154 |
+
super().__init__()
|
155 |
+
_levels = T.tensor(c.levels, dtype=int32)
|
156 |
+
self.register_buffer("_levels", _levels, persistent = False)
|
157 |
+
|
158 |
+
_basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
|
159 |
+
self.register_buffer("_basis", _basis, persistent = False)
|
160 |
+
|
161 |
+
self.scale = c.scale
|
162 |
+
|
163 |
+
codebook_dim = len(c.levels)
|
164 |
+
self.codebook_dim = codebook_dim
|
165 |
+
|
166 |
+
effective_codebook_dim = codebook_dim * c.num_codebooks
|
167 |
+
self.num_codebooks = c.num_codebooks
|
168 |
+
|
169 |
+
self.allowed_dtypes = []
|
170 |
+
for dtype_str in c.allowed_dtypes:
|
171 |
+
if hasattr(T, dtype_str):
|
172 |
+
self.allowed_dtypes.append(getattr(T, dtype_str))
|
173 |
+
else:
|
174 |
+
raise ValueError(f"Invalid dtype string: {dtype_str}")
|
175 |
+
|
176 |
+
self.effective_codebook_dim = effective_codebook_dim
|
177 |
+
|
178 |
+
keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
|
179 |
+
assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
|
180 |
+
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
181 |
+
|
182 |
+
self.dim = default(c.dim, len(_levels) * c.num_codebooks)
|
183 |
+
|
184 |
+
self.channel_first = c.channel_first
|
185 |
+
|
186 |
+
has_projections = self.dim != effective_codebook_dim
|
187 |
+
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
188 |
+
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
189 |
+
|
190 |
+
self.has_projections = has_projections
|
191 |
+
|
192 |
+
self.return_indices = c.return_indices
|
193 |
+
if c.return_indices:
|
194 |
+
self.codebook_size = self._levels.prod().item()
|
195 |
+
implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
|
196 |
+
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
197 |
+
|
198 |
+
self.allowed_dtypes = c.allowed_dtypes
|
199 |
+
self.force_quantization_f32 = c.force_quantization_f32
|
200 |
+
|
201 |
+
self.latent_loss = None
|
202 |
+
|
203 |
+
def latent_metric(self, codes, get_global=False):
|
204 |
+
return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
|
205 |
+
|
206 |
+
def repr_from_latent(self, latent):
|
207 |
+
return self.indices_to_codes(latent)
|
208 |
+
|
209 |
+
def bound(self, z, eps: float = 1e-3):
|
210 |
+
""" Bound `z`, an array of shape (..., d). """
|
211 |
+
half_l = (self._levels - 1) * (1 + eps) / 2
|
212 |
+
offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
|
213 |
+
shift = (offset / half_l).atanh()
|
214 |
+
return (z + shift).tanh() * half_l - offset
|
215 |
+
|
216 |
+
def quantize(self, z):
|
217 |
+
""" Quantizes z, returns quantized zhat, same shape as z. """
|
218 |
+
quantized = round_ste(self.bound(z))
|
219 |
+
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
220 |
+
return quantized / half_width
|
221 |
+
|
222 |
+
def _scale_and_shift(self, zhat_normalized):
|
223 |
+
half_width = self._levels // 2
|
224 |
+
return (zhat_normalized * half_width) + half_width
|
225 |
+
|
226 |
+
def _scale_and_shift_inverse(self, zhat):
|
227 |
+
half_width = self._levels // 2
|
228 |
+
return (zhat - half_width) / half_width
|
229 |
+
|
230 |
+
def _indices_to_codes(self, indices):
|
231 |
+
level_indices = self.indices_to_level_indices(indices)
|
232 |
+
codes = self._scale_and_shift_inverse(level_indices)
|
233 |
+
return codes
|
234 |
+
|
235 |
+
def codes_to_indices(self, zhat):
|
236 |
+
""" Converts a `code` to an index in the codebook. """
|
237 |
+
assert zhat.shape[-1] == self.codebook_dim
|
238 |
+
zhat = self._scale_and_shift(zhat)
|
239 |
+
return (zhat * self._basis).sum(dim=-1).to(int32)
|
240 |
+
|
241 |
+
def indices_to_level_indices(self, indices):
|
242 |
+
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
|
243 |
+
indices = rearrange(indices, '... -> ... 1')
|
244 |
+
codes_non_centered = (indices // self._basis) % self._levels
|
245 |
+
return codes_non_centered
|
246 |
+
|
247 |
+
def indices_to_codes(self, indices):
|
248 |
+
""" Inverse of `codes_to_indices`. """
|
249 |
+
assert exists(indices)
|
250 |
+
|
251 |
+
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
252 |
+
|
253 |
+
codes = self._indices_to_codes(indices)
|
254 |
+
|
255 |
+
if self.keep_num_codebooks_dim:
|
256 |
+
codes = rearrange(codes, '... c d -> ... (c d)')
|
257 |
+
|
258 |
+
codes = self.project_out(codes)
|
259 |
+
|
260 |
+
if is_img_or_video or self.channel_first:
|
261 |
+
codes = rearrange(codes, 'b ... d -> b d ...')
|
262 |
+
|
263 |
+
return codes
|
264 |
+
|
265 |
+
# @autocast(device_type='cuda', enabled = False)
|
266 |
+
def forward(self, z, return_codes=False):
|
267 |
+
"""
|
268 |
+
einstein notation
|
269 |
+
b - batch
|
270 |
+
n - sequence (or flattened spatial dimensions)
|
271 |
+
d - feature dimension
|
272 |
+
c - number of codebook dim
|
273 |
+
"""
|
274 |
+
|
275 |
+
is_img_or_video = z.ndim >= 4
|
276 |
+
need_move_channel_last = is_img_or_video or self.channel_first
|
277 |
+
|
278 |
+
# standardize image or video into (batch, seq, dimension)
|
279 |
+
|
280 |
+
if need_move_channel_last:
|
281 |
+
z = rearrange(z, 'b d ... -> b ... d')
|
282 |
+
z, ps = pack_one(z, 'b * d')
|
283 |
+
|
284 |
+
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
285 |
+
|
286 |
+
z = self.project_in(z)
|
287 |
+
|
288 |
+
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
289 |
+
|
290 |
+
# whether to force quantization step to be full precision or not
|
291 |
+
|
292 |
+
force_f32 = self.force_quantization_f32
|
293 |
+
quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
|
294 |
+
|
295 |
+
with quantization_context():
|
296 |
+
orig_dtype = z.dtype
|
297 |
+
|
298 |
+
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
299 |
+
z = z.float()
|
300 |
+
|
301 |
+
codes = self.quantize(z)
|
302 |
+
|
303 |
+
# returning indices could be optional
|
304 |
+
|
305 |
+
indices = None
|
306 |
+
|
307 |
+
if self.return_indices:
|
308 |
+
indices = self.codes_to_indices(codes)
|
309 |
+
|
310 |
+
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
311 |
+
|
312 |
+
codes = codes.type(orig_dtype)
|
313 |
+
|
314 |
+
# project out
|
315 |
+
if return_codes:
|
316 |
+
return codes, indices
|
317 |
+
|
318 |
+
out = self.project_out(codes)
|
319 |
+
|
320 |
+
# reconstitute image or video dimensions
|
321 |
+
|
322 |
+
if need_move_channel_last:
|
323 |
+
out = unpack_one(out, ps, 'b * d')
|
324 |
+
out = rearrange(out, 'b ... d -> b d ...')
|
325 |
+
|
326 |
+
indices = maybe(unpack_one)(indices, ps, 'b * c')
|
327 |
+
|
328 |
+
if not self.keep_num_codebooks_dim and self.return_indices:
|
329 |
+
indices = maybe(rearrange)(indices, '... 1 -> ...')
|
330 |
+
|
331 |
+
# return quantized output and indices
|
332 |
+
|
333 |
+
return out, indices
|
model.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
|
3 |
+
import torch as T
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ioblocks import GaussianMixtureIOLayer, FSQ
|
8 |
+
|
9 |
+
from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
|
10 |
+
from tokenizer import make_tokenizer
|
11 |
+
|
12 |
+
|
13 |
+
from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
|
14 |
+
from utils import load_ckpt
|
15 |
+
|
16 |
+
|
17 |
+
@si_module
|
18 |
+
class LatentQuantizer(nn.Module):
|
19 |
+
class Config:
|
20 |
+
compressor_config: Optional[FSQ.Config] = None
|
21 |
+
|
22 |
+
dim: Optional[int] = None
|
23 |
+
ff_dim: Optional[int] = None
|
24 |
+
input_dim: int = None
|
25 |
+
|
26 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
27 |
+
|
28 |
+
def __init__(self, c: Config):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
if exists(c.from_pretrained):
|
32 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
33 |
+
else:
|
34 |
+
assert exists(c.compressor_config), f'hmm {c}'
|
35 |
+
|
36 |
+
self.compressor = c.compressor_config()
|
37 |
+
self.ffnn = FFNN(c.dim, c.ff_dim)
|
38 |
+
self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
|
39 |
+
|
40 |
+
if exists(c.from_pretrained):
|
41 |
+
self.load_state_dict(checkpoint)
|
42 |
+
|
43 |
+
@T.no_grad()
|
44 |
+
def forward(self, x, return_latent=False, known_latent=None):
|
45 |
+
"""
|
46 |
+
x: (B, S, D)
|
47 |
+
"""
|
48 |
+
if exists(known_latent):
|
49 |
+
return self.compressor.indices_to_codes(known_latent)
|
50 |
+
|
51 |
+
x = self.input(x)
|
52 |
+
x = self.ffnn(x)
|
53 |
+
x, tokens = self.compressor(x)
|
54 |
+
|
55 |
+
if return_latent:
|
56 |
+
return x, tokens
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
@si_module
|
61 |
+
class TransformerVAE(nn.Module):
|
62 |
+
class Config:
|
63 |
+
io_config: Optional[GaussianMixtureIOLayer.Config] = None
|
64 |
+
stack_config: Optional[Stack.Config] = None
|
65 |
+
quantizer_config: Optional[LatentQuantizer.Config] = None
|
66 |
+
|
67 |
+
plex_layer: int = None
|
68 |
+
plex_roll: int = 1
|
69 |
+
split: bool = True
|
70 |
+
|
71 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
72 |
+
|
73 |
+
def __init__(self, c: Config):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
if exists(c.from_pretrained):
|
77 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
78 |
+
else:
|
79 |
+
assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
|
80 |
+
|
81 |
+
self.io = c.io_config()
|
82 |
+
self.stack = c.stack_config()
|
83 |
+
|
84 |
+
self.plex_layer = c.stack_config.layers//2
|
85 |
+
self.plex_roll = c.plex_roll
|
86 |
+
self.plex_dim = c.quantizer_config.dim
|
87 |
+
|
88 |
+
assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
|
89 |
+
self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
|
90 |
+
self.out_norm = Norm(c.stack_config.dim)
|
91 |
+
|
92 |
+
if c.split:
|
93 |
+
self.io2 = c.io_config()
|
94 |
+
self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
|
95 |
+
|
96 |
+
self.io2.fc_loc = None
|
97 |
+
self.io2.fc_scale = None
|
98 |
+
self.io2.fc_weight = None
|
99 |
+
|
100 |
+
kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
101 |
+
head_dim = c.stack_config.dim // c.stack_config.n_head
|
102 |
+
self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
|
103 |
+
cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
|
104 |
+
self.cache_shape = cache_shape
|
105 |
+
self.cache = [None] * self.cache_num_layers
|
106 |
+
|
107 |
+
if exists(c.from_pretrained):
|
108 |
+
result = self.load_state_dict(checkpoint, strict=False)
|
109 |
+
print0_colored(result, 'yellow')
|
110 |
+
|
111 |
+
self.quantizer = c.quantizer_config().eval()
|
112 |
+
self.quantizer.requires_grad = False
|
113 |
+
|
114 |
+
@T.no_grad()
|
115 |
+
def quantize(self, x):
|
116 |
+
if self.c.split:
|
117 |
+
x1, x2 = x.chunk(2, dim=-1)
|
118 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
119 |
+
quantized1 = self.quantizer(x1)
|
120 |
+
quantized2 = self.quantizer(x2)
|
121 |
+
return quantized1, quantized2
|
122 |
+
else:
|
123 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
124 |
+
return self.quantizer(x)
|
125 |
+
|
126 |
+
@T.no_grad()
|
127 |
+
def untokenize(self, token_data):
|
128 |
+
return self.quantizer(None, known_latent=token_data)
|
129 |
+
|
130 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
131 |
+
cache_shape = self.cache_shape.copy()
|
132 |
+
cache_shape[1] = length or cache_shape[1]
|
133 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
134 |
+
|
135 |
+
def deinit_cache(self):
|
136 |
+
self.cache = [None] * self.cache_num_layers
|
137 |
+
|
138 |
+
@T.no_grad()
|
139 |
+
def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
|
140 |
+
if self.c.split:
|
141 |
+
x1, x2 = data.chunk(2, dim=-1)
|
142 |
+
x = self.io.input(x1) + self.io2.input(x2)
|
143 |
+
else:
|
144 |
+
x = self.io.input(data)
|
145 |
+
|
146 |
+
cache_idx = 0
|
147 |
+
for l, layer in enumerate(self.stack.layers):
|
148 |
+
if l == self.plex_layer:
|
149 |
+
if self.c.split:
|
150 |
+
plex1, plex2 = self.quantize(data)
|
151 |
+
plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
|
152 |
+
plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
|
153 |
+
if exists(next_tokens):
|
154 |
+
plex1[:, -1:] = self.untokenize(next_tokens[0])
|
155 |
+
plex2[:, -1:] = self.untokenize(next_tokens[1])
|
156 |
+
x1 = x + self.plex_projection(plex1)
|
157 |
+
x2 = x + self.plex_projection2(plex2)
|
158 |
+
else:
|
159 |
+
plex = self.quantize(data)
|
160 |
+
plex = T.roll(plex, -self.c.plex_roll, dims=1)
|
161 |
+
if exists(next_tokens):
|
162 |
+
plex[:, -1:] = self.untokenize(next_tokens)
|
163 |
+
x = x + self.plex_projection(plex)
|
164 |
+
|
165 |
+
if l < self.plex_layer:
|
166 |
+
x = layer(x, kv=self.cache[l])
|
167 |
+
else:
|
168 |
+
if self.c.split:
|
169 |
+
x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
|
170 |
+
cache_idx += 1
|
171 |
+
x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
|
172 |
+
cache_idx += 1
|
173 |
+
else:
|
174 |
+
x = layer(x, kv=self.cache[l])
|
175 |
+
|
176 |
+
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
177 |
+
if self.c.split:
|
178 |
+
x1, x2 = self.out_norm(x1), self.out_norm(x2)
|
179 |
+
out1, out2 = self.io.output(x1), self.io.output(x2)
|
180 |
+
else:
|
181 |
+
x = self.out_norm(x)
|
182 |
+
out = self.io.output(x)
|
183 |
+
|
184 |
+
if isnt(temps):
|
185 |
+
if self.c.split:
|
186 |
+
return out1, out2
|
187 |
+
else:
|
188 |
+
return out
|
189 |
+
else:
|
190 |
+
if self.c.split:
|
191 |
+
next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
|
192 |
+
next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
|
193 |
+
next_data = T.cat([next_data1, next_data2], dim=-1)
|
194 |
+
return next_data
|
195 |
+
else:
|
196 |
+
next_data = self.io.temp_sample(out, temps)[:, -1:, :]
|
197 |
+
return next_data
|
198 |
+
|
199 |
+
@si_module
|
200 |
+
class HertzDevModel(nn.Module):
|
201 |
+
class Config:
|
202 |
+
dim: int
|
203 |
+
vocab_size: int
|
204 |
+
stack_config: Optional[Stack.Config] = None
|
205 |
+
latent_size: int = 32
|
206 |
+
|
207 |
+
split: bool = True
|
208 |
+
|
209 |
+
quantizer_config: Optional[LatentQuantizer.Config] = None
|
210 |
+
resynthesizer_config: Optional[TransformerVAE.Config] = None
|
211 |
+
|
212 |
+
from_pretrained: Optional[Tuple[str, str]] = None
|
213 |
+
|
214 |
+
def __init__(self, c: Config):
|
215 |
+
super().__init__()
|
216 |
+
|
217 |
+
if exists(c.from_pretrained):
|
218 |
+
checkpoint = load_ckpt(*c.from_pretrained)
|
219 |
+
else:
|
220 |
+
assert (exists(c.stack_config)), f'hmm {c}'
|
221 |
+
|
222 |
+
self.input = nn.Linear(c.latent_size, c.dim)
|
223 |
+
if self.c.split:
|
224 |
+
self.input2 = nn.Linear(c.latent_size, c.dim)
|
225 |
+
|
226 |
+
self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
|
227 |
+
|
228 |
+
self.layers = nn.ModuleList([
|
229 |
+
PerfBlock(
|
230 |
+
dim=c.stack_config.dim,
|
231 |
+
layer_id=l,
|
232 |
+
n_head=c.stack_config.n_head,
|
233 |
+
kv_heads=c.stack_config.kv_heads,
|
234 |
+
ff_dim=c.stack_config.ff_dim,
|
235 |
+
eps=c.stack_config.eps,
|
236 |
+
shape_rotator=self.shape_rotator,
|
237 |
+
) for l in range(c.stack_config.layers)
|
238 |
+
])
|
239 |
+
|
240 |
+
self.output = GPTOutput(c.dim, c.vocab_size)
|
241 |
+
if self.c.split:
|
242 |
+
self.output2 = GPTOutput(c.dim, c.vocab_size)
|
243 |
+
|
244 |
+
self.cache = [None] * c.stack_config.layers
|
245 |
+
self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
246 |
+
self.head_dim = c.stack_config.dim // c.stack_config.n_head
|
247 |
+
|
248 |
+
if exists(c.from_pretrained):
|
249 |
+
result = self.load_state_dict(checkpoint, strict=False)
|
250 |
+
print0_colored(result, 'yellow')
|
251 |
+
|
252 |
+
self.resynthesizer = c.resynthesizer_config().eval()
|
253 |
+
self.resynthesizer.requires_grad = False
|
254 |
+
|
255 |
+
self.audio_tokenizer = make_tokenizer(device='cpu')
|
256 |
+
self.audio_cache = None
|
257 |
+
self.audio_latent_cache = None
|
258 |
+
self.use_audio_cache = False
|
259 |
+
|
260 |
+
@T.no_grad()
|
261 |
+
def tokenize(self, audio_data):
|
262 |
+
orig_audio_shape = audio_data.shape
|
263 |
+
if exists(self.audio_cache):
|
264 |
+
audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
|
265 |
+
self.audio_cache = audio_data[..., -(6*16_000):]
|
266 |
+
elif self.use_audio_cache:
|
267 |
+
self.audio_cache = audio_data[..., -(6*16_000):]
|
268 |
+
|
269 |
+
if audio_data.shape[1] == 2:
|
270 |
+
enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
|
271 |
+
enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
|
272 |
+
return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
|
273 |
+
else:
|
274 |
+
return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
|
275 |
+
|
276 |
+
@T.no_grad()
|
277 |
+
def untokenize(self, token_data):
|
278 |
+
if exists(self.audio_latent_cache):
|
279 |
+
token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
|
280 |
+
self.audio_latent_cache = token_data[:, -(6*8):]
|
281 |
+
elif self.use_audio_cache:
|
282 |
+
self.audio_latent_cache = token_data[:, -(6*8):]
|
283 |
+
|
284 |
+
if token_data.shape[-1] == 2*self.c.latent_size:
|
285 |
+
dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
|
286 |
+
dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
|
287 |
+
return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
|
288 |
+
else:
|
289 |
+
return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
|
290 |
+
|
291 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
292 |
+
cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
|
293 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
294 |
+
self.resynthesizer.init_cache(bsize, device, dtype, length)
|
295 |
+
self.use_audio_cache = True
|
296 |
+
|
297 |
+
def deinit_cache(self):
|
298 |
+
self.cache = [None] * len(self.layers)
|
299 |
+
self.resynthesizer.deinit_cache()
|
300 |
+
self.audio_cache = None
|
301 |
+
self.audio_latent_cache = None
|
302 |
+
self.use_audio_cache = False
|
303 |
+
|
304 |
+
@T.no_grad()
|
305 |
+
def forward(self, data):
|
306 |
+
if self.c.split:
|
307 |
+
x1, x2 = data.chunk(2, dim=-1)
|
308 |
+
x = self.input(x1) + self.input2(x2)
|
309 |
+
else:
|
310 |
+
x = self.input(data)
|
311 |
+
|
312 |
+
for l, layer in enumerate(self.layers):
|
313 |
+
x = layer(x, kv=self.cache[l])
|
314 |
+
|
315 |
+
if self.c.split:
|
316 |
+
return self.output(x), self.output2(x)
|
317 |
+
else:
|
318 |
+
return self.output(x)
|
319 |
+
|
320 |
+
@T.no_grad()
|
321 |
+
def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
322 |
+
latents_in = self.tokenize(audio_data)
|
323 |
+
next_latents = self.next_latent(latents_in, temps)
|
324 |
+
next_model_latent = next_latents[..., self.c.latent_size:]
|
325 |
+
audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
|
326 |
+
return audio_decoded
|
327 |
+
|
328 |
+
|
329 |
+
@T.no_grad()
|
330 |
+
def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
331 |
+
|
332 |
+
if self.c.split:
|
333 |
+
logits1, logits2 = self.forward(model_input)
|
334 |
+
next_logits1 = logits1[:, -1]
|
335 |
+
next_logits2 = logits2[:, -1]
|
336 |
+
next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
|
337 |
+
next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
|
338 |
+
|
339 |
+
next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
|
340 |
+
else:
|
341 |
+
logits = self.forward(model_input)
|
342 |
+
next_logits = logits[:, -1]
|
343 |
+
next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
|
344 |
+
|
345 |
+
next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
|
346 |
+
|
347 |
+
return next_input
|
348 |
+
|
349 |
+
|
350 |
+
@T.no_grad()
|
351 |
+
def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
|
352 |
+
"""
|
353 |
+
only accepts latent-space data.
|
354 |
+
"""
|
355 |
+
if use_cache:
|
356 |
+
self.init_cache(data.shape[0], data.device, T.bfloat16)
|
357 |
+
|
358 |
+
next_input = generated = data
|
359 |
+
|
360 |
+
target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
|
361 |
+
|
362 |
+
for _ in tqdm0(range(data.shape[1], target_len)):
|
363 |
+
model_input = next_input if use_cache else generated
|
364 |
+
|
365 |
+
next_input = self.next_latent(model_input, temps)
|
366 |
+
|
367 |
+
generated = T.cat([generated, next_input], dim=1)
|
368 |
+
|
369 |
+
if use_cache:
|
370 |
+
self.deinit_cache()
|
371 |
+
return generated
|
372 |
+
|
373 |
+
|
374 |
+
|
375 |
+
def get_hertz_dev_config(is_split=True, use_pure_audio_ablation=False):
|
376 |
+
if is_split:
|
377 |
+
checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
|
378 |
+
elif not use_pure_audio_ablation:
|
379 |
+
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
|
380 |
+
else:
|
381 |
+
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_syrup_110000', '353c48f553f1706824c11f3bb6a049e9')]
|
382 |
+
|
383 |
+
quantizer_config=LatentQuantizer.Config(
|
384 |
+
from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
|
385 |
+
compressor_config=FSQ.Config(
|
386 |
+
levels=[8,8,8,8,8],
|
387 |
+
dim=2048,
|
388 |
+
num_codebooks=1,
|
389 |
+
keep_num_codebooks_dim=None,
|
390 |
+
scale=None,
|
391 |
+
allowed_dtypes=['float32', 'float64', 'bfloat16'],
|
392 |
+
channel_first=False,
|
393 |
+
projection_has_bias=True,
|
394 |
+
return_indices=True,
|
395 |
+
force_quantization_f32=True,
|
396 |
+
use_rms=False
|
397 |
+
),
|
398 |
+
dim=2048,
|
399 |
+
ff_dim=8192,
|
400 |
+
input_dim=32
|
401 |
+
)
|
402 |
+
|
403 |
+
resynthesizer_config=TransformerVAE.Config(
|
404 |
+
io_config=GaussianMixtureIOLayer.Config(
|
405 |
+
latent_dim=32,
|
406 |
+
dim=4096,
|
407 |
+
num_components=8,
|
408 |
+
),
|
409 |
+
stack_config=Stack.Config(
|
410 |
+
layers=8,
|
411 |
+
dim=4096,
|
412 |
+
seq_len=8192,
|
413 |
+
n_head=16,
|
414 |
+
ff_dim=11008,
|
415 |
+
kv_heads=16,
|
416 |
+
eps=1e-5,
|
417 |
+
theta=10_000
|
418 |
+
),
|
419 |
+
quantizer_config=quantizer_config,
|
420 |
+
plex_layer=None,
|
421 |
+
plex_roll=1,
|
422 |
+
split=is_split,
|
423 |
+
from_pretrained=checkpoints[0],
|
424 |
+
)
|
425 |
+
|
426 |
+
return HertzDevModel.Config(
|
427 |
+
dim=4096,
|
428 |
+
vocab_size=32_768,
|
429 |
+
stack_config=Stack.Config(
|
430 |
+
layers=32,
|
431 |
+
dim=4096,
|
432 |
+
seq_len=2048,
|
433 |
+
n_head=32,
|
434 |
+
ff_dim=None,
|
435 |
+
kv_heads=None,
|
436 |
+
eps=1e-5,
|
437 |
+
theta=10_000,
|
438 |
+
),
|
439 |
+
quantizer_config=quantizer_config,
|
440 |
+
resynthesizer_config=resynthesizer_config,
|
441 |
+
split=is_split,
|
442 |
+
from_pretrained=checkpoints[1],
|
443 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.5.1
|
2 |
+
torchaudio==2.5.1
|
3 |
+
einops==0.8.0
|
4 |
+
tqdm==4.66.6
|
5 |
+
ipython==8.29.0
|
6 |
+
numpy==1.26.3
|
7 |
+
soundfile==0.12.1
|
8 |
+
websockets==13.1
|
9 |
+
requests==2.32.3
|
10 |
+
sounddevice==0.5.1
|
11 |
+
matplotlib==3.9.2
|
12 |
+
fastapi==0.115.4
|
13 |
+
uvicorn==0.32.0
|
14 |
+
gradio==5.5.0
|
tokenizer.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import Union, Tuple, Literal
|
4 |
+
|
5 |
+
import torch as T
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
8 |
+
|
9 |
+
from utils import load_ckpt
|
10 |
+
from utils.interp import print_colored
|
11 |
+
from utils import si_module, get_activation
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Adapted from https://github.com/facebookresearch/AudioDec
|
16 |
+
|
17 |
+
def Conv1d1x1(in_channels, out_channels, bias=True):
|
18 |
+
return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
|
19 |
+
|
20 |
+
|
21 |
+
class NonCausalConv1d(nn.Module):
|
22 |
+
"""1D noncausal convolution w/ 2-sides padding."""
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
in_channels,
|
27 |
+
out_channels,
|
28 |
+
kernel_size,
|
29 |
+
stride=1,
|
30 |
+
padding=-1,
|
31 |
+
dilation=1,
|
32 |
+
groups=1,
|
33 |
+
bias=True):
|
34 |
+
super().__init__()
|
35 |
+
self.in_channels = in_channels
|
36 |
+
self.out_channels = out_channels
|
37 |
+
self.kernel_size = kernel_size
|
38 |
+
if padding < 0:
|
39 |
+
padding = (kernel_size - 1) // 2 * dilation
|
40 |
+
self.dilation = dilation
|
41 |
+
self.conv = nn.Conv1d(
|
42 |
+
in_channels=in_channels,
|
43 |
+
out_channels=out_channels,
|
44 |
+
kernel_size=kernel_size,
|
45 |
+
stride=stride,
|
46 |
+
padding=padding,
|
47 |
+
dilation=dilation,
|
48 |
+
groups=groups,
|
49 |
+
bias=bias,
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
56 |
+
Returns:
|
57 |
+
Tensor: Float tensor variable with the shape (B, C, T).
|
58 |
+
"""
|
59 |
+
x = self.conv(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class NonCausalConvTranspose1d(nn.Module):
|
64 |
+
"""1D noncausal transpose convolution."""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
in_channels,
|
69 |
+
out_channels,
|
70 |
+
kernel_size,
|
71 |
+
stride,
|
72 |
+
padding=-1,
|
73 |
+
output_padding=-1,
|
74 |
+
groups=1,
|
75 |
+
bias=True,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
if padding < 0:
|
79 |
+
padding = (stride+1) // 2
|
80 |
+
if output_padding < 0:
|
81 |
+
output_padding = 1 if stride % 2 else 0
|
82 |
+
self.deconv = nn.ConvTranspose1d(
|
83 |
+
in_channels=in_channels,
|
84 |
+
out_channels=out_channels,
|
85 |
+
kernel_size=kernel_size,
|
86 |
+
stride=stride,
|
87 |
+
padding=padding,
|
88 |
+
output_padding=output_padding,
|
89 |
+
groups=groups,
|
90 |
+
bias=bias,
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
"""
|
95 |
+
Args:
|
96 |
+
x (Tensor): Float tensor variable with the shape (B, C, T).
|
97 |
+
Returns:
|
98 |
+
Tensor: Float tensor variable with the shape (B, C', T').
|
99 |
+
"""
|
100 |
+
x = self.deconv(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
class CausalConv1d(NonCausalConv1d):
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
in_channels,
|
108 |
+
out_channels,
|
109 |
+
kernel_size,
|
110 |
+
stride=1,
|
111 |
+
dilation=1,
|
112 |
+
groups=1,
|
113 |
+
bias=True
|
114 |
+
):
|
115 |
+
super(CausalConv1d, self).__init__(
|
116 |
+
in_channels=in_channels,
|
117 |
+
out_channels=out_channels,
|
118 |
+
kernel_size=kernel_size,
|
119 |
+
stride=stride,
|
120 |
+
padding=0,
|
121 |
+
dilation=dilation,
|
122 |
+
groups=groups,
|
123 |
+
bias=bias,
|
124 |
+
)
|
125 |
+
self.stride = stride
|
126 |
+
self.pad_length = (kernel_size - 1) * dilation
|
127 |
+
def forward(self, x):
|
128 |
+
pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
|
129 |
+
x = pad(x)
|
130 |
+
return self.conv(x)
|
131 |
+
|
132 |
+
|
133 |
+
class CausalConvTranspose1d(NonCausalConvTranspose1d):
|
134 |
+
def __init__(
|
135 |
+
self,
|
136 |
+
in_channels,
|
137 |
+
out_channels,
|
138 |
+
kernel_size,
|
139 |
+
stride,
|
140 |
+
bias=True,
|
141 |
+
pad_buffer=None,
|
142 |
+
):
|
143 |
+
super(CausalConvTranspose1d, self).__init__(
|
144 |
+
in_channels=in_channels,
|
145 |
+
out_channels=out_channels,
|
146 |
+
kernel_size=kernel_size,
|
147 |
+
stride=stride,
|
148 |
+
padding=0,
|
149 |
+
output_padding=0,
|
150 |
+
bias=bias,
|
151 |
+
)
|
152 |
+
self.stride = stride
|
153 |
+
self.pad_length = (math.ceil(kernel_size/stride) - 1)
|
154 |
+
if pad_buffer is None:
|
155 |
+
pad_buffer = T.zeros(1, in_channels, self.pad_length)
|
156 |
+
self.register_buffer("pad_buffer", pad_buffer)
|
157 |
+
|
158 |
+
def forward(self, x):
|
159 |
+
pad = nn.ReplicationPad1d((self.pad_length, 0))
|
160 |
+
x = pad(x)
|
161 |
+
return self.deconv(x)[:, :, self.stride : -self.stride]
|
162 |
+
|
163 |
+
def inference(self, x):
|
164 |
+
x = T.cat((self.pad_buffer, x), -1)
|
165 |
+
self.pad_buffer = x[:, :, -self.pad_length:]
|
166 |
+
return self.deconv(x)[:, :, self.stride : -self.stride]
|
167 |
+
|
168 |
+
def reset_buffer(self):
|
169 |
+
self.pad_buffer.zero_()
|
170 |
+
|
171 |
+
|
172 |
+
class NonCausalResUnit(nn.Module):
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
in_channels,
|
176 |
+
out_channels,
|
177 |
+
kernel_size=7,
|
178 |
+
dilation=1,
|
179 |
+
bias=False,
|
180 |
+
):
|
181 |
+
super().__init__()
|
182 |
+
self.activation = nn.ELU()
|
183 |
+
self.conv1 = NonCausalConv1d(
|
184 |
+
in_channels=in_channels,
|
185 |
+
out_channels=out_channels,
|
186 |
+
kernel_size=kernel_size,
|
187 |
+
stride=1,
|
188 |
+
dilation=dilation,
|
189 |
+
bias=bias,
|
190 |
+
)
|
191 |
+
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
192 |
+
|
193 |
+
def forward(self, x):
|
194 |
+
y = self.conv1(self.activation(x))
|
195 |
+
y = self.conv2(self.activation(y))
|
196 |
+
return x + y
|
197 |
+
|
198 |
+
|
199 |
+
class CausalResUnit(NonCausalResUnit):
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
in_channels,
|
203 |
+
out_channels,
|
204 |
+
kernel_size=7,
|
205 |
+
dilation=1,
|
206 |
+
bias=False,
|
207 |
+
):
|
208 |
+
super(CausalResUnit, self).__init__(
|
209 |
+
in_channels=in_channels,
|
210 |
+
out_channels=out_channels,
|
211 |
+
kernel_size=kernel_size,
|
212 |
+
dilation=dilation,
|
213 |
+
bias=bias,
|
214 |
+
)
|
215 |
+
self.conv1 = CausalConv1d(
|
216 |
+
in_channels=in_channels,
|
217 |
+
out_channels=out_channels,
|
218 |
+
kernel_size=kernel_size,
|
219 |
+
stride=1,
|
220 |
+
dilation=dilation,
|
221 |
+
bias=bias,
|
222 |
+
)
|
223 |
+
|
224 |
+
def inference(self, x):
|
225 |
+
y = self.conv1.inference(self.activation(x))
|
226 |
+
y = self.conv2(self.activation(y))
|
227 |
+
return x + y
|
228 |
+
|
229 |
+
|
230 |
+
class ResNetBlock(nn.Module):
|
231 |
+
def __init__(self,
|
232 |
+
in_channels,
|
233 |
+
out_channels,
|
234 |
+
stride,
|
235 |
+
kernel_size=7,
|
236 |
+
dilations=(1, 3, 9),
|
237 |
+
bias=True,
|
238 |
+
mode='encoder',
|
239 |
+
):
|
240 |
+
super().__init__()
|
241 |
+
assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
|
242 |
+
|
243 |
+
self.mode = mode
|
244 |
+
self.stride = stride
|
245 |
+
|
246 |
+
ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
|
247 |
+
|
248 |
+
res_channels = in_channels if mode == 'encoder' else out_channels
|
249 |
+
|
250 |
+
res_units = [CausalResUnit(
|
251 |
+
res_channels,
|
252 |
+
res_channels,
|
253 |
+
kernel_size=kernel_size,
|
254 |
+
dilation=dilation,
|
255 |
+
) for dilation in dilations]
|
256 |
+
|
257 |
+
if in_channels == out_channels:
|
258 |
+
if mode == 'encoder':
|
259 |
+
self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
260 |
+
if mode == 'decoder':
|
261 |
+
self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
|
262 |
+
conv_unit = nn.Conv1d(
|
263 |
+
in_channels=in_channels,
|
264 |
+
out_channels=out_channels,
|
265 |
+
kernel_size=1,
|
266 |
+
bias=bias,
|
267 |
+
) if in_channels != out_channels else nn.Identity()
|
268 |
+
else:
|
269 |
+
conv_unit = ConvUnit(
|
270 |
+
in_channels=in_channels,
|
271 |
+
out_channels=out_channels,
|
272 |
+
kernel_size=(2 * stride),
|
273 |
+
stride=stride,
|
274 |
+
bias=bias,
|
275 |
+
)
|
276 |
+
|
277 |
+
if mode == 'encoder':
|
278 |
+
if in_channels == out_channels:
|
279 |
+
self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
|
280 |
+
else:
|
281 |
+
self.res_block = nn.Sequential(*res_units, conv_unit)
|
282 |
+
elif mode == 'decoder':
|
283 |
+
if in_channels == out_channels:
|
284 |
+
self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
|
285 |
+
else:
|
286 |
+
self.res_block = nn.Sequential(conv_unit, *res_units)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
out = x
|
290 |
+
for unit in self.res_block:
|
291 |
+
out = unit(out)
|
292 |
+
return out
|
293 |
+
|
294 |
+
def inference(self, x):
|
295 |
+
for unit in self.res_block:
|
296 |
+
x = unit.inference(x)
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
@si_module
|
303 |
+
class ResNetStack(nn.Module):
|
304 |
+
"""
|
305 |
+
ResNet encoder or decoder stack. Channel ratios
|
306 |
+
and strides take the default order of from
|
307 |
+
data/io-layer, to the middle of the model.
|
308 |
+
"""
|
309 |
+
class Config:
|
310 |
+
input_channels: int = 1
|
311 |
+
output_channels: int = 1
|
312 |
+
encode_channels: int = 32
|
313 |
+
decode_channel_multiplier: int = 1
|
314 |
+
latent_dim: int = None
|
315 |
+
kernel_size: int = 7
|
316 |
+
bias: bool = True
|
317 |
+
channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
|
318 |
+
strides: Tuple[int, ...] = (3, 4, 5, 5)
|
319 |
+
mode: Literal['encoder', 'decoder'] = 'encoder'
|
320 |
+
|
321 |
+
def __init__(self, c: Config):
|
322 |
+
super().__init__()
|
323 |
+
assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
|
324 |
+
|
325 |
+
self.mode = c.mode
|
326 |
+
|
327 |
+
assert len(c.channel_ratios) == len(c.strides)
|
328 |
+
channel_ratios = (1,) + c.channel_ratios
|
329 |
+
strides = c.strides
|
330 |
+
self.middle_channels = c.encode_channels * channel_ratios[-1]
|
331 |
+
if c.mode == 'decoder':
|
332 |
+
channel_ratios = tuple(reversed(channel_ratios))
|
333 |
+
strides = tuple(reversed(strides))
|
334 |
+
|
335 |
+
self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
|
336 |
+
res_blocks = [ResNetBlock(
|
337 |
+
c.encode_channels * channel_ratios[s_idx] * self.multiplier,
|
338 |
+
c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
|
339 |
+
stride,
|
340 |
+
kernel_size=c.kernel_size,
|
341 |
+
bias=c.bias,
|
342 |
+
mode=c.mode,
|
343 |
+
) for s_idx, stride in enumerate(strides)]
|
344 |
+
|
345 |
+
data_conv = CausalConv1d(
|
346 |
+
in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
|
347 |
+
out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
|
348 |
+
kernel_size=c.kernel_size,
|
349 |
+
stride=1,
|
350 |
+
bias=False,
|
351 |
+
)
|
352 |
+
|
353 |
+
if c.mode == 'encoder':
|
354 |
+
self.res_stack = nn.Sequential(data_conv, *res_blocks)
|
355 |
+
elif c.mode == 'decoder':
|
356 |
+
self.res_stack = nn.Sequential(*res_blocks, data_conv)
|
357 |
+
|
358 |
+
if c.latent_dim is not None:
|
359 |
+
self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
|
360 |
+
if self.multiplier != 1:
|
361 |
+
self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
|
362 |
+
|
363 |
+
def forward(self, x, return_feats=False):
|
364 |
+
if self.c.latent_dim is not None and self.mode == 'decoder':
|
365 |
+
x = self.latent_proj(x)
|
366 |
+
if self.multiplier != 1:
|
367 |
+
x = self.multiplier_proj(x)
|
368 |
+
|
369 |
+
feats = []
|
370 |
+
for block in self.res_stack:
|
371 |
+
x = block(x)
|
372 |
+
if return_feats:
|
373 |
+
feats.append(x)
|
374 |
+
if self.c.latent_dim is not None and self.mode == 'encoder':
|
375 |
+
x = self.latent_proj(x)
|
376 |
+
if return_feats:
|
377 |
+
feats.append(x)
|
378 |
+
if return_feats:
|
379 |
+
return feats
|
380 |
+
return x
|
381 |
+
|
382 |
+
def inference(self, x):
|
383 |
+
for block in self.res_stack:
|
384 |
+
x = block.inference(x)
|
385 |
+
return x
|
386 |
+
|
387 |
+
def reset_buffer(self):
|
388 |
+
def _reset_buffer(m):
|
389 |
+
if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
|
390 |
+
m.reset_buffer()
|
391 |
+
self.apply(_reset_buffer)
|
392 |
+
|
393 |
+
def reset_parameters(self):
|
394 |
+
def _reset_parameters(m):
|
395 |
+
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
|
396 |
+
m.weight.data.normal_(0.0, 0.01)
|
397 |
+
|
398 |
+
self.apply(_reset_parameters)
|
399 |
+
|
400 |
+
|
401 |
+
def apply_weight_norm(self):
|
402 |
+
def _apply_weight_norm(m):
|
403 |
+
if isinstance(m, nn.Conv1d) or isinstance(
|
404 |
+
m, nn.ConvTranspose1d
|
405 |
+
):
|
406 |
+
nn.utils.parametrizations.weight_norm(m)
|
407 |
+
|
408 |
+
self.apply(_apply_weight_norm)
|
409 |
+
|
410 |
+
|
411 |
+
def remove_weight_norm(self):
|
412 |
+
def _remove_weight_norm(m):
|
413 |
+
try:
|
414 |
+
print(m)
|
415 |
+
nn.utils.remove_weight_norm(m)
|
416 |
+
except ValueError: # this module didn't have weight norm
|
417 |
+
return
|
418 |
+
|
419 |
+
self.apply(_remove_weight_norm)
|
420 |
+
|
421 |
+
|
422 |
+
|
423 |
+
@si_module
|
424 |
+
class GaussianZ(nn.Module):
|
425 |
+
class Config:
|
426 |
+
dim: int
|
427 |
+
latent_dim: int
|
428 |
+
bias: bool = False
|
429 |
+
use_weight_norm: bool = False
|
430 |
+
|
431 |
+
def __init__(self, c: Config):
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
|
435 |
+
self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
|
436 |
+
|
437 |
+
if c.use_weight_norm:
|
438 |
+
self.proj_in = weight_norm(self.proj_in)
|
439 |
+
self.proj_out = weight_norm(self.proj_out)
|
440 |
+
|
441 |
+
def reparam(self, mu, logvar):
|
442 |
+
std = T.exp(logvar / 2)
|
443 |
+
eps = T.randn_like(std)
|
444 |
+
return mu + eps * std
|
445 |
+
|
446 |
+
def kl_divergence(self, mu, logvar):
|
447 |
+
return T.mean(-0.5 * T.sum(
|
448 |
+
1 + logvar - mu.pow(2) - logvar.exp(),
|
449 |
+
dim=(1, 2))
|
450 |
+
)
|
451 |
+
|
452 |
+
def repr_from_latent(self, latent: Union[dict, T.Tensor]):
|
453 |
+
if isinstance(latent, T.Tensor):
|
454 |
+
z = latent
|
455 |
+
else:
|
456 |
+
z = self.reparam(latent['mu'], latent['logvar'])
|
457 |
+
l = self.proj_out(z)
|
458 |
+
return l
|
459 |
+
|
460 |
+
def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
|
461 |
+
mu, logvar = self.proj_in(x).chunk(2, dim=-1)
|
462 |
+
kl_div = self.kl_divergence(mu, logvar)
|
463 |
+
z = self.reparam(mu, logvar)
|
464 |
+
xhat = self.proj_out(z)
|
465 |
+
latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
|
466 |
+
return xhat, latent
|
467 |
+
|
468 |
+
|
469 |
+
|
470 |
+
@si_module
|
471 |
+
class WaveCodec(nn.Module):
|
472 |
+
class Config:
|
473 |
+
resnet_config: ResNetStack.Config = None
|
474 |
+
sample_rate: int = 16_000
|
475 |
+
use_weight_norm: bool = False
|
476 |
+
|
477 |
+
compressor_config: dataclass = None
|
478 |
+
|
479 |
+
norm_stddev: float = 1.0
|
480 |
+
|
481 |
+
def __init__(self, c: Config):
|
482 |
+
super().__init__()
|
483 |
+
self.norm_stddev = c.norm_stddev
|
484 |
+
self.encoder = c.resnet_config(mode='encoder')
|
485 |
+
self.sample_rate = c.sample_rate
|
486 |
+
|
487 |
+
self.total_stride = 1
|
488 |
+
for stride in c.resnet_config.strides:
|
489 |
+
self.total_stride *= stride
|
490 |
+
self.tokens_per_second = self.sample_rate / self.total_stride
|
491 |
+
|
492 |
+
self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
|
493 |
+
|
494 |
+
self.decoder = c.resnet_config(mode='decoder')
|
495 |
+
|
496 |
+
if c.use_weight_norm:
|
497 |
+
self.encoder.apply_weight_norm()
|
498 |
+
self.decoder.apply_weight_norm()
|
499 |
+
self.encoder.reset_parameters()
|
500 |
+
self.decoder.reset_parameters()
|
501 |
+
|
502 |
+
def encode(self, data):
|
503 |
+
return self.encoder(data/self.norm_stddev)
|
504 |
+
|
505 |
+
def decode(self, latent):
|
506 |
+
return self.decoder(latent.transpose(1, 2))*self.norm_stddev
|
507 |
+
|
508 |
+
@T.no_grad()
|
509 |
+
def latent_from_data(self, data, get_parameters=False):
|
510 |
+
x = self.encode(data)
|
511 |
+
l_in = x.transpose(1, 2)
|
512 |
+
l, latent = self.compressor(l_in)
|
513 |
+
return latent['z'] if not get_parameters else {
|
514 |
+
'mu': latent['mu'],
|
515 |
+
'logvar': latent['logvar'],
|
516 |
+
'z': latent['z'],
|
517 |
+
}
|
518 |
+
|
519 |
+
@T.no_grad()
|
520 |
+
def data_from_latent(self, latent):
|
521 |
+
l = self.compressor.repr_from_latent(latent)
|
522 |
+
x = self.decode(l)
|
523 |
+
return x
|
524 |
+
|
525 |
+
def process(self, x):
|
526 |
+
return self.latent_from_data(x)
|
527 |
+
|
528 |
+
def unprocess(self, latent):
|
529 |
+
return self.data_from_latent(latent)
|
530 |
+
|
531 |
+
def forward(self, audio_input):
|
532 |
+
x = self.encode(audio_input)
|
533 |
+
|
534 |
+
l_in = x.transpose(1, 2)
|
535 |
+
l, latent = self.compressor(l_in)
|
536 |
+
|
537 |
+
xhat = self.decode(l)
|
538 |
+
return xhat, latent
|
539 |
+
|
540 |
+
|
541 |
+
|
542 |
+
def make_tokenizer(device='cuda'):
|
543 |
+
generator_config = WaveCodec.Config(
|
544 |
+
resnet_config=ResNetStack.Config(
|
545 |
+
input_channels=1,
|
546 |
+
output_channels=1,
|
547 |
+
encode_channels=16,
|
548 |
+
decode_channel_multiplier=4,
|
549 |
+
kernel_size=7,
|
550 |
+
bias=True,
|
551 |
+
channel_ratios=(4, 8, 16, 16, 16, 16),
|
552 |
+
strides=(2, 2, 4, 5, 5, 5),
|
553 |
+
mode=None,
|
554 |
+
),
|
555 |
+
use_weight_norm=True,
|
556 |
+
|
557 |
+
compressor_config=GaussianZ.Config(
|
558 |
+
dim=None,
|
559 |
+
latent_dim=32,
|
560 |
+
|
561 |
+
bias=True,
|
562 |
+
use_weight_norm=True
|
563 |
+
),
|
564 |
+
|
565 |
+
norm_stddev=0.05,
|
566 |
+
)
|
567 |
+
checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
|
568 |
+
|
569 |
+
tokenizer = generator_config()
|
570 |
+
|
571 |
+
load_result = tokenizer.load_state_dict(checkpoint, strict=False)
|
572 |
+
print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
|
573 |
+
|
574 |
+
tokenizer = tokenizer.eval()
|
575 |
+
# Only convert to bfloat16 if using CUDA
|
576 |
+
if device == 'cuda':
|
577 |
+
tokenizer = tokenizer.bfloat16()
|
578 |
+
tokenizer = tokenizer.to(device)
|
579 |
+
tokenizer.requires_grad_ = False
|
580 |
+
return tokenizer
|
581 |
+
|
transformer.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple, MutableMapping
|
2 |
+
from typing import Union
|
3 |
+
import math
|
4 |
+
from contextlib import nullcontext
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import Tensor
|
11 |
+
from torch.nn.attention import SDPBackend
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
|
15 |
+
from utils import si_module, default, exists, load_ckpt
|
16 |
+
|
17 |
+
CACHE_FILL_VALUE = -1
|
18 |
+
|
19 |
+
def get_cache_len(cache: Optional[Tensor]) -> int:
|
20 |
+
"""
|
21 |
+
cache: (batch, seq_len, 2, kv_heads, head_dim)
|
22 |
+
"""
|
23 |
+
if cache is None:
|
24 |
+
return 0
|
25 |
+
nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
|
26 |
+
length = nonzeros.sum(dim=-1).int()
|
27 |
+
assert T.all(length == length[0])
|
28 |
+
return length[0]
|
29 |
+
|
30 |
+
|
31 |
+
def rotate_half(x):
|
32 |
+
x1, x2 = x.chunk(2, dim=-1)
|
33 |
+
return torch.cat((-x2, x1), dim=-1)
|
34 |
+
|
35 |
+
|
36 |
+
def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
|
37 |
+
assert (
|
38 |
+
cos.shape[1] >= offset + x.shape[1]
|
39 |
+
), f"Offset and/or input sequence is too large,\
|
40 |
+
\n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
|
41 |
+
|
42 |
+
cos_out = cos[:, offset : offset + x.shape[1], :, :]
|
43 |
+
sin_out = sin[:, offset : offset + x.shape[1], :, :]
|
44 |
+
|
45 |
+
return (x * cos_out) + (rotate_half(x) * sin_out)
|
46 |
+
|
47 |
+
|
48 |
+
# Adapted from https://github.com/foundation-model-stack/foundation-model-stack
|
49 |
+
class ShapeRotator:
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
dim: int,
|
53 |
+
end: int,
|
54 |
+
theta: float = 10_000,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.dim = dim
|
58 |
+
self.ratio = theta
|
59 |
+
self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
|
60 |
+
self.max_seq_len_cached: MutableMapping[int, int] = {}
|
61 |
+
self.ntk_scaling = False
|
62 |
+
self.max_seq_len = end
|
63 |
+
|
64 |
+
def compute_freqs_cis(self, device, max_seq_len=None):
|
65 |
+
alpha = 1
|
66 |
+
dev_idx = device.index
|
67 |
+
max_seq_len = default(max_seq_len, self.max_seq_len)
|
68 |
+
|
69 |
+
if dev_idx not in self.cached_freqs:
|
70 |
+
self.cached_freqs[dev_idx] = {}
|
71 |
+
if dev_idx not in self.max_seq_len_cached:
|
72 |
+
self.max_seq_len_cached[dev_idx] = 0
|
73 |
+
|
74 |
+
|
75 |
+
if self.max_seq_len_cached[dev_idx] > 0:
|
76 |
+
return 1
|
77 |
+
max_seq_len = max(max_seq_len, self.max_seq_len)
|
78 |
+
|
79 |
+
if (
|
80 |
+
1 in self.cached_freqs[dev_idx]
|
81 |
+
and max_seq_len <= self.max_seq_len_cached[dev_idx]
|
82 |
+
):
|
83 |
+
return 1
|
84 |
+
|
85 |
+
ratio = self.ratio
|
86 |
+
dim = self.dim
|
87 |
+
|
88 |
+
freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
89 |
+
|
90 |
+
t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
|
91 |
+
freqs = torch.einsum("i,j->ij", t, freqs)
|
92 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
93 |
+
|
94 |
+
cos_to_cache = emb.cos()[None, :, None, :]
|
95 |
+
sin_to_cache = emb.sin()[None, :, None, :]
|
96 |
+
|
97 |
+
self.max_seq_len_cached[dev_idx] = max_seq_len
|
98 |
+
|
99 |
+
self.cached_freqs[dev_idx][alpha] = torch.stack(
|
100 |
+
[
|
101 |
+
cos_to_cache,
|
102 |
+
sin_to_cache,
|
103 |
+
],
|
104 |
+
dim=-1,
|
105 |
+
)
|
106 |
+
|
107 |
+
return alpha
|
108 |
+
|
109 |
+
def rotate(
|
110 |
+
self,
|
111 |
+
q: Tensor,
|
112 |
+
k: Tensor,
|
113 |
+
offset: int = 0,
|
114 |
+
) -> Tuple[Tensor, Tensor]:
|
115 |
+
"""
|
116 |
+
Args
|
117 |
+
----
|
118 |
+
q : torch.Tensor
|
119 |
+
Embedded query tensor, expected size is B x S x H x Eh
|
120 |
+
k : torch.Tensor
|
121 |
+
Embedded query tensor, expected size is B x S x H x Eh
|
122 |
+
"""
|
123 |
+
assert len(q.size()) == 4
|
124 |
+
assert len(k.size()) == 4
|
125 |
+
|
126 |
+
seq_len = self.max_seq_len
|
127 |
+
alpha = self.compute_freqs_cis(q.device, seq_len)
|
128 |
+
freqs = self.cached_freqs[q.device.index][alpha]
|
129 |
+
|
130 |
+
freqs = freqs.float() # 1 L D/2 2 2
|
131 |
+
q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
|
132 |
+
k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
|
133 |
+
|
134 |
+
return q_out.view_as(q), k_out.view_as(k)
|
135 |
+
|
136 |
+
class Linear(nn.Linear):
|
137 |
+
def __init__(self, *args, **kwargs):
|
138 |
+
super().__init__(*args, **kwargs, bias=False)
|
139 |
+
|
140 |
+
class Norm(nn.Module):
|
141 |
+
def __init__(self,
|
142 |
+
dim: int,
|
143 |
+
eps: float = 1e-5,) -> None:
|
144 |
+
super().__init__()
|
145 |
+
self.eps = eps
|
146 |
+
self.weight = nn.Parameter(T.ones((dim,)))
|
147 |
+
|
148 |
+
def forward(self, input: Tensor) -> Tensor:
|
149 |
+
return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
|
150 |
+
|
151 |
+
|
152 |
+
class FFNN(nn.Module):
|
153 |
+
def __init__(self,
|
154 |
+
dim: int,
|
155 |
+
expand_dim: int = None,):
|
156 |
+
super().__init__()
|
157 |
+
expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
|
158 |
+
self.dim = dim
|
159 |
+
self.expand_dim = expand_dim
|
160 |
+
|
161 |
+
self.gateup_proj = Linear(dim, 2*expand_dim)
|
162 |
+
self.down_proj = Linear(expand_dim, dim)
|
163 |
+
|
164 |
+
def forward(self, x):
|
165 |
+
gate, up = self.gateup_proj(x).chunk(2, dim=-1)
|
166 |
+
return self.down_proj(up * F.silu(gate))
|
167 |
+
|
168 |
+
class GQA(nn.Module):
|
169 |
+
def __init__(self,
|
170 |
+
dim: int,
|
171 |
+
n_head: int,
|
172 |
+
shape_rotator: ShapeRotator,
|
173 |
+
kv_heads: Optional[int] = None,
|
174 |
+
eps: float = 1e-5,
|
175 |
+
causal: bool = True,):
|
176 |
+
super().__init__()
|
177 |
+
self.n_heads = n_head
|
178 |
+
self.kv_heads = default(kv_heads, n_head)
|
179 |
+
self.head_dim = dim // n_head
|
180 |
+
self.causal = causal
|
181 |
+
|
182 |
+
self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
|
183 |
+
|
184 |
+
self.norm_q = Norm(self.head_dim*n_head, eps=eps)
|
185 |
+
self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
|
186 |
+
|
187 |
+
self.attn_out = Linear(dim, dim)
|
188 |
+
|
189 |
+
self.shape_rotator = shape_rotator
|
190 |
+
|
191 |
+
def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
192 |
+
k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
193 |
+
v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
194 |
+
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext():
|
195 |
+
x = F.scaled_dot_product_attention(
|
196 |
+
q.transpose(1, 2),
|
197 |
+
k.transpose(1, 2),
|
198 |
+
v.transpose(1, 2),
|
199 |
+
is_causal=False if (q.size(1) != k.size(1)) else self.causal,
|
200 |
+
)
|
201 |
+
x = x.transpose(1, 2).contiguous()
|
202 |
+
return x
|
203 |
+
|
204 |
+
def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
|
205 |
+
cache_len = get_cache_len(kv_cache)
|
206 |
+
q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
|
207 |
+
if exists(kv_cache):
|
208 |
+
k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
|
209 |
+
v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
|
210 |
+
kv_cache[:, :k.size(1), 0] = k
|
211 |
+
kv_cache[:, :v.size(1), 1] = v
|
212 |
+
x = self._sdpa(q, k, v)
|
213 |
+
return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
|
214 |
+
|
215 |
+
def _project(self, x):
|
216 |
+
full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
|
217 |
+
normed_full_q = self.norm_q(full_q).to(full_q.dtype)
|
218 |
+
normed_full_k = self.norm_k(full_k).to(full_k.dtype)
|
219 |
+
|
220 |
+
q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
|
221 |
+
k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
|
222 |
+
v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
|
223 |
+
return q, k, v
|
224 |
+
|
225 |
+
def forward(self,
|
226 |
+
x: Tensor,
|
227 |
+
kv: Optional[Tensor] = None,):
|
228 |
+
"""
|
229 |
+
x: (B, S, D)
|
230 |
+
kv: (B, S, H, D)
|
231 |
+
"""
|
232 |
+
q, k, v = self._project(x)
|
233 |
+
return self._attend(q, k, v, kv_cache=kv)
|
234 |
+
|
235 |
+
|
236 |
+
class PreNormAttn(nn.Module):
|
237 |
+
def __init__(self,
|
238 |
+
dim: int,
|
239 |
+
n_head: int,
|
240 |
+
shape_rotator: ShapeRotator,
|
241 |
+
kv_heads: Optional[int] = None,
|
242 |
+
eps: float = 1e-5,
|
243 |
+
causal: bool = True,):
|
244 |
+
super().__init__()
|
245 |
+
self.attn_norm = Norm(dim, eps=eps)
|
246 |
+
self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
247 |
+
|
248 |
+
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
249 |
+
"""
|
250 |
+
x: (B, S, D)
|
251 |
+
kv: (B, S, H, D)
|
252 |
+
"""
|
253 |
+
return x + self.attn(self.attn_norm(x), kv)
|
254 |
+
|
255 |
+
class PreNormFFNN(nn.Module):
|
256 |
+
def __init__(self,
|
257 |
+
dim: int,
|
258 |
+
ff_dim: int,
|
259 |
+
eps: float = 1e-5,):
|
260 |
+
super().__init__()
|
261 |
+
self.ffnn_norm = Norm(dim, eps=eps)
|
262 |
+
self.ffnn = FFNN(dim, ff_dim)
|
263 |
+
|
264 |
+
def forward(self, x: Tensor) -> Tensor:
|
265 |
+
return x + self.ffnn(self.ffnn_norm(x))
|
266 |
+
|
267 |
+
class Block(nn.Module):
|
268 |
+
def __init__(self,
|
269 |
+
dim: int,
|
270 |
+
layer_id: int = 0,
|
271 |
+
n_head: int = 16,
|
272 |
+
kv_heads: Optional[int] = None,
|
273 |
+
ff_dim: Optional[int] = None,
|
274 |
+
eps: float = 1e-5,
|
275 |
+
causal: bool = True,
|
276 |
+
shape_rotator: ShapeRotator = None):
|
277 |
+
super().__init__()
|
278 |
+
self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
279 |
+
self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
|
280 |
+
self.dim = dim
|
281 |
+
self.layer_id = layer_id
|
282 |
+
self.head_dim = dim // n_head
|
283 |
+
self.expand_dim = self.ffnn.ffnn.expand_dim
|
284 |
+
|
285 |
+
self.reset_parameters()
|
286 |
+
|
287 |
+
def reset_parameters(self):
|
288 |
+
std = 1.0 / math.sqrt(self.dim)
|
289 |
+
nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
|
290 |
+
nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
|
291 |
+
nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
|
292 |
+
|
293 |
+
xstd = 1.0 / math.sqrt(self.expand_dim)
|
294 |
+
nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
|
295 |
+
|
296 |
+
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
297 |
+
"""
|
298 |
+
x: (B, S, D)
|
299 |
+
kv: (B, S, H, D)
|
300 |
+
"""
|
301 |
+
h = self.attn(x, kv)
|
302 |
+
out = self.ffnn(h)
|
303 |
+
return out
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
class GPTOutput(nn.Module):
|
308 |
+
def __init__(self, dim, vocab_size):
|
309 |
+
super().__init__()
|
310 |
+
self.dim = dim
|
311 |
+
self.norm = Norm(dim)
|
312 |
+
self.output = Linear(dim, vocab_size)
|
313 |
+
|
314 |
+
self.reset_parameters()
|
315 |
+
|
316 |
+
def reset_parameters(self):
|
317 |
+
std = 1.0 / math.sqrt(self.dim**2)
|
318 |
+
nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
|
319 |
+
|
320 |
+
def forward(self, x):
|
321 |
+
return self.output(self.norm(x))
|
322 |
+
|
323 |
+
@si_module
|
324 |
+
class Stack(nn.Module):
|
325 |
+
class Config:
|
326 |
+
layers: int
|
327 |
+
dim: int
|
328 |
+
seq_len: int
|
329 |
+
n_head: int = 32
|
330 |
+
ff_dim: int = None
|
331 |
+
kv_heads: int = None
|
332 |
+
eps: float = 1e-5
|
333 |
+
theta: Union[int, float] = 10_000
|
334 |
+
causal: bool = True
|
335 |
+
|
336 |
+
from_pretrained: Optional[Tuple[str, int]] = None
|
337 |
+
|
338 |
+
def __init__(self, c: Config):
|
339 |
+
super().__init__()
|
340 |
+
|
341 |
+
from_pretrained = c.from_pretrained
|
342 |
+
if exists(from_pretrained):
|
343 |
+
checkpoint = load_ckpt(c.from_pretrained)
|
344 |
+
|
345 |
+
self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
|
346 |
+
|
347 |
+
self.layers = nn.ModuleList([
|
348 |
+
Block(
|
349 |
+
dim=c.dim,
|
350 |
+
layer_id=l,
|
351 |
+
n_head=c.n_head,
|
352 |
+
kv_heads=c.kv_heads,
|
353 |
+
ff_dim=c.ff_dim,
|
354 |
+
eps=c.eps,
|
355 |
+
causal=c.causal,
|
356 |
+
shape_rotator=self.shape_rotator,
|
357 |
+
) for l in range(c.layers)
|
358 |
+
])
|
359 |
+
|
360 |
+
kv_heads = c.kv_heads or c.n_head
|
361 |
+
head_dim = c.dim // c.n_head
|
362 |
+
cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
|
363 |
+
self.cache_shape = cache_shape
|
364 |
+
self.cache = [None] * c.layers
|
365 |
+
|
366 |
+
if exists(from_pretrained):
|
367 |
+
self.load_state_dict(checkpoint)
|
368 |
+
|
369 |
+
def init_cache(self, bsize, device, dtype, length:int=None):
|
370 |
+
if self.cache_shape is None:
|
371 |
+
return
|
372 |
+
cache_shape = self.cache_shape.copy()
|
373 |
+
cache_shape[1] = length or cache_shape[1]
|
374 |
+
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
375 |
+
|
376 |
+
def deinit_cache(self):
|
377 |
+
self.cache = [None] * len(self.cache)
|
378 |
+
|
379 |
+
def forward(self, x: Tensor) -> Tensor:
|
380 |
+
for l, layer in enumerate(self.layers):
|
381 |
+
x = layer(x, kv=self.cache[l])
|
382 |
+
return x
|
utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .blocks import *
|
2 |
+
from .dist import *
|
3 |
+
from .interp import *
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (211 Bytes). View file
|
|
utils/__pycache__/blocks.cpython-310.pyc
ADDED
Binary file (3.73 kB). View file
|
|
utils/__pycache__/dist.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
utils/__pycache__/interp.cpython-310.pyc
ADDED
Binary file (3.82 kB). View file
|
|
utils/blocks.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import TypeVar, Generic, Type, Optional
|
3 |
+
from functools import wraps
|
4 |
+
import time
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch as T
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
# @TODO: remove si_module from codebase
|
11 |
+
# we use this in our research codebase to make modules from callable configs
|
12 |
+
si_module_TpV = TypeVar('si_module_TpV')
|
13 |
+
def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
|
14 |
+
if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
|
15 |
+
class Config:
|
16 |
+
pass
|
17 |
+
cls.Config = Config
|
18 |
+
|
19 |
+
cls.Config = dataclass(cls.Config)
|
20 |
+
|
21 |
+
class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
|
22 |
+
def __call__(self, *args, **kwargs) -> si_module_TpV:
|
23 |
+
if len(kwargs) > 0:
|
24 |
+
config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
|
25 |
+
config_dict.update(kwargs)
|
26 |
+
new_config = type(self)(**config_dict)
|
27 |
+
return cls(new_config)
|
28 |
+
else:
|
29 |
+
return cls(self, *args)
|
30 |
+
|
31 |
+
ConfigWrapper.__module__ = cls.__module__
|
32 |
+
ConfigWrapper.__name__ = f"{cls.__name__}Config"
|
33 |
+
ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
|
34 |
+
|
35 |
+
cls.Config = ConfigWrapper
|
36 |
+
|
37 |
+
original_init = cls.__init__
|
38 |
+
def new_init(self, *args, **kwargs):
|
39 |
+
self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
|
40 |
+
original_init(self, *args, **kwargs)
|
41 |
+
self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
|
42 |
+
|
43 |
+
cls.__init__ = new_init
|
44 |
+
|
45 |
+
@property
|
46 |
+
def device(self):
|
47 |
+
return self._device_tracker.device
|
48 |
+
|
49 |
+
@property
|
50 |
+
def dtype(self):
|
51 |
+
return self._device_tracker.dtype
|
52 |
+
|
53 |
+
cls.device = device
|
54 |
+
cls.dtype = dtype
|
55 |
+
|
56 |
+
return cls
|
57 |
+
|
58 |
+
|
59 |
+
def get_activation(nonlinear_activation, nonlinear_activation_params={}):
|
60 |
+
if hasattr(nn, nonlinear_activation):
|
61 |
+
return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
|
64 |
+
|
65 |
+
|
66 |
+
def exists(v):
|
67 |
+
return v is not None
|
68 |
+
|
69 |
+
def isnt(v):
|
70 |
+
return not exists(v)
|
71 |
+
|
72 |
+
def truthyexists(v):
|
73 |
+
return exists(v) and v is not False
|
74 |
+
|
75 |
+
def truthyattr(obj, attr):
|
76 |
+
return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
|
77 |
+
|
78 |
+
defaultT = TypeVar('defaultT')
|
79 |
+
|
80 |
+
def default(*args: Optional[defaultT]) -> Optional[defaultT]:
|
81 |
+
for arg in args:
|
82 |
+
if exists(arg):
|
83 |
+
return arg
|
84 |
+
return None
|
85 |
+
|
86 |
+
def maybe(fn):
|
87 |
+
@wraps(fn)
|
88 |
+
def inner(x, *args, **kwargs):
|
89 |
+
if not exists(x):
|
90 |
+
return x
|
91 |
+
return fn(x, *args, **kwargs)
|
92 |
+
return inner
|
utils/dist.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch as T
|
3 |
+
import re
|
4 |
+
from tqdm import tqdm
|
5 |
+
from datetime import timedelta
|
6 |
+
|
7 |
+
import requests
|
8 |
+
import hashlib
|
9 |
+
|
10 |
+
from io import BytesIO
|
11 |
+
|
12 |
+
def rank0():
|
13 |
+
rank = os.environ.get('RANK')
|
14 |
+
if rank is None or rank == '0':
|
15 |
+
return True
|
16 |
+
else:
|
17 |
+
return False
|
18 |
+
|
19 |
+
def local0():
|
20 |
+
local_rank = os.environ.get('LOCAL_RANK')
|
21 |
+
if local_rank is None or local_rank == '0':
|
22 |
+
return True
|
23 |
+
else:
|
24 |
+
return False
|
25 |
+
class tqdm0(tqdm):
|
26 |
+
def __init__(self, *args, **kwargs):
|
27 |
+
total = kwargs.get('total', None)
|
28 |
+
if total is None and len(args) > 0:
|
29 |
+
try:
|
30 |
+
total = len(args[0])
|
31 |
+
except TypeError:
|
32 |
+
pass
|
33 |
+
if total is not None:
|
34 |
+
kwargs['miniters'] = max(1, total // 20)
|
35 |
+
super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
|
36 |
+
|
37 |
+
def print0(*args, **kwargs):
|
38 |
+
if rank0():
|
39 |
+
print(*args, **kwargs)
|
40 |
+
|
41 |
+
_PRINTED_IDS = set()
|
42 |
+
|
43 |
+
def printonce(*args, id=None, **kwargs):
|
44 |
+
if id is None:
|
45 |
+
id = ' '.join(map(str, args))
|
46 |
+
|
47 |
+
if id not in _PRINTED_IDS:
|
48 |
+
print(*args, **kwargs)
|
49 |
+
_PRINTED_IDS.add(id)
|
50 |
+
|
51 |
+
def print0once(*args, **kwargs):
|
52 |
+
if rank0():
|
53 |
+
printonce(*args, **kwargs)
|
54 |
+
|
55 |
+
def init_dist():
|
56 |
+
if T.distributed.is_initialized():
|
57 |
+
print0('Distributed already initialized')
|
58 |
+
rank = T.distributed.get_rank()
|
59 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
60 |
+
world_size = T.distributed.get_world_size()
|
61 |
+
else:
|
62 |
+
try:
|
63 |
+
rank = int(os.environ['RANK'])
|
64 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
65 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
66 |
+
device = f'cuda:{local_rank}'
|
67 |
+
T.cuda.set_device(device)
|
68 |
+
T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
|
69 |
+
print(f'Rank {rank} of {world_size}.')
|
70 |
+
except Exception as e:
|
71 |
+
print0once(f'Not initializing distributed env: {e}')
|
72 |
+
rank = 0
|
73 |
+
local_rank = 0
|
74 |
+
world_size = 1
|
75 |
+
return rank, local_rank, world_size
|
76 |
+
|
77 |
+
def load_ckpt(load_from_location, expected_hash=None):
|
78 |
+
if local0():
|
79 |
+
os.makedirs('ckpt', exist_ok=True)
|
80 |
+
url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt"
|
81 |
+
save_path = f"ckpt/{load_from_location}.pt"
|
82 |
+
if not os.path.exists(save_path):
|
83 |
+
response = requests.get(url, stream=True)
|
84 |
+
total_size = int(response.headers.get('content-length', 0))
|
85 |
+
with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
|
86 |
+
for chunk in response.iter_content(chunk_size=8192):
|
87 |
+
f.write(chunk)
|
88 |
+
pbar.update(len(chunk))
|
89 |
+
if expected_hash is not None:
|
90 |
+
with open(save_path, 'rb') as f:
|
91 |
+
file_hash = hashlib.md5(f.read()).hexdigest()
|
92 |
+
if file_hash != expected_hash:
|
93 |
+
print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
|
94 |
+
os.remove(save_path)
|
95 |
+
return load_ckpt(load_from_location, expected_hash)
|
96 |
+
if T.distributed.is_initialized():
|
97 |
+
T.distributed.barrier() # so that ranks don't try to load checkpoint before it's finished downloading
|
98 |
+
loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
|
99 |
+
return loaded
|
utils/interp.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch as T
|
2 |
+
import os
|
3 |
+
|
4 |
+
def rank0():
|
5 |
+
rank = os.environ.get('RANK')
|
6 |
+
if rank is None or rank == '0':
|
7 |
+
return True
|
8 |
+
else:
|
9 |
+
return False
|
10 |
+
|
11 |
+
def print_colored(message, color='reset', bold=False, **kwargs):
|
12 |
+
color_dict = {
|
13 |
+
'bold': '\033[1m',
|
14 |
+
'green': '\033[92m',
|
15 |
+
'yellow': '\033[93m',
|
16 |
+
'red': '\033[91m',
|
17 |
+
'blue': '\033[94m',
|
18 |
+
'grey': '\033[90m',
|
19 |
+
'white': '\033[97m',
|
20 |
+
'reset': '\033[0m'
|
21 |
+
}
|
22 |
+
|
23 |
+
color_code = color_dict.get(color.lower(), color_dict['reset'])
|
24 |
+
prefix = color_dict['bold'] if bold else ''
|
25 |
+
print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
|
26 |
+
|
27 |
+
def print0_colored(*args, **kwargs):
|
28 |
+
if rank0():
|
29 |
+
print_colored(*args, **kwargs)
|
30 |
+
|
31 |
+
def param_count(module):
|
32 |
+
def count_parameters(model):
|
33 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
34 |
+
|
35 |
+
total_params = count_parameters(module)
|
36 |
+
output = [f'Total model parameters: {total_params:,}', '---------------------------']
|
37 |
+
|
38 |
+
for name, child in module.named_children():
|
39 |
+
params = count_parameters(child)
|
40 |
+
output.append(f'{name} parameters: {params:,}')
|
41 |
+
|
42 |
+
return '\n'.join(output)
|
43 |
+
|
44 |
+
def model_size_estimation(module):
|
45 |
+
def estimate_size(model):
|
46 |
+
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
|
47 |
+
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
|
48 |
+
return param_size + buffer_size
|
49 |
+
|
50 |
+
total_size = estimate_size(module)
|
51 |
+
output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
|
52 |
+
|
53 |
+
for name, child in module.named_children():
|
54 |
+
child_size = estimate_size(child)
|
55 |
+
output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
|
56 |
+
|
57 |
+
return '\n'.join(output)
|
58 |
+
|
59 |
+
def layer_param_distribution(module):
|
60 |
+
def count_parameters(model):
|
61 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
62 |
+
|
63 |
+
def get_layer_types(model):
|
64 |
+
layer_types = {}
|
65 |
+
for name, module in model.named_modules():
|
66 |
+
layer_type = module.__class__.__name__
|
67 |
+
params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
|
68 |
+
if params > 0:
|
69 |
+
if layer_type not in layer_types:
|
70 |
+
layer_types[layer_type] = 0
|
71 |
+
layer_types[layer_type] += params
|
72 |
+
return layer_types
|
73 |
+
|
74 |
+
total_params = count_parameters(module)
|
75 |
+
layer_types = get_layer_types(module)
|
76 |
+
|
77 |
+
output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
|
78 |
+
|
79 |
+
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
|
80 |
+
percentage = (count / total_params) * 100
|
81 |
+
output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
|
82 |
+
|
83 |
+
return '\n'.join(output)
|
84 |
+
|