anantgupta129 commited on
Commit
c5a6a24
1 Parent(s): 7d8f12e

init spaces

Browse files
__pycache__/run.cpython-310.pyc ADDED
Binary file (6.19 kB). View file
 
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import run
4
+
5
+ title = "Lit GPT: Pythia 160M "
6
+
7
+ with gr.Blocks(title=title) as interface:
8
+ with gr.Row():
9
+ prompt = gr.Textbox(label="Input Text")
10
+
11
+ temperature = gr.Slider(
12
+ 0,
13
+ 1,
14
+ value=0.8,
15
+ label="Temperature",
16
+ info="Set the creativity level: Higher values produce more varied results, lower values generate more predictable text.",
17
+ )
18
+ top_k = gr.Slider(
19
+ 200,
20
+ 300,
21
+ value=200,
22
+ label="Top K",
23
+ info="Control the randomness: Limits the AI to consider only the top K most likely next words.",
24
+ )
25
+ max_new_tokens = gr.Slider(
26
+ 10,
27
+ 500,
28
+ value=500,
29
+ label="Max Tokens",
30
+ info="top most preferable tokens to consider in the sampling process",
31
+ )
32
+
33
+ inputs = [prompt, max_new_tokens, top_k, temperature]
34
+
35
+ with gr.Column():
36
+ outputs = gr.Textbox(label="Generated")
37
+ button = gr.Button("Generate")
38
+ button.click(run.generate_from_prompt, inputs=inputs, outputs=outputs)
39
+
40
+ # with gr.Row():
41
+ # gr.Examples(examples=examples, inputs=inputs, outputs=outputs, fn=generate_dialogue, cache_examples=True,)
42
+
43
+
44
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ lightning @ git+https://github.com/Lightning-AI/lightning@6cbe9ceb560d798892bdae9186291acf9bf5d2e3
3
+ tokenizers
4
+ sentencepiece
5
+ bitsandbytes==0.41.0
run.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Any, Literal, Optional
5
+
6
+ import lightning as L
7
+ import torch
8
+ from lightning.fabric.plugins import BitsandbytesPrecision
9
+ from lightning.fabric.strategies import FSDPStrategy
10
+
11
+ from tsai_gpt.model import GPT, Block, Config
12
+ from tsai_gpt.tokenizer import Tokenizer
13
+ from tsai_gpt.utils import (get_default_supported_precision, gptq_quantization,
14
+ load_checkpoint)
15
+
16
+ L.seed_everything(1234)
17
+
18
+
19
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
20
+ if torch._dynamo.is_compiling():
21
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
22
+ distribution = torch.empty_like(probs).exponential_(1)
23
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
24
+ return torch.multinomial(probs, num_samples=1)
25
+
26
+
27
+ def sample(
28
+ logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None
29
+ ) -> torch.Tensor:
30
+ logits = logits[0, -1]
31
+ # optionally crop the logits to only the top k options
32
+ if top_k is not None:
33
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
34
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
35
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
36
+ # optionally scale the logits and sample from a probability distribution
37
+ if temperature > 0.0:
38
+ probs = torch.nn.functional.softmax(logits / temperature, dim=-1)
39
+ return multinomial_num_samples_1(probs)
40
+ return torch.argmax(logits, dim=-1, keepdim=True)
41
+
42
+
43
+ def next_token(
44
+ model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any
45
+ ) -> torch.Tensor:
46
+ logits = model(x, input_pos)
47
+ next = sample(logits, **kwargs)
48
+ return next.type_as(x)
49
+
50
+
51
+ @torch.inference_mode()
52
+ def generate(
53
+ model: GPT,
54
+ prompt: torch.Tensor,
55
+ max_returned_tokens: int,
56
+ *,
57
+ temperature: float = 1.0,
58
+ top_k: Optional[int] = None,
59
+ eos_id: Optional[int] = None,
60
+ ) -> torch.Tensor:
61
+ """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
62
+
63
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
64
+
65
+ Args:
66
+ model: The model to use.
67
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
68
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
69
+ temperature: Scales the predicted logits by 1 / temperature.
70
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
71
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
72
+ """
73
+ T = prompt.size(0)
74
+ assert max_returned_tokens > T
75
+ if model.max_seq_length < max_returned_tokens - 1:
76
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
77
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
78
+ # not support it to avoid negatively impacting the overall speed
79
+ raise NotImplementedError(
80
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
81
+ )
82
+
83
+ device = prompt.device
84
+ tokens = [prompt]
85
+ input_pos = torch.tensor([T], device=device)
86
+ token = next_token(
87
+ model,
88
+ torch.arange(0, T, device=device),
89
+ prompt.view(1, -1),
90
+ temperature=temperature,
91
+ top_k=top_k,
92
+ ).clone()
93
+ tokens.append(token)
94
+ for _ in range(2, max_returned_tokens - T + 1):
95
+ token = next_token(
96
+ model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k
97
+ ).clone()
98
+ tokens.append(token)
99
+ if token == eos_id:
100
+ break
101
+ input_pos = input_pos.add_(1)
102
+ return torch.cat(tokens)
103
+
104
+
105
+ """
106
+ quantize (Optional[Literal[&quot;bnb.nf4&quot;, &quot;bnb.nf4, optional): quantization method to use. Defaults to None.
107
+ - "bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq": 4-bit quantization bitsandbytes
108
+ - "bnb.int8": 8-bit quantization bitsandbytes
109
+ - "gptq.int4": 4-bit quantization GPTQ
110
+ for more details see: https://github.com/facebookresearch/bitsandbytes, https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
111
+ strategy (str, optional): Fabric strategy setting. Defaults to "auto".
112
+ devices (int, optional): number of devices to be used. Defaults to 1.
113
+ precision (Optional[str], optional): fabic precision settings. Defaults to None.
114
+ """
115
+
116
+ chptk_path: str = "saved_model/last-iter-015000-ckpt.pth"
117
+ tokenizer_path: str = "tokenizer_Llama-2-7b-chat-hf"
118
+ quantize: Optional[
119
+ Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8", "gptq.int4"]
120
+ ] = None
121
+ strategy: str = "auto"
122
+ devices: int = 1
123
+ precision: Optional[str] = None
124
+
125
+ precision = precision or get_default_supported_precision(training=False)
126
+ plugins = None
127
+ if quantize is not None:
128
+ if devices > 1:
129
+ raise NotImplemented("Multi-GPU quantization is not supported yet.")
130
+ if quantize.startswith("bnb."):
131
+ if "mixed" in precision:
132
+ raise ValueError("Quantization and mixed precision is not supported.")
133
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[
134
+ precision
135
+ ]
136
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
137
+ precision = None
138
+
139
+ if strategy == "fsdp":
140
+ strategy = FSDPStrategy(auto_wrap_policy={Block}, cpu_offload=False)
141
+
142
+ fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, plugins=plugins)
143
+ fabric.launch()
144
+
145
+ tokenizer = Tokenizer(Path("tokenizer_Llama-2-7b-chat-hf"))
146
+ config = Config.from_name("pythia-160m")
147
+
148
+ fabric.print(f"Loading model from {chptk_path}", file=sys.stderr)
149
+ t0 = time.perf_counter()
150
+ with fabric.init_module(empty_init=True), gptq_quantization(quantize == "gptq.int4"):
151
+ model = GPT(config)
152
+ fabric.print(
153
+ f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
154
+ )
155
+ with fabric.init_tensor():
156
+ # enable the kv cache
157
+ model.set_kv_cache(batch_size=1)
158
+
159
+ model.eval()
160
+ model = fabric.setup_module(model)
161
+
162
+ t0 = time.perf_counter()
163
+ load_checkpoint(fabric, model, chptk_path)
164
+ fabric.print(
165
+ f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", file=sys.stderr
166
+ )
167
+
168
+
169
+ def generate_from_prompt(
170
+ prompt: str = "",
171
+ max_new_tokens: int = 500,
172
+ top_k: int = 200,
173
+ temperature: float = 0.8,
174
+ ):
175
+ """Generate text from a prompt using pre-trained model
176
+
177
+ Args:
178
+ prompt (str, optional): Prompt string to be used for generating samples. Defaults to "".
179
+ num_samples (int, optional): Number of samples to be generated. Defaults to 1.
180
+ max_new_tokens (int, optional): number of generation steps to take. Defaults to 500.
181
+ top_k (int, optional): top most preferable tokens to consider in the sampling process. Defaults to 200.
182
+ temperature (float, optional): Control randomness for sampelling process. Defaults to 0.8.
183
+ """
184
+ encoded = tokenizer.encode(prompt, device=fabric.device)
185
+ prompt_length = encoded.size(0)
186
+ max_returned_tokens = prompt_length + max_new_tokens
187
+ with fabric.init_tensor():
188
+ # set the max_seq_length to limit the memory usage to what we need
189
+ model.max_seq_length = max_returned_tokens
190
+
191
+ num_samples: int = 1
192
+ for i in range(num_samples):
193
+ t0 = time.perf_counter()
194
+ y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k)
195
+ t = time.perf_counter() - t0
196
+ # for block in model.transformer.h:
197
+ # block.attn.kv_cache.reset_parameters()
198
+ pred = tokenizer.decode(y)
199
+ fabric.print(pred)
200
+ tokens_generated = y.size(0) - prompt_length
201
+ fabric.print(
202
+ f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec",
203
+ file=sys.stderr,
204
+ )
205
+ if fabric.device.type == "cuda":
206
+ fabric.print(
207
+ f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr
208
+ )
209
+
210
+ return pred
saved_model/last-iter-015000-ckpt.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3a07dec62fbdcda721f26f4581e736995ee0c195f46ce2f9587a20e969a996c
3
+ size 1948052114
tokenizer_Llama-2-7b-chat-hf/generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "temperature": 0.6,
8
+ "top_p": 0.9,
9
+ "transformers_version": "4.32.0.dev0"
10
+ }
tokenizer_Llama-2-7b-chat-hf/lit_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"name": "Llama-2-7b-chat-hf", "hf_config": {"org": "meta-llama", "name": "Llama-2-7b-chat-hf"}, "block_size": 4096, "vocab_size": 32000, "padding_multiple": 64, "padded_vocab_size": 32000, "n_layer": 32, "n_head": 32, "n_embd": 4096, "rotary_percentage": 1.0, "parallel_residual": false, "bias": false, "lm_head_bias": false, "n_query_groups": 32, "shared_attention_norm": false, "_norm_class": "RMSNorm", "norm_eps": 1e-05, "_mlp_class": "LLaMAMLP", "gelu_approximate": "none", "intermediate_size": 11008, "rope_condense_ratio": 1, "rope_base": 10000}
tokenizer_Llama-2-7b-chat-hf/pytorch_model.bin.index.json ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13476839424
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "pytorch_model-00002-of-00002.bin",
7
+ "model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
8
+ "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
9
+ "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
10
+ "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
11
+ "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
12
+ "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
13
+ "model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
14
+ "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
15
+ "model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
16
+ "model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
17
+ "model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
18
+ "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
19
+ "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
20
+ "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
21
+ "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
22
+ "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
23
+ "model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
24
+ "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
25
+ "model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
26
+ "model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
27
+ "model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
28
+ "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
29
+ "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
30
+ "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
31
+ "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
32
+ "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
33
+ "model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
34
+ "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
35
+ "model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
36
+ "model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
37
+ "model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
38
+ "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
39
+ "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
40
+ "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
41
+ "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
42
+ "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
43
+ "model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
44
+ "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
45
+ "model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
46
+ "model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
47
+ "model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
48
+ "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
49
+ "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
50
+ "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
51
+ "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
52
+ "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
53
+ "model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
54
+ "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
55
+ "model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
56
+ "model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
57
+ "model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
58
+ "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
59
+ "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
60
+ "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
61
+ "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
62
+ "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
63
+ "model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
64
+ "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
65
+ "model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
66
+ "model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
67
+ "model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
68
+ "model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
69
+ "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
70
+ "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
71
+ "model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
72
+ "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
73
+ "model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
74
+ "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
75
+ "model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
76
+ "model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
77
+ "model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
78
+ "model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
79
+ "model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
80
+ "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
81
+ "model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
82
+ "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
83
+ "model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
84
+ "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
85
+ "model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
86
+ "model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
87
+ "model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
88
+ "model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
89
+ "model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
90
+ "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
91
+ "model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
92
+ "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
93
+ "model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
94
+ "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
95
+ "model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
96
+ "model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
97
+ "model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
98
+ "model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
99
+ "model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
100
+ "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
101
+ "model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
102
+ "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
103
+ "model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
104
+ "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
105
+ "model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
106
+ "model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
107
+ "model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
108
+ "model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
109
+ "model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
110
+ "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
111
+ "model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
112
+ "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
113
+ "model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
114
+ "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
115
+ "model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
116
+ "model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
117
+ "model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
118
+ "model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
119
+ "model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
120
+ "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
121
+ "model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
122
+ "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
123
+ "model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
124
+ "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
125
+ "model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
126
+ "model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
127
+ "model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
128
+ "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
129
+ "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
130
+ "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
131
+ "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
132
+ "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
133
+ "model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
134
+ "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
135
+ "model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
136
+ "model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
137
+ "model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
138
+ "model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
139
+ "model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
140
+ "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
141
+ "model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
142
+ "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
143
+ "model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
144
+ "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
145
+ "model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
146
+ "model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
147
+ "model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
148
+ "model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
149
+ "model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
150
+ "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
151
+ "model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
152
+ "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
153
+ "model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
154
+ "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
155
+ "model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
156
+ "model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
157
+ "model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
158
+ "model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
159
+ "model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
160
+ "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
161
+ "model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
162
+ "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
163
+ "model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
164
+ "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
165
+ "model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
166
+ "model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
167
+ "model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
168
+ "model.layers.23.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
169
+ "model.layers.23.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
170
+ "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
171
+ "model.layers.23.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
172
+ "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
173
+ "model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
174
+ "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
175
+ "model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
176
+ "model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
177
+ "model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
178
+ "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
179
+ "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
180
+ "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
181
+ "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
182
+ "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
183
+ "model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
184
+ "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
185
+ "model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
186
+ "model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
187
+ "model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
188
+ "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
189
+ "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
190
+ "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
191
+ "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
192
+ "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
193
+ "model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
194
+ "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
195
+ "model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
196
+ "model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
197
+ "model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
198
+ "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
199
+ "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
200
+ "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
201
+ "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
202
+ "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
203
+ "model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
204
+ "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
205
+ "model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
206
+ "model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
207
+ "model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
208
+ "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
209
+ "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
210
+ "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
211
+ "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
212
+ "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
213
+ "model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
214
+ "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
215
+ "model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
216
+ "model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
217
+ "model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
218
+ "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
219
+ "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
220
+ "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
221
+ "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
222
+ "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
223
+ "model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
224
+ "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
225
+ "model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
226
+ "model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
227
+ "model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
228
+ "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
229
+ "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
230
+ "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
231
+ "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
232
+ "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
233
+ "model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
234
+ "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
235
+ "model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
236
+ "model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
237
+ "model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
238
+ "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
239
+ "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
240
+ "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
241
+ "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
242
+ "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
243
+ "model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
244
+ "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
245
+ "model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
246
+ "model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
247
+ "model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
248
+ "model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
249
+ "model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
250
+ "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
251
+ "model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
252
+ "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
253
+ "model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
254
+ "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
255
+ "model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
256
+ "model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
257
+ "model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
258
+ "model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
259
+ "model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
260
+ "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
261
+ "model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
262
+ "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
263
+ "model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
264
+ "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
265
+ "model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
266
+ "model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
267
+ "model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
268
+ "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
269
+ "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
270
+ "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
271
+ "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
272
+ "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
273
+ "model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
274
+ "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
275
+ "model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
276
+ "model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
277
+ "model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
278
+ "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
279
+ "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
280
+ "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
281
+ "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
282
+ "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
283
+ "model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
284
+ "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
285
+ "model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
286
+ "model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
287
+ "model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
288
+ "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
289
+ "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
290
+ "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
291
+ "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
292
+ "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
293
+ "model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
294
+ "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
295
+ "model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
296
+ "model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
297
+ "model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
298
+ "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
299
+ "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
300
+ "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
301
+ "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
302
+ "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
303
+ "model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
304
+ "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
305
+ "model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
306
+ "model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
307
+ "model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
308
+ "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
309
+ "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
310
+ "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
311
+ "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
312
+ "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
313
+ "model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
314
+ "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
315
+ "model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
316
+ "model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
317
+ "model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
318
+ "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
319
+ "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
320
+ "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
321
+ "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
322
+ "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
323
+ "model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
324
+ "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
325
+ "model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
326
+ "model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
327
+ "model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
328
+ "model.norm.weight": "pytorch_model-00002-of-00002.bin"
329
+ }
330
+ }
tokenizer_Llama-2-7b-chat-hf/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_Llama-2-7b-chat-hf/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_Llama-2-7b-chat-hf/tokenizer_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "chat_template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
13
+ "clean_up_tokenization_spaces": false,
14
+ "eos_token": {
15
+ "__type": "AddedToken",
16
+ "content": "</s>",
17
+ "lstrip": false,
18
+ "normalized": false,
19
+ "rstrip": false,
20
+ "single_word": false
21
+ },
22
+ "legacy": false,
23
+ "model_max_length": 1000000000000000019884624838656,
24
+ "pad_token": null,
25
+ "padding_side": "right",
26
+ "sp_model_kwargs": {},
27
+ "tokenizer_class": "LlamaTokenizer",
28
+ "unk_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<unk>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ }
36
+ }
tsai_gpt/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning_utilities.core.imports import RequirementCache
2
+
3
+ from tsai_gpt.config import Config
4
+ from tsai_gpt.model import GPT
5
+ from tsai_gpt.tokenizer import Tokenizer
6
+
7
+ _LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.1.0.dev0")
8
+ if not bool(_LIGHTNING_AVAILABLE):
9
+ raise ImportError(
10
+ "Lit-GPT requires lightning==2.1. Please run:\n"
11
+ f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}"
12
+ )
13
+
14
+
15
+ __all__ = ["GPT", "Config", "Tokenizer"]
tsai_gpt/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (620 Bytes). View file
 
tsai_gpt/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (606 Bytes). View file
 
tsai_gpt/__pycache__/config.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
tsai_gpt/__pycache__/config.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
tsai_gpt/__pycache__/model.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
tsai_gpt/__pycache__/model.cpython-39.pyc ADDED
Binary file (11.5 kB). View file
 
tsai_gpt/__pycache__/packed_dataset.cpython-39.pyc ADDED
Binary file (7.62 kB). View file
 
tsai_gpt/__pycache__/speed_monitor.cpython-39.pyc ADDED
Binary file (15.6 kB). View file
 
tsai_gpt/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (3.57 kB). View file
 
tsai_gpt/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
tsai_gpt/__pycache__/utils.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
tsai_gpt/__pycache__/utils.cpython-39.pyc ADDED
Binary file (11.8 kB). View file
 
tsai_gpt/config.py ADDED
@@ -0,0 +1,1192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import Any, Literal, Optional, Type, Union
6
+
7
+ import torch
8
+ from typing_extensions import Self
9
+
10
+ from tsai_gpt.utils import find_multiple
11
+
12
+
13
+ @dataclass
14
+ class Config:
15
+ name: str = ""
16
+ hf_config: dict = field(default_factory=dict)
17
+ block_size: int = 4096
18
+ vocab_size: int = 50254
19
+ padding_multiple: int = 512
20
+ padded_vocab_size: Optional[int] = None
21
+ n_layer: int = 16
22
+ n_head: int = 32
23
+ n_embd: int = 4096
24
+ rotary_percentage: float = 0.25
25
+ parallel_residual: bool = True
26
+ bias: bool = True
27
+ lm_head_bias: bool = False
28
+ # to use multi-head attention (MHA), set this to `n_head` (default)
29
+ # to use multi-query attention (MQA), set this to 1
30
+ # to use grouped-query attention (GQA), set this to a value in between
31
+ # Example with `n_head=4`
32
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
33
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
34
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
35
+ # │ │ │ │ │ │ │
36
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
37
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
38
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
39
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
40
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
41
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
42
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
43
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
44
+ # MHA GQA MQA
45
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
46
+ #
47
+ # credit https://arxiv.org/pdf/2305.13245.pdf
48
+ n_query_groups: Optional[int] = None
49
+ shared_attention_norm: bool = False
50
+ _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
51
+ norm_eps: float = 1e-5
52
+ _mlp_class: Literal["GptNeoxMLP", "LLaMAMLP"] = "GptNeoxMLP"
53
+ gelu_approximate: str = "none"
54
+ intermediate_size: Optional[int] = None
55
+ rope_condense_ratio: int = 1
56
+ rope_base: int = 10000
57
+
58
+ def __post_init__(self):
59
+ if not self.name:
60
+ self.name = self.hf_config.get("name", self.name)
61
+
62
+ assert self.n_embd % self.n_head == 0
63
+ self.head_size = self.n_embd // self.n_head
64
+
65
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
66
+ if self.padded_vocab_size is None:
67
+ self.padded_vocab_size = find_multiple(self.vocab_size, self.padding_multiple)
68
+ else:
69
+ # vocab size shouldn't be larger than padded vocab size
70
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
71
+
72
+ # compute the number of query groups
73
+ if self.n_query_groups is not None:
74
+ assert self.n_head % self.n_query_groups == 0
75
+ else:
76
+ self.n_query_groups = self.n_head
77
+
78
+ # compute the intermediate size for MLP if not set
79
+ if self.intermediate_size is None:
80
+ if self._mlp_class == "LLaMAMLP":
81
+ raise ValueError("The config needs to set the `intermediate_size`")
82
+ self.intermediate_size = 4 * self.n_embd
83
+
84
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
85
+
86
+ @classmethod
87
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
88
+ if name not in name_to_config:
89
+ # search through all `config['hf_config']['name']`
90
+ conf_dict = next(config for config in configs if name == config["hf_config"]["name"])
91
+ else:
92
+ conf_dict = name_to_config[name]
93
+
94
+ conf_dict = conf_dict.copy()
95
+ if "condense_ratio" in kwargs: # legacy name
96
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
97
+ conf_dict.update(kwargs)
98
+ return cls(**conf_dict)
99
+
100
+ @classmethod
101
+ def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self:
102
+ with open(path, encoding="utf-8") as fp:
103
+ json_kwargs = json.load(fp)
104
+ if "condense_ratio" in json_kwargs: # legacy name
105
+ json_kwargs["rope_condense_ratio"] = json_kwargs.pop("condense_ratio")
106
+ if "condense_ratio" in kwargs: # legacy name
107
+ kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio")
108
+ if "org" in json_kwargs: # legacy name
109
+ json_kwargs["hf_config"] = {"name": json_kwargs["name"], "org": json_kwargs.pop("org")}
110
+ if "org" in kwargs: # legacy name
111
+ kwargs["hf_config"] = {
112
+ "name": kwargs.get("name", json_kwargs["name"]),
113
+ "org": kwargs.pop("org"),
114
+ }
115
+ json_kwargs.update(kwargs)
116
+ return cls(**json_kwargs)
117
+
118
+ @property
119
+ def mlp_class(self) -> Type:
120
+ # `self._mlp_class` cannot be the type to keep the config json serializable
121
+ import tsai_gpt.model
122
+
123
+ return getattr(tsai_gpt.model, self._mlp_class)
124
+
125
+ @property
126
+ def norm_class(self) -> Type:
127
+ # `self._norm_class` cannot be the type to keep the config json serializable
128
+ if self._norm_class == "RMSNorm":
129
+ from tsai_gpt.rmsnorm import RMSNorm
130
+
131
+ return RMSNorm
132
+ return getattr(torch.nn, self._norm_class)
133
+
134
+
135
+ ########################
136
+ # Stability AI StableLM
137
+ ########################
138
+ configs = [
139
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json
140
+ dict(
141
+ name="stablelm-base-alpha-3b",
142
+ hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b"),
143
+ ),
144
+ # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json
145
+ dict(
146
+ name="stablelm-base-alpha-7b",
147
+ hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"),
148
+ n_head=48,
149
+ n_embd=6144,
150
+ padding_multiple=256,
151
+ ),
152
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json
153
+ dict(
154
+ name="stablelm-tuned-alpha-3b",
155
+ hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"),
156
+ n_head=32,
157
+ ),
158
+ # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json
159
+ dict(
160
+ name="stablelm-tuned-alpha-7b",
161
+ hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"),
162
+ n_head=48,
163
+ n_embd=6144,
164
+ padding_multiple=256,
165
+ ),
166
+ ]
167
+
168
+ ####################
169
+ # EleutherAI Pythia
170
+ ####################
171
+ pythia = [
172
+ # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json
173
+ dict(
174
+ name="pythia-70m",
175
+ hf_config=dict(org="EleutherAI", name="pythia-70m"),
176
+ block_size=2048,
177
+ n_layer=6,
178
+ n_embd=512,
179
+ n_head=8,
180
+ padding_multiple=128,
181
+ ),
182
+ # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json
183
+ dict(
184
+ name="pythia-160m",
185
+ hf_config=dict(org="EleutherAI", name="pythia-160m"),
186
+ block_size=2048,
187
+ n_layer=12,
188
+ n_embd=768,
189
+ n_head=12,
190
+ padding_multiple=128,
191
+ ),
192
+ # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json
193
+ dict(
194
+ name="pythia-410m",
195
+ hf_config=dict(org="EleutherAI", name="pythia-410m"),
196
+ block_size=2048,
197
+ n_layer=24,
198
+ n_embd=1024,
199
+ n_head=16,
200
+ padding_multiple=128,
201
+ ),
202
+ # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json
203
+ dict(
204
+ name="pythia-1b",
205
+ hf_config=dict(org="EleutherAI", name="pythia-1b"),
206
+ block_size=2048,
207
+ n_embd=2048,
208
+ n_head=8,
209
+ padding_multiple=128,
210
+ ),
211
+ # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json
212
+ dict(
213
+ name="pythia-1.4b",
214
+ hf_config=dict(org="EleutherAI", name="pythia-1.4b"),
215
+ block_size=2048,
216
+ n_layer=24,
217
+ n_embd=2048,
218
+ n_head=16,
219
+ padding_multiple=128,
220
+ ),
221
+ # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json
222
+ dict(
223
+ name="pythia-2.8b",
224
+ hf_config=dict(org="EleutherAI", name="pythia-2.8b"),
225
+ block_size=2048,
226
+ n_layer=32,
227
+ n_embd=2560,
228
+ padding_multiple=128,
229
+ ),
230
+ # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json
231
+ dict(
232
+ name="pythia-6.9b",
233
+ hf_config=dict(org="EleutherAI", name="pythia-6.9b"),
234
+ block_size=2048,
235
+ n_layer=32,
236
+ padding_multiple=256,
237
+ ),
238
+ # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json
239
+ dict(
240
+ name="pythia-12b",
241
+ hf_config=dict(org="EleutherAI", name="pythia-12b"),
242
+ block_size=2048,
243
+ n_layer=36,
244
+ n_embd=5120,
245
+ n_head=40,
246
+ ),
247
+ ]
248
+ configs.extend(pythia)
249
+ for c in pythia:
250
+ copy = c.copy()
251
+ copy["name"] = f"{c['name']}-deduped"
252
+ copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped"
253
+ configs.append(copy)
254
+
255
+
256
+ ####################################
257
+ # togethercomputer RedPajama INCITE
258
+ ####################################
259
+ redpajama_incite = [
260
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json
261
+ dict(
262
+ name="RedPajama-INCITE-{}-3B-v1",
263
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1"),
264
+ block_size=2048,
265
+ n_layer=32,
266
+ n_embd=2560,
267
+ padding_multiple=256,
268
+ rotary_percentage=1.0,
269
+ parallel_residual=False,
270
+ ),
271
+ # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json
272
+ dict(
273
+ name="RedPajama-INCITE-7B-{}",
274
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"),
275
+ block_size=2048,
276
+ n_layer=32,
277
+ padding_multiple=256,
278
+ rotary_percentage=1.0,
279
+ parallel_residual=False,
280
+ ),
281
+ # this redirects to the checkpoint above. kept for those who had the old weights already downloaded
282
+ dict(
283
+ name="RedPajama-INCITE-{}-7B-v0.1",
284
+ hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1"),
285
+ block_size=2048,
286
+ n_layer=32,
287
+ padding_multiple=256,
288
+ rotary_percentage=1.0,
289
+ parallel_residual=False,
290
+ ),
291
+ ]
292
+ for c in redpajama_incite:
293
+ for kind in ("Base", "Chat", "Instruct"):
294
+ copy = c.copy()
295
+ copy["name"] = c["name"].format(kind)
296
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
297
+ configs.append(copy)
298
+
299
+
300
+ #################
301
+ # TII UAE Falcon
302
+ #################
303
+ falcon = [
304
+ # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
305
+ dict(
306
+ name="falcon-7b{}",
307
+ hf_config=dict(org="tiiuae", name="falcon-7b{}"),
308
+ block_size=2048,
309
+ vocab_size=65024,
310
+ padded_vocab_size=65024,
311
+ n_layer=32,
312
+ n_head=71,
313
+ n_embd=4544,
314
+ rotary_percentage=1.0,
315
+ n_query_groups=1,
316
+ bias=False,
317
+ # this is not in the config, but in the original model implementation, only for this config
318
+ shared_attention_norm=True,
319
+ ),
320
+ # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json
321
+ dict(
322
+ name="falcon-40b{}",
323
+ hf_config=dict(org="tiiuae", name="falcon-40b{}"),
324
+ block_size=2048,
325
+ vocab_size=65024,
326
+ padded_vocab_size=65024,
327
+ n_layer=60,
328
+ n_head=128,
329
+ n_embd=8192,
330
+ rotary_percentage=1.0,
331
+ n_query_groups=8,
332
+ bias=False,
333
+ ),
334
+ ]
335
+ for c in falcon:
336
+ for kind in ("", "-instruct"):
337
+ copy = c.copy()
338
+ copy["name"] = c["name"].format(kind)
339
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
340
+ configs.append(copy)
341
+
342
+ # https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json
343
+ falcon180b = dict(
344
+ name="falcon-180B{}",
345
+ hf_config=dict(org="tiiuae", name="falcon-180B{}"),
346
+ block_size=2048,
347
+ vocab_size=65024,
348
+ padded_vocab_size=65024,
349
+ n_layer=80,
350
+ n_head=232,
351
+ n_embd=14848,
352
+ rotary_percentage=1.0,
353
+ n_query_groups=8,
354
+ bias=False,
355
+ )
356
+
357
+ for kind in ("", "-chat"):
358
+ copy = falcon180b.copy()
359
+ copy["name"] = falcon180b["name"].format(kind)
360
+ copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
361
+ configs.append(copy)
362
+
363
+
364
+ #############################
365
+ # OpenLM Research Open LLaMA
366
+ #############################
367
+ open_LLaMA = [
368
+ # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json
369
+ dict(
370
+ name="open_llama_3b",
371
+ hf_config=dict(org="openlm-research", name="open_llama_3b"),
372
+ block_size=2048,
373
+ vocab_size=32000,
374
+ padding_multiple=64,
375
+ n_layer=26,
376
+ n_embd=3200,
377
+ rotary_percentage=1.0,
378
+ parallel_residual=False,
379
+ bias=False,
380
+ _norm_class="RMSNorm",
381
+ norm_eps=1e-6,
382
+ _mlp_class="LLaMAMLP",
383
+ intermediate_size=8640,
384
+ ),
385
+ # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json
386
+ dict(
387
+ name="open_llama_7b",
388
+ hf_config=dict(org="openlm-research", name="open_llama_7b"),
389
+ block_size=2048,
390
+ vocab_size=32000,
391
+ padding_multiple=64,
392
+ n_layer=32,
393
+ rotary_percentage=1.0,
394
+ parallel_residual=False,
395
+ bias=False,
396
+ _norm_class="RMSNorm",
397
+ norm_eps=1e-6,
398
+ _mlp_class="LLaMAMLP",
399
+ intermediate_size=11008,
400
+ ),
401
+ # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json
402
+ dict(
403
+ name="open_llama_13b",
404
+ hf_config=dict(org="openlm-research", name="open_llama_13b"),
405
+ block_size=2048,
406
+ vocab_size=32000,
407
+ padding_multiple=64,
408
+ n_layer=40,
409
+ n_head=40,
410
+ n_embd=5120,
411
+ rotary_percentage=1.0,
412
+ parallel_residual=False,
413
+ bias=False,
414
+ _norm_class="RMSNorm",
415
+ norm_eps=1e-6,
416
+ _mlp_class="LLaMAMLP",
417
+ intermediate_size=13824,
418
+ ),
419
+ ]
420
+ configs.extend(open_LLaMA)
421
+
422
+
423
+ ###############
424
+ # LMSYS Vicuna
425
+ ###############
426
+ vicuna = [
427
+ # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json
428
+ dict(
429
+ name="vicuna-7b-v1.3",
430
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"),
431
+ block_size=2048,
432
+ vocab_size=32000,
433
+ padding_multiple=64,
434
+ n_layer=32,
435
+ rotary_percentage=1.0,
436
+ parallel_residual=False,
437
+ bias=False,
438
+ _norm_class="RMSNorm",
439
+ norm_eps=1e-6,
440
+ _mlp_class="LLaMAMLP",
441
+ intermediate_size=11008,
442
+ ),
443
+ # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json
444
+ dict(
445
+ name="vicuna-13b-v1.3",
446
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"),
447
+ block_size=2048,
448
+ vocab_size=32000,
449
+ padding_multiple=64,
450
+ n_layer=40,
451
+ n_head=40,
452
+ n_embd=5120,
453
+ rotary_percentage=1.0,
454
+ parallel_residual=False,
455
+ bias=False,
456
+ _norm_class="RMSNorm",
457
+ norm_eps=1e-6,
458
+ _mlp_class="LLaMAMLP",
459
+ intermediate_size=13824,
460
+ ),
461
+ # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json
462
+ dict(
463
+ name="vicuna-33b-v1.3",
464
+ hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"),
465
+ block_size=2048,
466
+ vocab_size=32000,
467
+ padding_multiple=64,
468
+ n_layer=60,
469
+ n_head=52,
470
+ n_embd=6656,
471
+ rotary_percentage=1.0,
472
+ parallel_residual=False,
473
+ bias=False,
474
+ _norm_class="RMSNorm",
475
+ norm_eps=1e-6,
476
+ _mlp_class="LLaMAMLP",
477
+ intermediate_size=17920,
478
+ ),
479
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json
480
+ dict(
481
+ name="vicuna-7b-v1.5",
482
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"),
483
+ vocab_size=32000,
484
+ padding_multiple=64,
485
+ n_layer=32,
486
+ rotary_percentage=1.0,
487
+ parallel_residual=False,
488
+ bias=False,
489
+ _norm_class="RMSNorm",
490
+ _mlp_class="LLaMAMLP",
491
+ intermediate_size=11008,
492
+ ),
493
+ # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json
494
+ dict(
495
+ name="vicuna-7b-v1.5-16k",
496
+ hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"),
497
+ block_size=16384,
498
+ vocab_size=32000,
499
+ padding_multiple=64,
500
+ n_layer=32,
501
+ rotary_percentage=1.0,
502
+ parallel_residual=False,
503
+ bias=False,
504
+ _norm_class="RMSNorm",
505
+ _mlp_class="LLaMAMLP",
506
+ intermediate_size=11008,
507
+ rope_condense_ratio=4,
508
+ ),
509
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json
510
+ dict(
511
+ name="vicuna-13b-v1.5",
512
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"),
513
+ vocab_size=32000,
514
+ padding_multiple=64,
515
+ n_layer=40,
516
+ n_head=40,
517
+ n_embd=5120,
518
+ rotary_percentage=1.0,
519
+ parallel_residual=False,
520
+ bias=False,
521
+ _norm_class="RMSNorm",
522
+ _mlp_class="LLaMAMLP",
523
+ intermediate_size=13824,
524
+ ),
525
+ # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json
526
+ dict(
527
+ name="vicuna-13b-v1.5-16k",
528
+ hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"),
529
+ block_size=16384,
530
+ vocab_size=32000,
531
+ padding_multiple=64,
532
+ n_layer=40,
533
+ n_head=40,
534
+ n_embd=5120,
535
+ rotary_percentage=1.0,
536
+ parallel_residual=False,
537
+ bias=False,
538
+ _norm_class="RMSNorm",
539
+ _mlp_class="LLaMAMLP",
540
+ intermediate_size=13824,
541
+ rope_condense_ratio=4,
542
+ ),
543
+ ]
544
+ configs.extend(vicuna)
545
+
546
+
547
+ #################
548
+ # LMSYS LongChat
549
+ #################
550
+ long_chat = [
551
+ # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json
552
+ dict(
553
+ name="longchat-7b-16k",
554
+ hf_config=dict(org="lmsys", name="longchat-7b-16k"),
555
+ block_size=16384,
556
+ vocab_size=32000,
557
+ padding_multiple=64,
558
+ n_layer=32,
559
+ rotary_percentage=1.0,
560
+ parallel_residual=False,
561
+ bias=False,
562
+ _norm_class="RMSNorm",
563
+ norm_eps=1e-6,
564
+ _mlp_class="LLaMAMLP",
565
+ intermediate_size=11008,
566
+ rope_condense_ratio=8,
567
+ ),
568
+ # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json
569
+ dict(
570
+ name="longchat-13b-16k",
571
+ hf_config=dict(org="lmsys", name="longchat-13b-16k"),
572
+ block_size=16384,
573
+ vocab_size=32000,
574
+ padding_multiple=64,
575
+ n_layer=40,
576
+ n_head=40,
577
+ n_embd=5120,
578
+ rotary_percentage=1.0,
579
+ parallel_residual=False,
580
+ bias=False,
581
+ _norm_class="RMSNorm",
582
+ norm_eps=1e-6,
583
+ _mlp_class="LLaMAMLP",
584
+ intermediate_size=13824,
585
+ rope_condense_ratio=8,
586
+ ),
587
+ ]
588
+ configs.extend(long_chat)
589
+
590
+
591
+ ######################
592
+ # NousResearch Hermes
593
+ ######################
594
+ nous_research = [
595
+ # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json
596
+ dict(
597
+ name="Nous-Hermes-llama-2-7b",
598
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"),
599
+ padded_vocab_size=32000,
600
+ n_layer=32,
601
+ rotary_percentage=1.0,
602
+ parallel_residual=False,
603
+ bias=False,
604
+ _norm_class="RMSNorm",
605
+ norm_eps=1e-05,
606
+ _mlp_class="LLaMAMLP",
607
+ intermediate_size=11008,
608
+ ),
609
+ # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json
610
+ dict(
611
+ name="Nous-Hermes-13b",
612
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"),
613
+ block_size=2048,
614
+ vocab_size=32000,
615
+ padded_vocab_size=32001,
616
+ n_layer=40,
617
+ n_head=40,
618
+ n_embd=5120,
619
+ rotary_percentage=1.0,
620
+ parallel_residual=False,
621
+ bias=False,
622
+ _norm_class="RMSNorm",
623
+ norm_eps=1e-6,
624
+ _mlp_class="LLaMAMLP",
625
+ intermediate_size=13824,
626
+ ),
627
+ # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b
628
+ dict(
629
+ name="Nous-Hermes-Llama2-13b",
630
+ hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"),
631
+ vocab_size=32000,
632
+ padded_vocab_size=32032,
633
+ n_layer=40,
634
+ n_head=40,
635
+ n_embd=5120,
636
+ rotary_percentage=1.0,
637
+ parallel_residual=False,
638
+ bias=False,
639
+ _norm_class="RMSNorm",
640
+ norm_eps=1e-05,
641
+ _mlp_class="LLaMAMLP",
642
+ intermediate_size=13824,
643
+ ),
644
+ ]
645
+ configs.extend(nous_research)
646
+
647
+
648
+ ###############
649
+ # Meta LLaMA 2
650
+ ###############
651
+ llama_2 = [
652
+ # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json
653
+ dict(
654
+ name="Llama-2-7b{}-hf",
655
+ hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"),
656
+ vocab_size=32000,
657
+ padding_multiple=64,
658
+ n_layer=32,
659
+ rotary_percentage=1.0,
660
+ parallel_residual=False,
661
+ bias=False,
662
+ _norm_class="RMSNorm",
663
+ _mlp_class="LLaMAMLP",
664
+ intermediate_size=11008,
665
+ ),
666
+ # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json
667
+ dict(
668
+ name="Llama-2-13b{}-hf",
669
+ hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"),
670
+ vocab_size=32000,
671
+ padding_multiple=64,
672
+ n_layer=40,
673
+ n_head=40,
674
+ n_embd=5120,
675
+ rotary_percentage=1.0,
676
+ parallel_residual=False,
677
+ bias=False,
678
+ _norm_class="RMSNorm",
679
+ _mlp_class="LLaMAMLP",
680
+ intermediate_size=13824,
681
+ ),
682
+ # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json
683
+ dict(
684
+ name="Llama-2-70b{}-hf",
685
+ hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"),
686
+ vocab_size=32000,
687
+ padding_multiple=64,
688
+ n_layer=80,
689
+ n_head=64,
690
+ n_embd=8192,
691
+ n_query_groups=8,
692
+ rotary_percentage=1.0,
693
+ parallel_residual=False,
694
+ bias=False,
695
+ _norm_class="RMSNorm",
696
+ _mlp_class="LLaMAMLP",
697
+ intermediate_size=28672,
698
+ ),
699
+ ]
700
+ for c in llama_2:
701
+ for kind in ("", "-chat"):
702
+ copy = c.copy()
703
+ copy["name"] = c["name"].format(kind)
704
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
705
+ configs.append(copy)
706
+
707
+
708
+ ##########################
709
+ # Stability AI FreeWilly2
710
+ ##########################
711
+ freewilly_2 = [
712
+ # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json
713
+ dict(
714
+ name="FreeWilly2",
715
+ hf_config=dict(org="stabilityai", name="FreeWilly2"),
716
+ vocab_size=32000,
717
+ padding_multiple=64,
718
+ n_layer=80,
719
+ n_head=64,
720
+ n_embd=8192,
721
+ n_query_groups=8,
722
+ rotary_percentage=1.0,
723
+ parallel_residual=False,
724
+ bias=False,
725
+ _norm_class="RMSNorm",
726
+ _mlp_class="LLaMAMLP",
727
+ intermediate_size=28672,
728
+ )
729
+ ]
730
+ configs.extend(freewilly_2)
731
+
732
+
733
+ ##################
734
+ # Meta Code Llama
735
+ ##################
736
+ code_llama = [
737
+ # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json
738
+ dict(
739
+ name="CodeLlama-7b-hf",
740
+ hf_config=dict(org="codellama", name="CodeLlama-7b-hf"),
741
+ block_size=16384,
742
+ vocab_size=32016,
743
+ padding_multiple=16,
744
+ n_layer=32,
745
+ rotary_percentage=1.0,
746
+ parallel_residual=False,
747
+ bias=False,
748
+ _norm_class="RMSNorm",
749
+ norm_eps=1e-05,
750
+ _mlp_class="LLaMAMLP",
751
+ intermediate_size=11008,
752
+ rope_base=1000000,
753
+ ),
754
+ # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json
755
+ dict(
756
+ name="CodeLlama-13b-hf",
757
+ hf_config=dict(org="codellama", name="CodeLlama-13b-hf"),
758
+ block_size=16384,
759
+ vocab_size=32016,
760
+ padding_multiple=16,
761
+ n_layer=40,
762
+ n_head=40,
763
+ n_embd=5120,
764
+ rotary_percentage=1.0,
765
+ parallel_residual=False,
766
+ bias=False,
767
+ _norm_class="RMSNorm",
768
+ norm_eps=1e-05,
769
+ _mlp_class="LLaMAMLP",
770
+ intermediate_size=13824,
771
+ rope_base=1000000,
772
+ ),
773
+ # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json
774
+ dict(
775
+ name="CodeLlama-34b-hf",
776
+ hf_config=dict(org="codellama", name="CodeLlama-34b-hf"),
777
+ block_size=16384,
778
+ vocab_size=32000,
779
+ padding_multiple=64,
780
+ n_layer=48,
781
+ n_head=64,
782
+ n_embd=8192,
783
+ n_query_groups=8,
784
+ rotary_percentage=1.0,
785
+ parallel_residual=False,
786
+ bias=False,
787
+ _norm_class="RMSNorm",
788
+ norm_eps=1e-05,
789
+ _mlp_class="LLaMAMLP",
790
+ intermediate_size=22016,
791
+ rope_base=1000000,
792
+ ),
793
+ # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json
794
+ dict(
795
+ name="CodeLlama-7b-Python-hf",
796
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"),
797
+ block_size=16384,
798
+ vocab_size=32000,
799
+ padding_multiple=64,
800
+ n_layer=32,
801
+ rotary_percentage=1.0,
802
+ parallel_residual=False,
803
+ bias=False,
804
+ _norm_class="RMSNorm",
805
+ norm_eps=1e-05,
806
+ _mlp_class="LLaMAMLP",
807
+ intermediate_size=11008,
808
+ rope_base=1000000,
809
+ ),
810
+ # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json
811
+ dict(
812
+ name="CodeLlama-13b-Python-hf",
813
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"),
814
+ block_size=16384,
815
+ vocab_size=32000,
816
+ padding_multiple=64,
817
+ n_layer=40,
818
+ n_head=40,
819
+ n_embd=5120,
820
+ rotary_percentage=1.0,
821
+ parallel_residual=False,
822
+ bias=False,
823
+ _norm_class="RMSNorm",
824
+ norm_eps=1e-05,
825
+ _mlp_class="LLaMAMLP",
826
+ intermediate_size=13824,
827
+ rope_base=1000000,
828
+ ),
829
+ # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json
830
+ dict(
831
+ name="CodeLlama-34b-Python-hf",
832
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"),
833
+ block_size=16384,
834
+ vocab_size=32000,
835
+ padding_multiple=64,
836
+ n_layer=48,
837
+ n_head=64,
838
+ n_embd=8192,
839
+ n_query_groups=8,
840
+ rotary_percentage=1.0,
841
+ parallel_residual=False,
842
+ bias=False,
843
+ _norm_class="RMSNorm",
844
+ norm_eps=1e-05,
845
+ _mlp_class="LLaMAMLP",
846
+ intermediate_size=22016,
847
+ rope_base=1000000,
848
+ ),
849
+ # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/tree/main/config.json
850
+ dict(
851
+ name="CodeLlama-7b-Instruct-hf",
852
+ hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"),
853
+ block_size=16384,
854
+ vocab_size=32016,
855
+ padding_multiple=16,
856
+ n_layer=32,
857
+ rotary_percentage=1.0,
858
+ parallel_residual=False,
859
+ bias=False,
860
+ _norm_class="RMSNorm",
861
+ norm_eps=1e-05,
862
+ _mlp_class="LLaMAMLP",
863
+ intermediate_size=11008,
864
+ rope_base=1000000,
865
+ ),
866
+ # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json
867
+ dict(
868
+ name="CodeLlama-13b-Instruct-hf",
869
+ hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"),
870
+ block_size=2048,
871
+ vocab_size=32016,
872
+ padding_multiple=16,
873
+ n_layer=40,
874
+ n_head=40,
875
+ n_embd=5120,
876
+ rotary_percentage=1.0,
877
+ parallel_residual=False,
878
+ bias=False,
879
+ _norm_class="RMSNorm",
880
+ norm_eps=1e-05,
881
+ _mlp_class="LLaMAMLP",
882
+ intermediate_size=13824,
883
+ rope_base=1000000,
884
+ ),
885
+ # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json
886
+ dict(
887
+ name="CodeLlama-34b-Instruct-hf",
888
+ hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"),
889
+ block_size=16384,
890
+ vocab_size=32000,
891
+ padding_multiple=64,
892
+ n_layer=48,
893
+ n_head=64,
894
+ n_embd=8192,
895
+ n_query_groups=8,
896
+ rotary_percentage=1.0,
897
+ parallel_residual=False,
898
+ bias=False,
899
+ _norm_class="RMSNorm",
900
+ norm_eps=1e-05,
901
+ _mlp_class="LLaMAMLP",
902
+ intermediate_size=22016,
903
+ rope_base=1000000,
904
+ ),
905
+ ]
906
+ configs.extend(code_llama)
907
+
908
+
909
+ ########################
910
+ # garage-bAInd Platypus
911
+ ########################
912
+ platypus = [
913
+ # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json
914
+ dict(
915
+ name="Platypus-30B",
916
+ hf_config=dict(org="garage-bAInd", name="Platypus-30B"),
917
+ block_size=2048,
918
+ padded_vocab_size=32000,
919
+ n_layer=60,
920
+ n_head=52,
921
+ n_embd=6656,
922
+ rotary_percentage=1.0,
923
+ parallel_residual=False,
924
+ bias=False,
925
+ _norm_class="RMSNorm",
926
+ norm_eps=1e-06,
927
+ _mlp_class="LLaMAMLP",
928
+ intermediate_size=17920,
929
+ ),
930
+ # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json
931
+ dict(
932
+ name="Platypus2-7B",
933
+ hf_config=dict(org="garage-bAInd", name="Platypus2-7B"),
934
+ padded_vocab_size=32000,
935
+ n_layer=32,
936
+ rotary_percentage=1.0,
937
+ parallel_residual=False,
938
+ bias=False,
939
+ _norm_class="RMSNorm",
940
+ norm_eps=1e-05,
941
+ _mlp_class="LLaMAMLP",
942
+ intermediate_size=11008,
943
+ ),
944
+ # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json
945
+ dict(
946
+ name="Platypus2-13B",
947
+ hf_config=dict(org="garage-bAInd", name="Platypus2-13B"),
948
+ padded_vocab_size=32000,
949
+ n_layer=40,
950
+ n_head=40,
951
+ n_embd=5120,
952
+ rotary_percentage=1.0,
953
+ parallel_residual=False,
954
+ bias=False,
955
+ _norm_class="RMSNorm",
956
+ norm_eps=1e-05,
957
+ _mlp_class="LLaMAMLP",
958
+ intermediate_size=13824,
959
+ ),
960
+ # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json
961
+ dict(
962
+ name="Platypus2-70B",
963
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B"),
964
+ padded_vocab_size=32000,
965
+ n_layer=80,
966
+ n_head=64,
967
+ n_embd=8192,
968
+ rotary_percentage=1.0,
969
+ parallel_residual=False,
970
+ bias=False,
971
+ _norm_class="RMSNorm",
972
+ _mlp_class="LLaMAMLP",
973
+ intermediate_size=28672,
974
+ ),
975
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json
976
+ dict(
977
+ name="Camel-Platypus2-13B",
978
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"),
979
+ padded_vocab_size=32000,
980
+ n_layer=40,
981
+ n_head=40,
982
+ n_embd=5120,
983
+ rotary_percentage=1.0,
984
+ parallel_residual=False,
985
+ bias=False,
986
+ _norm_class="RMSNorm",
987
+ _mlp_class="LLaMAMLP",
988
+ intermediate_size=13824,
989
+ ),
990
+ # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json
991
+ dict(
992
+ name="Camel-Platypus2-70B",
993
+ hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"),
994
+ padded_vocab_size=32000,
995
+ n_layer=80,
996
+ n_head=64,
997
+ n_embd=8192,
998
+ n_query_groups=8,
999
+ rotary_percentage=1.0,
1000
+ parallel_residual=False,
1001
+ bias=False,
1002
+ _norm_class="RMSNorm",
1003
+ _mlp_class="LLaMAMLP",
1004
+ intermediate_size=28672,
1005
+ ),
1006
+ # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json
1007
+ dict(
1008
+ name="Stable-Platypus2-13B",
1009
+ hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"),
1010
+ padded_vocab_size=32000,
1011
+ n_layer=40,
1012
+ n_head=40,
1013
+ n_embd=5120,
1014
+ rotary_percentage=1.0,
1015
+ parallel_residual=False,
1016
+ bias=False,
1017
+ _norm_class="RMSNorm",
1018
+ _mlp_class="LLaMAMLP",
1019
+ intermediate_size=13824,
1020
+ ),
1021
+ # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json
1022
+ dict(
1023
+ name="Platypus2-70B-instruct",
1024
+ hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"),
1025
+ padded_vocab_size=32000,
1026
+ n_layer=80,
1027
+ n_head=64,
1028
+ n_embd=8192,
1029
+ n_query_groups=8,
1030
+ rotary_percentage=1.0,
1031
+ parallel_residual=False,
1032
+ bias=False,
1033
+ _norm_class="RMSNorm",
1034
+ _mlp_class="LLaMAMLP",
1035
+ intermediate_size=28672,
1036
+ ),
1037
+ ]
1038
+ configs.extend(platypus)
1039
+
1040
+
1041
+ ##########################
1042
+ # Stability AI StableCode
1043
+ ##########################
1044
+ stablecode = [
1045
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json
1046
+ dict(
1047
+ name="stablecode-completion-alpha-3b",
1048
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b"),
1049
+ block_size=16384,
1050
+ vocab_size=49152,
1051
+ n_layer=32,
1052
+ n_embd=2560,
1053
+ ),
1054
+ # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json
1055
+ dict(
1056
+ name="stablecode-completion-alpha-3b-4k",
1057
+ hf_config=dict(org="stabilityai", name="stablecode-completion-alpha-3b-4k"),
1058
+ vocab_size=49152,
1059
+ n_layer=32,
1060
+ n_embd=2560,
1061
+ ),
1062
+ # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json
1063
+ dict(
1064
+ name="stablecode-instruct-alpha-3b",
1065
+ hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"),
1066
+ vocab_size=49152,
1067
+ n_layer=32,
1068
+ n_embd=2560,
1069
+ ),
1070
+ ]
1071
+ configs.extend(stablecode)
1072
+
1073
+
1074
+ ##################################
1075
+ # togethercomputer LLaMA-2-7B-32K
1076
+ ##################################
1077
+ together_llama2_32k = [
1078
+ # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json
1079
+ dict(
1080
+ name="LLaMA-2-7B-32K",
1081
+ hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"),
1082
+ vocab_size=32000,
1083
+ padding_multiple=64,
1084
+ n_layer=32,
1085
+ rotary_percentage=1.0,
1086
+ parallel_residual=False,
1087
+ bias=False,
1088
+ _norm_class="RMSNorm",
1089
+ _mlp_class="LLaMAMLP",
1090
+ intermediate_size=11008,
1091
+ rope_condense_ratio=8,
1092
+ )
1093
+ ]
1094
+ configs.extend(together_llama2_32k)
1095
+
1096
+
1097
+ ################
1098
+ # Microsoft Phi
1099
+ ################
1100
+ phi = [
1101
+ # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json
1102
+ dict(
1103
+ name="phi-1_5",
1104
+ hf_config=dict(org="microsoft", name="phi-1_5"),
1105
+ vocab_size=50257,
1106
+ padded_vocab_size=51200,
1107
+ block_size=2048,
1108
+ n_embd=2048,
1109
+ n_layer=24,
1110
+ rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64
1111
+ shared_attention_norm=True,
1112
+ lm_head_bias=True,
1113
+ gelu_approximate="tanh",
1114
+ )
1115
+ ]
1116
+ configs.extend(phi)
1117
+
1118
+
1119
+ #############
1120
+ # Mistral AI
1121
+ #############
1122
+ mistral = [
1123
+ # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
1124
+ dict(
1125
+ name="Mistral-7B-{}v0.1",
1126
+ hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"),
1127
+ padded_vocab_size=32000,
1128
+ block_size=4096, # should be 32768 but sliding window attention is not implemented
1129
+ n_layer=32,
1130
+ n_query_groups=8,
1131
+ rotary_percentage=1.0,
1132
+ parallel_residual=False,
1133
+ bias=False,
1134
+ _norm_class="RMSNorm",
1135
+ norm_eps=1e-05,
1136
+ _mlp_class="LLaMAMLP",
1137
+ intermediate_size=14336,
1138
+ )
1139
+ ]
1140
+ for c in mistral:
1141
+ for kind in ("", "Instruct-"):
1142
+ copy = c.copy()
1143
+ copy["name"] = c["name"].format(kind)
1144
+ copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
1145
+ configs.append(copy)
1146
+
1147
+
1148
+ ############
1149
+ # TinyLlama
1150
+ ############
1151
+ tiny_llama = [
1152
+ dict(
1153
+ name="tiny-llama-1.1b",
1154
+ hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
1155
+ block_size=2048,
1156
+ vocab_size=32000,
1157
+ padding_multiple=64,
1158
+ n_layer=22,
1159
+ n_head=32,
1160
+ n_embd=2048,
1161
+ rotary_percentage=1.0,
1162
+ parallel_residual=False,
1163
+ bias=False,
1164
+ _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm
1165
+ norm_eps=1e-5,
1166
+ _mlp_class="LLaMAMLP",
1167
+ intermediate_size=5632,
1168
+ n_query_groups=4,
1169
+ ),
1170
+ dict(
1171
+ name="tiny-llama-new",
1172
+ hf_config=dict(org="PY007", name="TinyLlama-1.1B-intermediate-step-480k-1T"),
1173
+ block_size=768,
1174
+ vocab_size=32000,
1175
+ padding_multiple=64,
1176
+ n_layer=18,
1177
+ n_head=32,
1178
+ n_embd=1024,
1179
+ rotary_percentage=1.0,
1180
+ parallel_residual=False,
1181
+ bias=False,
1182
+ _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm
1183
+ norm_eps=1e-5,
1184
+ _mlp_class="LLaMAMLP",
1185
+ intermediate_size=5632,
1186
+ n_query_groups=4,
1187
+ ),
1188
+ ]
1189
+ configs.extend(tiny_llama)
1190
+
1191
+
1192
+ name_to_config = {config["name"]: config for config in configs}
tsai_gpt/model.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Full definition of a GPT NeoX Language Model, all of it in this single file.
2
+
3
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
4
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
5
+ """
6
+ import math
7
+ from typing import Any, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing_extensions import Self
12
+
13
+ from tsai_gpt.config import Config
14
+
15
+
16
+ class GPT(nn.Module):
17
+ def __init__(self, config: Config) -> None:
18
+ super().__init__()
19
+ assert config.padded_vocab_size is not None
20
+ self.config = config
21
+
22
+ self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias)
23
+ self.transformer = nn.ModuleDict(
24
+ dict(
25
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
26
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
27
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
28
+ )
29
+ )
30
+ self.max_seq_length = self.config.block_size
31
+ self.mask_cache: Optional[torch.Tensor] = None
32
+
33
+ @property
34
+ def max_seq_length(self) -> int:
35
+ return self._max_seq_length
36
+
37
+ @max_seq_length.setter
38
+ def max_seq_length(self, value: int) -> None:
39
+ """
40
+ When doing inference, the sequences used might be shorter than the model's context length.
41
+ This allows setting a smaller number to avoid allocating unused memory
42
+ """
43
+ if value > self.config.block_size:
44
+ raise ValueError(
45
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
46
+ )
47
+ self._max_seq_length = value
48
+ if not hasattr(self, "cos"):
49
+ # first call
50
+ cos, sin = self.rope_cache()
51
+ self.register_buffer("cos", cos, persistent=False)
52
+ self.register_buffer("sin", sin, persistent=False)
53
+ elif value != self.cos.size(0):
54
+ # override
55
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
56
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
57
+ # if the kv cache is expected
58
+
59
+ def reset_parameters(self) -> None:
60
+ # Trigger resetting the rope-cache
61
+ self.max_seq_length = self.config.block_size
62
+
63
+ def _init_weights(self, module: nn.Module) -> None:
64
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
65
+ if isinstance(module, nn.Linear):
66
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
67
+ if module.bias is not None:
68
+ torch.nn.init.zeros_(module.bias)
69
+ elif isinstance(module, nn.Embedding):
70
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
71
+
72
+ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor:
73
+ T = idx.size(1)
74
+ if self.max_seq_length < T:
75
+ raise ValueError(
76
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
77
+ )
78
+
79
+ if input_pos is not None: # use the kv cache
80
+ cos = self.cos.index_select(0, input_pos)
81
+ sin = self.sin.index_select(0, input_pos)
82
+ if self.mask_cache is None:
83
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
84
+ mask = self.mask_cache.index_select(2, input_pos)
85
+ else:
86
+ cos = self.cos[:T]
87
+ sin = self.sin[:T]
88
+ mask = None
89
+
90
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
91
+ for block in self.transformer.h:
92
+ x = block(x, cos, sin, mask, input_pos)
93
+ x = self.transformer.ln_f(x)
94
+ return self.lm_head(x) # (b, t, vocab_size)
95
+
96
+ @classmethod
97
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
98
+ return cls(Config.from_name(name, **kwargs))
99
+
100
+ def rope_cache(
101
+ self, device: Optional[torch.device] = None
102
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
103
+ return build_rope_cache(
104
+ seq_len=self.max_seq_length,
105
+ n_elem=self.config.rope_n_elem,
106
+ device=device,
107
+ condense_ratio=self.config.rope_condense_ratio,
108
+ base=self.config.rope_base,
109
+ )
110
+
111
+ def set_kv_cache(
112
+ self,
113
+ batch_size: int,
114
+ rope_cache_length: Optional[int] = None,
115
+ device: Optional[torch.device] = None,
116
+ dtype: Optional[torch.dtype] = None,
117
+ ) -> None:
118
+ if rope_cache_length is None:
119
+ rope_cache_length = self.cos.size(-1)
120
+ max_seq_length = self.max_seq_length
121
+
122
+ # initialize the kv cache for all blocks
123
+ for block in self.transformer.h:
124
+ block.attn.kv_cache = block.attn.build_kv_cache(
125
+ batch_size, max_seq_length, rope_cache_length, device, dtype
126
+ )
127
+
128
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
129
+ # passing `attn_mask` to SDPA downgrades it to use the inefficient implementation. since we only need the mask
130
+ # for the kv-cache support (only during inference), we only create it in that situation
131
+ # this will be resolved by https://github.com/pytorch/pytorch/issues/96099
132
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
133
+ self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0)
134
+
135
+ def clear_kv_cache(self) -> None:
136
+ self.mask_cache = None
137
+ for block in self.transformer.h:
138
+ block.attn.kv_cache = None
139
+
140
+
141
+ class Block(nn.Module):
142
+ def __init__(self, config: Config) -> None:
143
+ super().__init__()
144
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
145
+ self.attn = CausalSelfAttention(config)
146
+ self.norm_2 = (
147
+ None
148
+ if config.shared_attention_norm
149
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
150
+ )
151
+ self.mlp = config.mlp_class(config)
152
+
153
+ self.config = config
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ cos: torch.Tensor,
159
+ sin: torch.Tensor,
160
+ mask: Optional[torch.Tensor] = None,
161
+ input_pos: Optional[torch.Tensor] = None,
162
+ ) -> torch.Tensor:
163
+ n_1 = self.norm_1(x)
164
+ h = self.attn(n_1, cos, sin, mask, input_pos)
165
+ if self.config.parallel_residual:
166
+ n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
167
+ x = self.mlp(n_2) + h + x
168
+ else:
169
+ if self.config.shared_attention_norm:
170
+ raise NotImplementedError(
171
+ "No checkpoint amongst the ones we support uses this configuration"
172
+ " (non-parallel residual and shared attention norm)."
173
+ )
174
+ x = h + x
175
+ x = self.mlp(self.norm_2(x)) + x
176
+ return x
177
+
178
+
179
+ class CausalSelfAttention(nn.Module):
180
+ def __init__(self, config: Config) -> None:
181
+ super().__init__()
182
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
183
+ # key, query, value projections for all heads, but in a batch
184
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.bias)
185
+ # output projection
186
+ self.proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
187
+ # disabled by default
188
+ self.kv_cache: Optional[KVCache] = None
189
+
190
+ self.config = config
191
+
192
+ def forward(
193
+ self,
194
+ x: torch.Tensor,
195
+ cos: torch.Tensor,
196
+ sin: torch.Tensor,
197
+ mask: Optional[torch.Tensor] = None,
198
+ input_pos: Optional[torch.Tensor] = None,
199
+ ) -> torch.Tensor:
200
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
201
+
202
+ qkv = self.attn(x)
203
+
204
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
205
+ q_per_kv = self.config.n_head // self.config.n_query_groups
206
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
207
+ qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size)
208
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
209
+
210
+ # split batched computation into three
211
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
212
+
213
+ # maybe repeat k and v if for the non multi-head attention cases
214
+ # training: flash attention requires it
215
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
216
+ if self.config.n_query_groups != self.config.n_head and (
217
+ input_pos is None or self.config.n_query_groups != 1
218
+ ):
219
+ k = k.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
220
+ v = v.expand(B, self.config.n_query_groups, q_per_kv, T, self.config.head_size)
221
+
222
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
223
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
224
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
225
+
226
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
227
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
228
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
229
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
230
+
231
+ if input_pos is not None:
232
+ if not isinstance(self.kv_cache, KVCache):
233
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
234
+ k, v = self.kv_cache(input_pos, k, v)
235
+
236
+ y = self.scaled_dot_product_attention(q, k, v, mask)
237
+
238
+ y = y.reshape(B, T, C) # re-assemble all head outputs side by side
239
+
240
+ # output projection
241
+ return self.proj(y)
242
+
243
+ def scaled_dot_product_attention(
244
+ self,
245
+ q: torch.Tensor,
246
+ k: torch.Tensor,
247
+ v: torch.Tensor,
248
+ mask: Optional[torch.Tensor] = None,
249
+ ) -> torch.Tensor:
250
+ scale = 1.0 / math.sqrt(self.config.head_size)
251
+ y = torch.nn.functional.scaled_dot_product_attention(
252
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
253
+ )
254
+ return y.transpose(1, 2)
255
+
256
+ def build_kv_cache(
257
+ self,
258
+ batch_size: int,
259
+ max_seq_length: int,
260
+ rope_cache_length: Optional[int] = None,
261
+ device: Optional[torch.device] = None,
262
+ dtype: Optional[torch.dtype] = None,
263
+ ) -> "KVCache":
264
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
265
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
266
+ if rope_cache_length is None:
267
+ if self.config.rotary_percentage != 1.0:
268
+ raise TypeError("Please pass the `rope_cache_length=gpt.cos.size(-1)` value")
269
+ k_shape = v_shape
270
+ else:
271
+ k_shape = (
272
+ batch_size,
273
+ heads,
274
+ max_seq_length,
275
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
276
+ )
277
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
278
+
279
+
280
+ class GptNeoxMLP(nn.Module):
281
+ def __init__(self, config: Config) -> None:
282
+ super().__init__()
283
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
284
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
285
+
286
+ self.config = config
287
+
288
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
289
+ x = self.fc(x)
290
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
291
+ return self.proj(x)
292
+
293
+
294
+ class LLaMAMLP(nn.Module):
295
+ def __init__(self, config: Config) -> None:
296
+ super().__init__()
297
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
298
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
299
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
300
+
301
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
302
+ x_fc_1 = self.fc_1(x)
303
+ x_fc_2 = self.fc_2(x)
304
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
305
+ return self.proj(x)
306
+
307
+
308
+ def build_rope_cache(
309
+ seq_len: int,
310
+ n_elem: int,
311
+ device: Optional[torch.device] = None,
312
+ base: int = 10000,
313
+ condense_ratio: int = 1,
314
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
315
+ """Enhanced Transformer with Rotary Position Embedding.
316
+
317
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
318
+ transformers/rope/__init__.py. MIT License:
319
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
320
+ """
321
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
322
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
323
+
324
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
325
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
326
+
327
+ # Calculate the product of position index and $\theta_i$
328
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
329
+
330
+ return torch.cos(idx_theta), torch.sin(idx_theta)
331
+
332
+
333
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
334
+ head_size = x.size(-1)
335
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
336
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
337
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
338
+ roped = (x * cos) + (rotated * sin)
339
+ return roped.type_as(x)
340
+
341
+
342
+ class KVCache(nn.Module):
343
+ def __init__(
344
+ self,
345
+ k_shape: Tuple[int, int, int, int],
346
+ v_shape: Tuple[int, int, int, int],
347
+ device: Optional[torch.device] = None,
348
+ dtype: Optional[torch.dtype] = None,
349
+ ) -> None:
350
+ super().__init__()
351
+ self.register_buffer(
352
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
353
+ )
354
+ self.register_buffer(
355
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
356
+ )
357
+
358
+ def forward(
359
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
360
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
361
+ # move the buffer to the activation dtype for when AMP is used
362
+ self.k = self.k.to(k.dtype)
363
+ self.v = self.v.to(v.dtype)
364
+ # update the cache
365
+ k = self.k.index_copy_(2, input_pos, k)
366
+ v = self.v.index_copy_(2, input_pos, v)
367
+ return k, v
tsai_gpt/packed_dataset.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Very loosely inspired by indexed_dataset in Fairseq, Megatron
2
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py
3
+
4
+
5
+ import os
6
+ import random
7
+ import struct
8
+
9
+ import numpy as np
10
+ import torch
11
+ from torch.utils.data import IterableDataset, get_worker_info
12
+
13
+ dtypes = {
14
+ 1: np.uint8,
15
+ 2: np.int8,
16
+ 3: np.int16,
17
+ 4: np.int32,
18
+ 5: np.int64,
19
+ 6: np.float32,
20
+ 7: np.float64,
21
+ 8: np.uint16,
22
+ }
23
+
24
+
25
+ def code(dtype):
26
+ for k in dtypes:
27
+ if dtypes[k] == dtype:
28
+ return k
29
+ raise ValueError(dtype)
30
+
31
+
32
+ HDR_MAGIC = b"LITPKDS"
33
+ HDR_SIZE = 24 # bytes
34
+
35
+
36
+ class PackedDataset(IterableDataset):
37
+ def __init__(
38
+ self,
39
+ filenames,
40
+ n_chunks,
41
+ block_size,
42
+ seed=12345,
43
+ shuffle=True,
44
+ wrap=False,
45
+ num_processes=1,
46
+ process_rank=0,
47
+ ):
48
+ self._filenames = filenames
49
+ self._n_chunks = n_chunks
50
+ self._block_size = block_size
51
+ self._seed = seed
52
+ self._shuffle = shuffle
53
+ self._wrap = wrap
54
+ self._num_processes = num_processes
55
+ self._process_rank = process_rank
56
+
57
+ def __iter__(self):
58
+ worker_info = get_worker_info()
59
+ num_workers = worker_info.num_workers if worker_info is not None else 1
60
+ worker_id = worker_info.id if worker_info is not None else 0
61
+ num_shards = num_workers * self._num_processes
62
+ shard_id = self._process_rank * num_workers + worker_id
63
+
64
+ max_num_files = len(self._filenames) // num_shards * num_shards
65
+ filenames = self._filenames[shard_id:max_num_files:num_shards]
66
+
67
+ return PackedDatasetIterator(
68
+ filenames=filenames,
69
+ n_chunks=self._n_chunks,
70
+ block_size=self._block_size,
71
+ seed=self._seed,
72
+ shuffle=self._shuffle,
73
+ wrap=self._wrap,
74
+ )
75
+
76
+
77
+ class PackedDatasetBuilder(object):
78
+ def __init__(self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None):
79
+ if dtype == "auto":
80
+ if vocab_size is None:
81
+ raise ValueError("vocab_size cannot be None when dtype='auto'")
82
+ if vocab_size is not None and vocab_size < 65500:
83
+ self._dtype = np.uint16
84
+ else:
85
+ self._dtype = np.int32
86
+ else:
87
+ self._dtype = dtype
88
+ self._counter = 0
89
+ self._chunk_size = chunk_size
90
+ self._outdir = outdir
91
+ self._prefix = prefix
92
+ self._sep_token = sep_token
93
+ self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
94
+ self._arr.fill(self._sep_token)
95
+ self._idx = 0
96
+ self._version = 1
97
+ self._filenames = []
98
+
99
+ def _write_chunk(self):
100
+ filename = f"{self._prefix}_{self._counter:010d}.bin"
101
+ filename = os.path.join(self._outdir, filename)
102
+
103
+ with open(filename, "wb") as f:
104
+ f.write(HDR_MAGIC)
105
+ f.write(struct.pack("<Q", self._version))
106
+ f.write(struct.pack("<B", code(self._dtype)))
107
+ f.write(struct.pack("<Q", self._chunk_size))
108
+ f.write(self._arr.tobytes(order="C"))
109
+
110
+ self._filenames.append(filename)
111
+ self._counter += 1
112
+ self._arr.fill(self._sep_token)
113
+ self._idx = 0
114
+
115
+ @property
116
+ def dtype(self):
117
+ return self._dtype
118
+
119
+ @property
120
+ def filenames(self):
121
+ return self._filenames.copy()
122
+
123
+ def add_array(self, arr):
124
+ while self._idx + arr.shape[0] > self._chunk_size:
125
+ part_len = self._chunk_size - self._idx
126
+ self._arr[self._idx : self._idx + part_len] = arr[:part_len]
127
+ self._write_chunk()
128
+ arr = arr[part_len:]
129
+
130
+ arr_len = arr.shape[0]
131
+ self._arr[self._idx : self._idx + arr_len] = arr
132
+ self._idx += arr_len
133
+
134
+ def write_reminder(self):
135
+ self._write_chunk()
136
+
137
+
138
+ class PackedDatasetIterator:
139
+ def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap):
140
+ self._seed = seed
141
+ self._shuffle = shuffle
142
+ self._rng = np.random.default_rng(seed) if shuffle else None
143
+ self._block_idxs = None
144
+
145
+ self._wrap = wrap
146
+
147
+ # TODO: instead of filenames, we could have a single text stream
148
+ # (or text file) with the sequence of all files to be
149
+ # fetched/loaded.
150
+ self._filenames = filenames
151
+ self._file_idx = 0
152
+
153
+ self._n_chunks = n_chunks
154
+
155
+ self._dtype = None
156
+ self._block_size = block_size
157
+ self._n_blocks = None
158
+
159
+ self._mmaps = []
160
+ self._buffers = []
161
+
162
+ self._block_idxs = []
163
+ self._curr_idx = 0
164
+
165
+ self._load_n_chunks()
166
+
167
+ def _read_header(self, path):
168
+ with open(path, "rb") as f:
169
+ magic = f.read(len(HDR_MAGIC))
170
+ assert magic == HDR_MAGIC, "File doesn't match expected format."
171
+ version = struct.unpack("<Q", f.read(8))
172
+ assert version == (1,)
173
+ (dtype_code,) = struct.unpack("<B", f.read(1))
174
+ dtype = dtypes[dtype_code]
175
+ (chunk_size,) = struct.unpack("<Q", f.read(8))
176
+ return dtype, chunk_size
177
+
178
+ def _close_mmaps(self):
179
+ for mmap in self._mmaps:
180
+ mmap._mmap.close()
181
+
182
+ def _load_n_chunks(self):
183
+ self._close_mmaps()
184
+ self._mmaps = []
185
+ self._buffers = []
186
+
187
+ if self._n_chunks > len(self._filenames[self._file_idx :]):
188
+ if not self._wrap:
189
+ raise StopIteration
190
+ self._file_idx = 0
191
+
192
+ for i in range(self._n_chunks):
193
+ filename = self._filenames[self._file_idx + i]
194
+ if self._dtype is None:
195
+ self._dtype, self._chunk_size = self._read_header(filename)
196
+ self._n_blocks = self._chunk_size // self._block_size
197
+ # TODO: check header matches with previous files
198
+ mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
199
+ self._mmaps.append(mmap)
200
+ self._buffers.append(memoryview(mmap))
201
+
202
+ self._file_idx += self._n_chunks
203
+ n_all_blocks = self._n_chunks * self._n_blocks
204
+
205
+ self._block_idxs = (
206
+ self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks)
207
+ )
208
+
209
+ self._curr_idx = 0
210
+
211
+ def __del__(self):
212
+ self._close_mmaps()
213
+ del self._mmaps
214
+ del self._buffers
215
+
216
+ def __iter__(self):
217
+ return self
218
+
219
+ def __next__(self):
220
+ if self._curr_idx >= len(self._block_idxs):
221
+ self._load_n_chunks()
222
+ # TODO: trigger fetching next next n_chunks if remote
223
+ block_idx = self._block_idxs[self._curr_idx]
224
+ chunk_id = block_idx // self._n_blocks
225
+ buffer = self._buffers[chunk_id]
226
+ elem_id = (block_idx % self._n_blocks) * self._block_size
227
+ offset = np.dtype(self._dtype).itemsize * elem_id
228
+ arr = np.frombuffer(buffer, dtype=self._dtype, count=self._block_size, offset=offset)
229
+ self._curr_idx += 1
230
+ return torch.from_numpy(arr.astype(np.int64))
231
+
232
+
233
+ class CombinedDataset(IterableDataset):
234
+ def __init__(self, datasets, seed, weights=None):
235
+ self._seed = seed
236
+ self._datasets = datasets
237
+ self._weights = weights
238
+ n_datasets = len(datasets)
239
+ if weights is None:
240
+ self._weights = [1 / n_datasets] * n_datasets
241
+
242
+ def __iter__(self):
243
+ return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
244
+
245
+
246
+ class CombinedDatasetIterator:
247
+ def __init__(self, datasets, seed, weights):
248
+ self._datasets = [iter(el) for el in datasets]
249
+ self._weights = weights
250
+ self._rng = random.Random(seed)
251
+
252
+ def __next__(self):
253
+ (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
254
+ return next(dataset)
tsai_gpt/rmsnorm.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class RMSNorm(torch.nn.Module):
5
+ """Root Mean Square Layer Normalization.
6
+
7
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
8
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
9
+ """
10
+
11
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
12
+ super().__init__()
13
+ self.weight = torch.nn.Parameter(torch.ones(size))
14
+ self.eps = eps
15
+ self.dim = dim
16
+
17
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
18
+ dtype = x.dtype
19
+ x = x.float()
20
+ # NOTE: the original RMSNorm paper implementation is not equivalent
21
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
22
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
23
+ return (self.weight * x_normed).to(dtype=dtype)
24
+
25
+ def reset_parameters(self) -> None:
26
+ torch.nn.init.ones_(self.weight)
tsai_gpt/speed_monitor.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from collections import deque
3
+ from contextlib import nullcontext
4
+ from typing import Any, Callable, Deque, Dict, Optional
5
+
6
+ import torch
7
+ from lightning import Callback, Fabric, LightningModule, Trainer
8
+ from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1
9
+ from lightning.fabric.plugins import (BitsandbytesPrecision, DoublePrecision,
10
+ FSDPPrecision, HalfPrecision,
11
+ MixedPrecision, Precision,
12
+ TransformerEnginePrecision, XLAPrecision)
13
+ from lightning.fabric.utilities.rank_zero import \
14
+ rank_zero_only as fabric_rank_zero_only
15
+ from lightning.pytorch.plugins import (DoublePrecisionPlugin,
16
+ FSDPPrecisionPlugin,
17
+ HalfPrecisionPlugin,
18
+ MixedPrecisionPlugin,
19
+ XLAPrecisionPlugin)
20
+ from lightning.pytorch.utilities.rank_zero import \
21
+ rank_zero_only as trainer_rank_zero_only
22
+ from torch.utils.flop_counter import FlopCounterMode
23
+
24
+ from tsai_gpt import GPT
25
+ from tsai_gpt.utils import num_parameters
26
+
27
+ GPU_AVAILABLE_FLOPS = {
28
+ # source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
29
+ # nvidia publishes spec sheet with a 2x sparsity factor
30
+ "h100-sxm": {
31
+ torch.float64: 67e12,
32
+ torch.float32: 67e12,
33
+ torch.bfloat16: 1.979e15 / 2,
34
+ torch.float16: 1.979e15 / 2,
35
+ torch.int8: 3.958e15 / 2,
36
+ },
37
+ "h100-pcie": {
38
+ torch.float64: 51e12,
39
+ torch.float32: 51e12,
40
+ torch.bfloat16: 1.513e15 / 2,
41
+ torch.float16: 1.513e15 / 2,
42
+ torch.int8: 3.026e15 / 2,
43
+ },
44
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf
45
+ # sxm and pcie have same flop counts
46
+ "a100": {
47
+ torch.float64: 19.5e12,
48
+ torch.float32: 19.5e12,
49
+ torch.bfloat16: 312e12,
50
+ torch.float16: 312e12,
51
+ },
52
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a10/pdf/a10-datasheet.pdf
53
+ "a10g": {torch.float32: 31.2e12, torch.bfloat16: 125e12, torch.float16: 125e12},
54
+ # source: https://images.nvidia.com/content/technologies/volta/pdf/volta-v100-datasheet-update-us-1165301-r5.pdf
55
+ "v100-sxm": {torch.float64: 7.8e12, torch.float32: 15.7e12, torch.float16: 125e12},
56
+ "v100-pcie": {torch.float64: 7e12, torch.float32: 14e12, torch.float16: 112e12},
57
+ "v100s-pcie": {torch.float64: 8.2e12, torch.float32: 16.4e12, torch.float16: 130e12},
58
+ # source: https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf
59
+ # sxm and pcie have same flop counts
60
+ "t4": {torch.float32: 8.1e12, torch.float16: 65e12, torch.int8: 130e12},
61
+ # https://www.nvidia.com/content/dam/en-zz/Solutions/design-visualization/quadro-product-literature/quadro-rtx-5000-data-sheet-us-nvidia-704120-r4-web.pdf
62
+ "quadro rtx 5000": {torch.float32: 11.2e12, torch.float16: 89.2e12},
63
+ }
64
+
65
+ TPU_AVAILABLE_FLOPS = {
66
+ # flop count for each TPU generation is the same for all precisions
67
+ # since bfloat16 precision is always used for performing matrix operations
68
+ # for more info: https://cloud.google.com/tpu/docs/bfloat16#choosing_bfloat16
69
+ # source: https://arxiv.org/pdf/1907.10701.pdf
70
+ "v2": 45e12,
71
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3
72
+ "v3": 123e12,
73
+ # source: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v4
74
+ "v4": 275e12,
75
+ # source: https://cloud.google.com/tpu/docs/v5e-training
76
+ "v5litepod": 197e12,
77
+ }
78
+
79
+
80
+ def get_flops_available(device: torch.device, dtype: torch.dtype) -> Optional[float]:
81
+ if device.type == "cuda":
82
+ device_name = torch.cuda.get_device_name(device).lower()
83
+ if "h100" in device_name and "hbm3" in device_name:
84
+ device_name = "h100-sxm"
85
+ elif "h100" in device_name and ("pcie" in device_name or "hbm2e" in device_name):
86
+ device_name = "h100-pcie"
87
+ elif "a100" in device_name:
88
+ device_name = "a100"
89
+ elif "a10g" in device_name:
90
+ device_name = "a10g"
91
+ elif "v100-sxm" in device_name:
92
+ device_name = "v100-sxm"
93
+ elif "v100-pcie" in device_name:
94
+ device_name = "v100-pcie"
95
+ elif "t4" in device_name:
96
+ device_name = "t4"
97
+ elif "quadro rtx 5000" in device_name:
98
+ device_name = "quadro rtx 5000"
99
+ else:
100
+ device_name = None
101
+
102
+ if device_name is not None:
103
+ try:
104
+ return int(GPU_AVAILABLE_FLOPS[device_name][dtype])
105
+ except KeyError:
106
+ raise KeyError(
107
+ f"flop count not found for {device_name} with dtype: {dtype}; "
108
+ "MFU cannot be calculated and reported."
109
+ )
110
+ elif device.type == "xla":
111
+ if _XLA_GREATER_EQUAL_2_1:
112
+ from torch_xla._internal import tpu
113
+ else:
114
+ from torch_xla.experimental import tpu
115
+
116
+ device_name = tpu.get_tpu_env()["TYPE"].lower()
117
+ try:
118
+ return int(TPU_AVAILABLE_FLOPS[device_name])
119
+ except KeyError:
120
+ raise KeyError(
121
+ f"flop count not found for {device_name} with dtype: {dtype}; MFU cannot be calculated and reported."
122
+ )
123
+
124
+ return None
125
+
126
+
127
+ # Adapted from https://github.com/mosaicml/composer/blob/f2a2dc820cb75023b9eb7c46fdfd25273712abd0/composer/callbacks/speed_monitor.py
128
+
129
+
130
+ class SpeedMonitorBase:
131
+ """Logs the training throughput and utilization.
132
+
133
+ +-------------------------------------+-----------------------------------------------------------+
134
+ | Key | Logged data |
135
+ +=====================================+===========================================================+
136
+ | | Rolling average (over `window_size` most recent |
137
+ | `throughput/batches_per_sec` | batches) of the number of batches processed per second |
138
+ | | |
139
+ +-------------------------------------+-----------------------------------------------------------+
140
+ | | Rolling average (over `window_size` most recent |
141
+ | `throughput/samples_per_sec` | batches) of the number of samples processed per second |
142
+ | | |
143
+ +-------------------------------------+-----------------------------------------------------------+
144
+ | | Rolling average (over `window_size` most recent |
145
+ | `throughput/tokens_per_sec` | batches) of the number of tokens processed per second. |
146
+ | | This may include padding depending on dataset |
147
+ +-------------------------------------+-----------------------------------------------------------+
148
+ | | Estimates flops by `flops_per_batch * batches_per_sec` |
149
+ | `throughput/flops_per_sec` | |
150
+ | | |
151
+ +-------------------------------------+-----------------------------------------------------------+
152
+ | `throughput/device/batches_per_sec` | `throughput/batches_per_sec` divided by world size |
153
+ +-------------------------------------+-----------------------------------------------------------+
154
+ | `throughput/device/samples_per_sec` | `throughput/samples_per_sec` divided by world size |
155
+ +-------------------------------------+-----------------------------------------------------------+
156
+ | | `throughput/tokens_per_sec` divided by world size. This |
157
+ | `throughput/device/tokens_per_sec` | may include pad tokens depending on dataset |
158
+ | | |
159
+ +-------------------------------------+-----------------------------------------------------------+
160
+ | | `throughput/flops_per_sec` divided by world size. Only |
161
+ | `throughput/device/flops_per_sec` | logged when model has attribute `flops_per_batch` |
162
+ | | |
163
+ +-------------------------------------+-----------------------------------------------------------+
164
+ | | `throughput/device/flops_per_sec` divided by world size. |
165
+ | `throughput/device/mfu` | |
166
+ | | |
167
+ +-------------------------------------+-----------------------------------------------------------+
168
+ | `time/train` | Total elapsed training time |
169
+ +-------------------------------------+-----------------------------------------------------------+
170
+ | `time/val` | Total elapsed validation time |
171
+ +-------------------------------------+-----------------------------------------------------------+
172
+ | `time/total` | Total elapsed time (time/train + time/val) |
173
+ +-------------------------------------+-----------------------------------------------------------+
174
+
175
+ Notes:
176
+ - The implementation assumes that devices are homogeneous as it normalizes by the world size.
177
+ - Tokens/sec, flops/sec and MFU do not account for padding tokens if present. We suggest using samples/sec or
178
+ batches/sec to measure throughput under this circumstance.
179
+ - Be careful when comparing MFU numbers across projects, as this will highly depend on the ``flops_per_batch``.
180
+ There is no widespread, realistic, and reliable implementation to compute them.
181
+ We suggest using our ``measure_flops`` function, but many other works will use ``estimated_flops`` which
182
+ will almost always be an overestimate when compared to the true value.
183
+
184
+ Args:
185
+ window_size (int, optional): Number of batches to use for a rolling average of throughput.
186
+ Defaults to 100.
187
+ time_unit (str, optional): Time unit to use for `time` logging. Can be one of
188
+ 'seconds', 'minutes', 'hours', or 'days'. Defaults to 'hours'.
189
+ """
190
+
191
+ def __init__(
192
+ self,
193
+ flops_available: float,
194
+ log_dict: Callable[[Dict, int], None],
195
+ window_size: int = 100,
196
+ time_unit: str = "hours",
197
+ ):
198
+ self.flops_available = flops_available
199
+ self.log_dict = log_dict
200
+
201
+ # Track the batch num samples and wct to compute throughput over a window of batches
202
+ self.history_samples: Deque[int] = deque(maxlen=window_size + 1)
203
+ self.history_wct: Deque[float] = deque(maxlen=window_size + 1)
204
+ self.history_lengths: Deque[int] = deque(maxlen=window_size + 1)
205
+ self.history_flops: Deque[int] = deque(maxlen=window_size + 1)
206
+
207
+ self.divider = 1
208
+ if time_unit == "seconds":
209
+ self.divider = 1
210
+ elif time_unit == "minutes":
211
+ self.divider = 60
212
+ elif time_unit == "hours":
213
+ self.divider = 60 * 60
214
+ elif time_unit == "days":
215
+ self.divider = 60 * 60 * 24
216
+ else:
217
+ raise ValueError(
218
+ f'Invalid time_unit: {time_unit}. Must be one of "seconds", "minutes", "hours", or "days".'
219
+ )
220
+
221
+ # Keep track of time spent evaluating
222
+ self.total_eval_wct = 0.0
223
+ self.step = -1
224
+
225
+ def on_train_batch_end(
226
+ self,
227
+ samples: int, # total samples seen (per device)
228
+ train_elapsed: float, # total training time (seconds)
229
+ world_size: int,
230
+ flops_per_batch: Optional[int] = None, # (per device)
231
+ lengths: Optional[int] = None, # total length of the samples seen (per device)
232
+ ) -> None:
233
+ self.step += 1
234
+ step = self.step
235
+ metrics = {}
236
+
237
+ self.history_samples.append(samples)
238
+ if lengths is not None:
239
+ self.history_lengths.append(lengths)
240
+ # if lengths are passed, there should be as many values as samples
241
+ assert len(self.history_samples) == len(self.history_lengths)
242
+ self.history_wct.append(train_elapsed)
243
+ if len(self.history_wct) == self.history_wct.maxlen:
244
+ elapsed_batches = len(self.history_samples) - 1
245
+ elapsed_samples = self.history_samples[-1] - self.history_samples[0]
246
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
247
+ samples_per_sec = elapsed_samples * world_size / elapsed_wct
248
+ dev_samples_per_sec = elapsed_samples / elapsed_wct
249
+ metrics.update(
250
+ {
251
+ "throughput/batches_per_sec": elapsed_batches * world_size / elapsed_wct,
252
+ "throughput/samples_per_sec": samples_per_sec,
253
+ "throughput/device/batches_per_sec": elapsed_batches / elapsed_wct,
254
+ "throughput/device/samples_per_sec": dev_samples_per_sec,
255
+ }
256
+ )
257
+ if lengths is not None:
258
+ elapsed_lengths = int(self.history_lengths[-1]) - int(self.history_lengths[0])
259
+ avg_length = elapsed_lengths / elapsed_batches
260
+ metrics.update(
261
+ {
262
+ "throughput/tokens_per_sec": samples_per_sec * avg_length,
263
+ "throughput/device/tokens_per_sec": dev_samples_per_sec * avg_length,
264
+ }
265
+ )
266
+
267
+ if flops_per_batch is not None:
268
+ # sum of flops per batch across ranks
269
+ self.history_flops.append(flops_per_batch * world_size)
270
+ if len(self.history_flops) == self.history_flops.maxlen:
271
+ elapsed_flops = sum(self.history_flops) - self.history_flops[0]
272
+ elapsed_wct = self.history_wct[-1] - self.history_wct[0]
273
+ flops_per_sec = elapsed_flops / elapsed_wct
274
+ device_flops_per_sec = flops_per_sec / world_size
275
+ metrics.update(
276
+ {
277
+ "throughput/flops_per_sec": flops_per_sec,
278
+ "throughput/device/flops_per_sec": device_flops_per_sec,
279
+ }
280
+ )
281
+ if self.flops_available:
282
+ metrics["throughput/device/mfu"] = device_flops_per_sec / self.flops_available
283
+
284
+ metrics.update(
285
+ {
286
+ "time/train": train_elapsed / self.divider,
287
+ "time/val": self.total_eval_wct / self.divider,
288
+ "time/total": (train_elapsed + self.total_eval_wct) / self.divider,
289
+ "samples": samples,
290
+ }
291
+ )
292
+
293
+ self.log_dict(metrics, step)
294
+
295
+ def eval_end(self, eval_elapsed: float) -> None:
296
+ self.total_eval_wct += eval_elapsed # seconds
297
+
298
+
299
+ def plugin_to_compute_dtype(plugin: Precision) -> torch.dtype:
300
+ if isinstance(plugin, BitsandbytesPrecision):
301
+ return plugin.dtype
302
+ if isinstance(plugin, (HalfPrecision, MixedPrecision, HalfPrecisionPlugin)):
303
+ return plugin._desired_input_dtype
304
+ if isinstance(plugin, MixedPrecisionPlugin):
305
+ return torch.bfloat16 if plugin.precision == "bf16-mixed" else torch.half
306
+ if isinstance(plugin, (DoublePrecision, DoublePrecisionPlugin)):
307
+ return torch.double
308
+ if isinstance(plugin, (XLAPrecision, XLAPrecisionPlugin)):
309
+ return plugin._desired_dtype
310
+ if isinstance(plugin, TransformerEnginePrecision):
311
+ return torch.int8
312
+ if isinstance(plugin, (FSDPPrecision, FSDPPrecisionPlugin)):
313
+ return plugin.mixed_precision_config.reduce_dtype
314
+ if isinstance(plugin, Precision):
315
+ return torch.float32
316
+ raise NotImplementedError(plugin)
317
+
318
+
319
+ class SpeedMonitorFabric(SpeedMonitorBase):
320
+ def __init__(self, fabric: Fabric, *args: Any, **kwargs: Any) -> None:
321
+ dtype = plugin_to_compute_dtype(fabric.strategy.precision)
322
+ flops_available = get_flops_available(fabric.device, dtype)
323
+ super().__init__(flops_available, fabric.log_dict, *args, **kwargs)
324
+
325
+ @fabric_rank_zero_only
326
+ def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
327
+ super().on_train_batch_end(*args, **kwargs)
328
+
329
+
330
+ class SpeedMonitorCallback(Callback):
331
+ def __init__(self, length_fn: Callable[[Any], int], batch_size: int, **kwargs: Any) -> None:
332
+ super().__init__()
333
+ self.speed_monitor: Optional[SpeedMonitorBase] = None
334
+ self.speed_monitor_kwargs = kwargs
335
+ self.length_fn = length_fn
336
+ self.batch_size = batch_size
337
+ self.eval_t0: int = 0
338
+ self.train_t0: int = 0
339
+ self.total_lengths: int = 0
340
+
341
+ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
342
+ if self.speed_monitor is not None:
343
+ return # already setup
344
+ dtype = plugin_to_compute_dtype(trainer.precision_plugin)
345
+ flops_available = get_flops_available(trainer.strategy.root_device, dtype)
346
+ self.speed_monitor = SpeedMonitorBase(
347
+ flops_available, trainer.logger.log_metrics, **self.speed_monitor_kwargs
348
+ )
349
+
350
+ @trainer_rank_zero_only
351
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
352
+ if trainer.fit_loop._should_accumulate():
353
+ return
354
+
355
+ self.train_t0 = time.perf_counter()
356
+
357
+ @trainer_rank_zero_only
358
+ def on_train_batch_end(
359
+ self,
360
+ trainer: Trainer,
361
+ pl_module: LightningModule,
362
+ outputs: Any,
363
+ batch: Any,
364
+ batch_idx: int,
365
+ ) -> None:
366
+ self.total_lengths += self.length_fn(batch)
367
+ if trainer.fit_loop._should_accumulate():
368
+ return
369
+ train_elapsed = time.perf_counter() - self.train_t0
370
+ assert self.speed_monitor is not None
371
+ iter_num = trainer.fit_loop.total_batch_idx
372
+ assert (measured_flops := pl_module.measured_flops) is not None
373
+ self.speed_monitor.on_train_batch_end(
374
+ (iter_num + 1) * self.batch_size,
375
+ train_elapsed,
376
+ # this assumes that device FLOPs are the same and that all devices have the same batch size
377
+ trainer.world_size,
378
+ flops_per_batch=measured_flops,
379
+ lengths=self.total_lengths,
380
+ )
381
+
382
+ @trainer_rank_zero_only
383
+ def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
384
+ self.eval_t0 = time.perf_counter()
385
+
386
+ @trainer_rank_zero_only
387
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
388
+ eval_elapsed = time.perf_counter() - self.eval_t0
389
+ assert self.speed_monitor is not None
390
+ self.speed_monitor.eval_end(eval_elapsed)
391
+
392
+
393
+ def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
394
+ flops_per_token = (
395
+ 2 * n_params
396
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
397
+ # this assumes that all samples have a fixed length equal to the block size
398
+ # which is most likely false during finetuning
399
+ flops_per_seq = flops_per_token * max_seq_length
400
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
401
+ return flops_per_seq + attn_flops_per_seq
402
+
403
+
404
+ def estimate_flops(model: GPT) -> int:
405
+ """Measures estimated FLOPs for MFU.
406
+
407
+ Refs:
408
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
409
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
410
+ """
411
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
412
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
413
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
414
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
415
+ n_trainable_params = num_parameters(model, requires_grad=True)
416
+ trainable_flops = flops_per_param(
417
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
418
+ )
419
+ # forward + backward + gradients (assumes no gradient accumulation)
420
+ ops_per_step = 3 if model.training else 1
421
+ n_frozen_params = num_parameters(model, requires_grad=False)
422
+ frozen_flops = flops_per_param(
423
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
424
+ )
425
+ # forward + backward
426
+ frozen_ops_per_step = 2 if model.training else 1
427
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
428
+
429
+
430
+ def measure_flops(model: GPT, x: torch.Tensor) -> int:
431
+ """Measures real FLOPs for HFU"""
432
+ flop_counter = FlopCounterMode(model, display=False)
433
+ ctx = nullcontext() if model.training else torch.no_grad()
434
+ with ctx, flop_counter:
435
+ y = model(x)
436
+ if model.training:
437
+ y.sum().backward()
438
+ return flop_counter.get_total_flops()
tsai_gpt/tokenizer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import torch
6
+
7
+
8
+ class Tokenizer:
9
+ def __init__(self, checkpoint_dir: Path) -> None:
10
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
11
+ self.bos_id = None
12
+ self.eos_id = None
13
+
14
+ # some checkpoints have both files, `.model` takes precedence
15
+ if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
16
+ from sentencepiece import SentencePieceProcessor
17
+
18
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
19
+ self.backend = "sentencepiece"
20
+ self.bos_id = self.processor.bos_id()
21
+ self.eos_id = self.processor.eos_id()
22
+
23
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (special_tokens_path := checkpoint_dir / "tokenizer_config.json").is_file():
30
+ with open(special_tokens_path) as fp:
31
+ config = json.load(fp)
32
+ bos_token = config.get("bos_token")
33
+ self.bos_id = self.token_to_id(bos_token) if bos_token is not None else None
34
+ eos_token = config.get("eos_token")
35
+ self.eos_id = self.token_to_id(eos_token) if eos_token is not None else None
36
+ if (special_tokens_path := checkpoint_dir / "generation_config.json").is_file():
37
+ with open(special_tokens_path) as fp:
38
+ config = json.load(fp)
39
+ if self.bos_id is None:
40
+ self.bos_id = config.get("bos_token_id")
41
+ if self.eos_id is None:
42
+ self.eos_id = config.get("eos_token_id")
43
+ else:
44
+ raise NotImplementedError
45
+
46
+ @property
47
+ def vocab_size(self) -> int:
48
+ if self.backend == "huggingface":
49
+ return self.processor.get_vocab_size(with_added_tokens=False)
50
+ if self.backend == "sentencepiece":
51
+ return self.processor.vocab_size()
52
+ raise RuntimeError
53
+
54
+ def token_to_id(self, token: str) -> int:
55
+ if self.backend == "huggingface":
56
+ id_ = self.processor.token_to_id(token)
57
+ elif self.backend == "sentencepiece":
58
+ id_ = self.processor.piece_to_id(token)
59
+ else:
60
+ raise RuntimeError
61
+ if id_ is None:
62
+ raise ValueError(f"token {token!r} not found in the collection.")
63
+ return id_
64
+
65
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
66
+ if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
67
+ return False
68
+ with open(tokenizer_config_path) as fp:
69
+ config = json.load(fp)
70
+ if any(config.get(check, False) for check in ("add_bos_token", "add_prefix_space")):
71
+ return True
72
+ # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True.
73
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
74
+ return (
75
+ config.get("add_bos_token") is None
76
+ and config.get("tokenizer_class") == "LlamaTokenizer"
77
+ )
78
+
79
+ def encode(
80
+ self,
81
+ string: str,
82
+ device: Optional[torch.device] = None,
83
+ bos: Optional[bool] = None,
84
+ eos: bool = False,
85
+ max_length: int = -1,
86
+ ) -> torch.Tensor:
87
+ if self.backend == "huggingface":
88
+ tokens = self.processor.encode(string).ids
89
+ elif self.backend == "sentencepiece":
90
+ tokens = self.processor.encode(string)
91
+ else:
92
+ raise RuntimeError
93
+ if bos or (bos is None and self.use_bos):
94
+ bos_id = self.bos_id
95
+ if bos_id is None:
96
+ raise NotImplementedError("This tokenizer does not have a defined a bos token")
97
+ tokens = [bos_id] + tokens
98
+ if eos:
99
+ tokens = tokens + [self.eos_id]
100
+ if max_length > 0:
101
+ tokens = tokens[:max_length]
102
+ return torch.tensor(tokens, dtype=torch.int, device=device)
103
+
104
+ def decode(self, tensor: torch.Tensor) -> str:
105
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
106
+ return self.processor.decode(tokens)
tsai_gpt/utils.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for training and inference."""
2
+ import math
3
+ import pickle
4
+ import sys
5
+ from contextlib import nullcontext
6
+ from io import BytesIO
7
+ from pathlib import Path
8
+ from typing import (TYPE_CHECKING, ContextManager, Dict, List, Mapping,
9
+ Optional, TypeVar, Union)
10
+
11
+ import lightning as L
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.utils._device
15
+ from lightning.fabric.strategies import FSDPStrategy
16
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
17
+ from torch.serialization import normalize_storage_type
18
+
19
+ if TYPE_CHECKING:
20
+ from model import GPT
21
+
22
+
23
+ def find_multiple(n: int, k: int) -> int:
24
+ assert k > 0
25
+ if n % k == 0:
26
+ return n
27
+ return n + k - (n % k)
28
+
29
+
30
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
31
+ total = 0
32
+ for p in module.parameters():
33
+ if requires_grad is None or p.requires_grad == requires_grad:
34
+ if hasattr(p, "quant_state"):
35
+ # bitsandbytes 4bit layer support
36
+ total += math.prod(p.quant_state[1])
37
+ else:
38
+ total += p.numel()
39
+ return total
40
+
41
+
42
+ def gptq_quantization(enabled: bool = False) -> ContextManager:
43
+ if not enabled:
44
+ return nullcontext()
45
+
46
+ from lightning.fabric.plugins.precision.utils import \
47
+ _ClassReplacementContextManager
48
+ from quantize.gptq import ColBlockQuantizedLinear
49
+
50
+ class QuantizedLinear(ColBlockQuantizedLinear):
51
+ def __init__(self, *args, **kwargs):
52
+ super().__init__(*args, bits=4, tile_cols=-1, **kwargs)
53
+
54
+ return _ClassReplacementContextManager({"torch.nn.Linear": QuantizedLinear})
55
+
56
+
57
+ def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None:
58
+ files = {
59
+ "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(),
60
+ "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(),
61
+ "tokenizer.json OR tokenizer.model": (checkpoint_dir / "tokenizer.json").is_file()
62
+ or (checkpoint_dir / "tokenizer.model").is_file(),
63
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
64
+ }
65
+ if checkpoint_dir.is_dir():
66
+ if all(files.values()):
67
+ # we're good
68
+ return
69
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
70
+ else:
71
+ problem = " is not a checkpoint directory"
72
+
73
+ # list locally available checkpoints
74
+ available = list(Path("checkpoints").glob("*/*"))
75
+ if available:
76
+ options = "\n --checkpoint_dir ".join([""] + [repr(str(p.resolve())) for p in available])
77
+ extra = f"\nYou have downloaded locally:{options}\n"
78
+ else:
79
+ extra = ""
80
+
81
+ error_message = (
82
+ f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
83
+ "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n"
84
+ f"{extra}\nSee all download options by running:\n python scripts/download.py"
85
+ )
86
+ print(error_message, file=sys.stderr)
87
+ raise SystemExit(1)
88
+
89
+
90
+ class SavingProxyForStorage:
91
+ def __init__(self, obj, saver, protocol_version=5):
92
+ self.protocol_version = protocol_version
93
+ self.saver = saver
94
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
95
+ raise TypeError(f"expected storage, not {type(obj)}")
96
+
97
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
98
+ if isinstance(obj, torch.storage.TypedStorage):
99
+ # PT upstream wants to deprecate this eventually...
100
+ storage = obj._untyped_storage
101
+ storage_type_str = obj._pickle_storage_type()
102
+ storage_type = getattr(torch, storage_type_str)
103
+ storage_numel = obj._size()
104
+ else:
105
+ storage = obj
106
+ storage_type = normalize_storage_type(type(obj))
107
+ storage_numel = storage.nbytes()
108
+
109
+ storage_key = saver._write_storage_and_return_key(storage)
110
+ location = torch.serialization.location_tag(storage)
111
+
112
+ self.storage_info = ("storage", storage_type, storage_key, location, storage_numel)
113
+
114
+ def __reduce_ex__(self, protocol_version):
115
+ assert False, "this should be handled with out of band"
116
+
117
+
118
+ class SavingProxyForTensor:
119
+ def __init__(self, tensor, saver, protocol_version=5):
120
+ self.protocol_version = protocol_version
121
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
122
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
123
+ # for Tensors with Python attributes
124
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
125
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
126
+ storage_proxy = SavingProxyForStorage(
127
+ storage, saver, protocol_version=protocol_version
128
+ )
129
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
130
+ else:
131
+ (storage, *other_reduce_args) = reduce_args
132
+ assert isinstance(storage, torch.storage.TypedStorage), "Please check for updates"
133
+ storage_proxy = SavingProxyForStorage(
134
+ storage, saver, protocol_version=protocol_version
135
+ )
136
+ self.reduce_args = (storage_proxy, *other_reduce_args)
137
+
138
+ def __reduce_ex__(self, protocol_version):
139
+ if protocol_version != self.protocol_version:
140
+ raise RuntimeError(
141
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
142
+ )
143
+ return self.reduce_ret_fn, self.reduce_args
144
+
145
+
146
+ class IncrementalPyTorchPickler(pickle.Pickler):
147
+ def __init__(self, saver, *args, **kwargs):
148
+ super().__init__(*args, **kwargs)
149
+ self.storage_dtypes = {}
150
+ self.saver = saver
151
+ self.id_map = {}
152
+
153
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
154
+ def persistent_id(self, obj):
155
+ # FIXME: the docs say that persistent_id should only return a string
156
+ # but torch store returns tuples. This works only in the binary protocol
157
+ # see
158
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
159
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
160
+ if isinstance(obj, SavingProxyForStorage):
161
+ return obj.storage_info
162
+
163
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
164
+ if isinstance(obj, torch.storage.TypedStorage):
165
+ # TODO: Once we decide to break serialization FC, this case
166
+ # can be deleted
167
+ storage = obj._untyped_storage
168
+ storage_dtype = obj.dtype
169
+ storage_type_str = obj._pickle_storage_type()
170
+ storage_type = getattr(torch, storage_type_str)
171
+ storage_numel = obj._size()
172
+
173
+ else:
174
+ storage = obj
175
+ storage_dtype = torch.uint8
176
+ storage_type = normalize_storage_type(type(obj))
177
+ storage_numel = storage.nbytes()
178
+
179
+ # If storage is allocated, ensure that any other saved storages
180
+ # pointing to the same data all have the same dtype. If storage is
181
+ # not allocated, don't perform this check
182
+ if storage.data_ptr() != 0:
183
+ if storage.data_ptr() in self.storage_dtypes:
184
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
185
+ raise RuntimeError(
186
+ "Cannot save multiple tensors or storages that view the same data as different types"
187
+ )
188
+ else:
189
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
190
+
191
+ storage_key = self.id_map.get(storage._cdata)
192
+ if storage_key is None:
193
+ storage_key = self.saver._write_storage_and_return_key(storage)
194
+ self.id_map[storage._cdata] = storage_key
195
+ location = torch.serialization.location_tag(storage)
196
+
197
+ return ("storage", storage_type, storage_key, location, storage_numel)
198
+
199
+ return None
200
+
201
+
202
+ class incremental_save:
203
+ def __init__(self, name):
204
+ self.name = name
205
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
206
+ self.has_saved = False
207
+ self.next_key = 0
208
+
209
+ def __enter__(self):
210
+ return self
211
+
212
+ def store_early(self, tensor):
213
+ if isinstance(tensor, torch.Tensor):
214
+ return SavingProxyForTensor(tensor, self)
215
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
216
+
217
+ def save(self, obj):
218
+ if self.has_saved:
219
+ raise RuntimeError("have already saved")
220
+ # Write the pickle data for `obj`
221
+ data_buf = BytesIO()
222
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
223
+ pickler.dump(obj)
224
+ data_value = data_buf.getvalue()
225
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
226
+ self.has_saved = True
227
+
228
+ def _write_storage_and_return_key(self, storage):
229
+ if self.has_saved:
230
+ raise RuntimeError("have already saved")
231
+ key = self.next_key
232
+ self.next_key += 1
233
+ name = f"data/{key}"
234
+ if storage.device.type != "cpu":
235
+ storage = storage.cpu()
236
+ num_bytes = storage.nbytes()
237
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
238
+ return key
239
+
240
+ def __exit__(self, type, value, traceback):
241
+ self.zipfile.write_end_of_file()
242
+
243
+
244
+ T = TypeVar("T")
245
+
246
+
247
+ def chunked_cross_entropy(
248
+ logits: Union[torch.Tensor, List[torch.Tensor]], targets: torch.Tensor, chunk_size: int = 128
249
+ ) -> torch.Tensor:
250
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
251
+ # the memory usage in fine-tuning settings with low number of parameters.
252
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
253
+ # the memory spike's magnitude
254
+
255
+ # lm_head was chunked (we are fine-tuning)
256
+ if isinstance(logits, list):
257
+ # don't want to chunk cross entropy
258
+ if chunk_size == 0:
259
+ logits = torch.cat(logits, dim=1)
260
+ logits = logits.reshape(-1, logits.size(-1))
261
+ targets = targets.reshape(-1)
262
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
263
+
264
+ # chunk cross entropy
265
+ logit_chunks = [logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits]
266
+ target_chunks = [
267
+ target_chunk.reshape(-1) for target_chunk in targets.split(logits[0].size(1), dim=1)
268
+ ]
269
+ loss_chunks = [
270
+ torch.nn.functional.cross_entropy(
271
+ logit_chunk, target_chunk, ignore_index=-1, reduction="none"
272
+ )
273
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
274
+ ]
275
+ return torch.cat(loss_chunks).mean()
276
+
277
+ # no chunking at all
278
+ logits = logits.reshape(-1, logits.size(-1))
279
+ targets = targets.reshape(-1)
280
+ if chunk_size == 0:
281
+ return torch.nn.functional.cross_entropy(logits, targets, ignore_index=-1)
282
+
283
+ # lm_head wasn't chunked, chunk cross entropy
284
+ logit_chunks = logits.split(chunk_size)
285
+ target_chunks = targets.split(chunk_size)
286
+ loss_chunks = [
287
+ torch.nn.functional.cross_entropy(
288
+ logit_chunk, target_chunk, ignore_index=-1, reduction="none"
289
+ )
290
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
291
+ ]
292
+ return torch.cat(loss_chunks).mean()
293
+
294
+
295
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
296
+ for checkpoint_name, attribute_name in mapping.items():
297
+ full_checkpoint_name = prefix + checkpoint_name
298
+ if full_checkpoint_name in state_dict:
299
+ full_attribute_name = prefix + attribute_name
300
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
301
+ return state_dict
302
+
303
+
304
+ def get_default_supported_precision(training: bool) -> str:
305
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
306
+
307
+ Args:
308
+ training: `-mixed` or `-true` version of the precision to use
309
+
310
+ Returns:
311
+ default precision that is suitable for the task and is supported by the hardware
312
+ """
313
+ from lightning.fabric.accelerators import MPSAccelerator
314
+
315
+ if MPSAccelerator.is_available() or (
316
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
317
+ ):
318
+ return "16-mixed" if training else "16-true"
319
+ return "bf16-mixed" if training else "bf16-true"
320
+
321
+
322
+ def load_checkpoint(
323
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
324
+ ) -> None:
325
+ if isinstance(fabric.strategy, FSDPStrategy):
326
+ fabric.load_raw(checkpoint_path, model, strict=strict)
327
+ else:
328
+ state_dict = lazy_load(checkpoint_path)
329
+ state_dict = state_dict.get("model", state_dict)
330
+ model.load_state_dict(state_dict, strict=strict)
331
+
332
+
333
+ def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int:
334
+ flops_per_token = (
335
+ 2 * n_params
336
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
337
+ # this assumes that all samples have a fixed length equal to the block size
338
+ # which is most likely false during finetuning
339
+ flops_per_seq = flops_per_token * max_seq_length
340
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
341
+ return flops_per_seq + attn_flops_per_seq
342
+
343
+
344
+ def estimate_flops(model: "GPT", training: bool) -> int:
345
+ """Measures estimated FLOPs for MFU.
346
+
347
+ Refs:
348
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
349
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
350
+ """
351
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
352
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
353
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
354
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
355
+ n_trainable_params = num_parameters(model, requires_grad=True)
356
+ trainable_flops = flops_per_param(
357
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_trainable_params
358
+ )
359
+ # forward + backward + gradients (assumes no gradient accumulation)
360
+ ops_per_step = 3 if training else 1
361
+ n_frozen_params = num_parameters(model, requires_grad=False)
362
+ frozen_flops = flops_per_param(
363
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
364
+ )
365
+ # forward + backward
366
+ frozen_ops_per_step = 2 if training else 1
367
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops