dinalt commited on
Commit
3f599c2
1 Parent(s): dff1e8a

Upload model

Browse files
Files changed (5) hide show
  1. README.md +201 -0
  2. config.json +59 -0
  3. generation_config.json +12 -0
  4. model.safetensors +3 -0
  5. modelling_walsh.py +949 -0
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/dinalt/ai_assets/models/walsh_instruct",
3
+ "activation_args": {},
4
+ "activation_cls": "torch.nn.GELU",
5
+ "architectures": [
6
+ "HFCausalModel"
7
+ ],
8
+ "attention_args": {
9
+ "beta": 0.25,
10
+ "dropout": 0.1
11
+ },
12
+ "attention_cls": ".CausalSelfAttention",
13
+ "auto_map": {
14
+ "AutoConfig": "modelling_walsh.Config",
15
+ "AutoModelForCausalLM": "modelling_walsh.HFCausalModel"
16
+ },
17
+ "d_embed": 2048,
18
+ "dim_feedforward": 8192,
19
+ "dropout": 0.1,
20
+ "embdding_cls": "torch.nn.Embedding",
21
+ "embedding_args": {},
22
+ "feedforward_args": {
23
+ "beta": 0.25,
24
+ "bias": true
25
+ },
26
+ "feedforward_cls": ".FeedforwardLayer",
27
+ "head_args": {},
28
+ "head_cls": ".Transformer",
29
+ "init_gain": 1.0,
30
+ "layer_args": {
31
+ "alpha": 2.828427124746
32
+ },
33
+ "layer_cls": ".DeepnetLayer",
34
+ "layer_stack_args": {},
35
+ "layer_stack_cls": ".TransformerLayerStack",
36
+ "loss_function": ".causal_loss",
37
+ "max_sequence_length": 16384,
38
+ "model_type": "walsh-causal-v1",
39
+ "norm_args": {
40
+ "normalized_shape": 2084
41
+ },
42
+ "norm_cls": "torch.nn.LayerNorm",
43
+ "num_attention_heads": 32,
44
+ "num_hidden_layers": 32,
45
+ "output_proj_args": {},
46
+ "output_proj_cls": "torch.nn.Linear",
47
+ "pad_index": null,
48
+ "positional_encoder_args": {
49
+ "d_embed": 2048,
50
+ "gain": 0.3333,
51
+ "max_seq": 16384
52
+ },
53
+ "positional_encoder_cls": ".RSWalshPositionalEncoder",
54
+ "torch_dtype": "bfloat16",
55
+ "transformer_args": {},
56
+ "transformer_cls": ".Transformer",
57
+ "transformers_version": "4.37.2",
58
+ "vocab_size": 32000
59
+ }
generation_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_sample": true,
3
+ "eos_token_id": 3,
4
+ "max_new_tokens": 512,
5
+ "pad_token_id": 0,
6
+ "repetition_penalty": 1.01,
7
+ "temperature": 0.87,
8
+ "top_k": 85,
9
+ "top_p": 0.99,
10
+ "transformers_version": "4.37.2",
11
+ "typical_p": 0.68
12
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edcfe732c6a1b9c12adfe505c30d06962f238a1ebf1c9489790a04b87ff37f6b
3
+ size 3485189432
modelling_walsh.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # See: https://huggingface.co/docs/transformers/custom_models
2
+ from typing import Optional, Tuple, Union
3
+ import math
4
+ import copy
5
+ import sys
6
+ from importlib import import_module
7
+
8
+ import torch
9
+ from torch import nn, Tensor
10
+ import torch.nn.init as init
11
+ from torch.nn import functional as F
12
+ from transformers.modeling_outputs import CausalLMOutput
13
+ from transformers import (
14
+ PreTrainedModel,
15
+ PretrainedConfig,
16
+ AutoConfig,
17
+ AutoModel,
18
+ AutoModelForCausalLM,
19
+ )
20
+
21
+ from transformers.utils import (
22
+ is_flash_attn_2_available,
23
+ is_flash_attn_greater_or_equal_2_10,
24
+ )
25
+
26
+ if is_flash_attn_2_available():
27
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
28
+
29
+ # The model type string to bind.
30
+ model_type = "walsh-causal-v1"
31
+
32
+ class Config(PretrainedConfig):
33
+ model_type = model_type
34
+
35
+ attribute_map = {
36
+ "hidden_size": "d_embed",
37
+ }
38
+
39
+ def __init__(
40
+ # All of these MUST have defaults, even if unused.
41
+ self,
42
+ vocab_size=16000,
43
+ pad_index=None,
44
+ hidden_size=1024,
45
+ num_attention_heads=8,
46
+ num_hidden_layers=6,
47
+ max_sequence_length=2048,
48
+ dim_feedforward = 4096,
49
+ dropout=0.1,
50
+ loss_function = "causal_loss",
51
+
52
+ # Default class to use for each of these components.
53
+ positional_encoder_cls='.PositionalEncoder',
54
+ attention_cls='.CausalSelfAttention',
55
+ activation_cls='torch.nn.ReLU',
56
+ feedforward_cls='.FeedforwardLayer',
57
+ layer_stack_cls='.TransformerLayerStack',
58
+ layer_cls='.PostLayerNorm',
59
+ transformer_cls='.Transformer',
60
+ norm_cls='torch.nn.LayerNorm',
61
+ embdding_cls='torch.nn.Embedding',
62
+ output_proj_cls='torch.nn.Linear',
63
+
64
+ positional_encoder_args={
65
+ 'd_model': 1024,
66
+ 'max_seq_len': 2048,
67
+ },
68
+
69
+ # Arg groups, passed to factory classes above.
70
+ transformer_args=dict(),
71
+ attention_args=dict(),
72
+ feedforward_args=dict(),
73
+ activation_args=dict(),
74
+ norm_args={
75
+ 'normalized_shape': 1024,
76
+ },
77
+ layer_stack_args=dict(),
78
+ layer_args=dict(),
79
+ embedding_args=dict(),
80
+ output_proj_args=dict(),
81
+
82
+ **kwargs,
83
+ ):
84
+ self.vocab_size = vocab_size
85
+ self.pad_index = pad_index
86
+ self.hidden_size = hidden_size
87
+ self.num_attention_heads = num_attention_heads
88
+ self.num_hidden_layers = num_hidden_layers
89
+ self.max_sequence_length = max_sequence_length
90
+ self.loss_function = loss_function
91
+
92
+ self.dim_feedforward = dim_feedforward
93
+ self.dropout = dropout
94
+
95
+ self.positional_encoder_cls = positional_encoder_cls
96
+ self.attention_cls = attention_cls
97
+ self.activation_cls = activation_cls
98
+ self.feedforward_cls = feedforward_cls
99
+ self.layer_stack_cls = layer_stack_cls
100
+ self.layer_cls = layer_cls
101
+ self.transformer_cls = transformer_cls
102
+ self.norm_cls = norm_cls
103
+ self.embdding_cls = embdding_cls
104
+ self.output_proj_cls = output_proj_cls
105
+
106
+ self.positional_encoder_args = positional_encoder_args
107
+ self.transformer_args = transformer_args
108
+ self.attention_args = attention_args
109
+ self.feedforward_args = feedforward_args
110
+ self.activation_args = activation_args
111
+ self.norm_args = norm_args
112
+ self.layer_stack_args = layer_stack_args
113
+ self.layer_args = layer_args
114
+ self.embedding_args = embedding_args
115
+ self.output_proj_args = output_proj_args
116
+
117
+ super().__init__(**kwargs)
118
+
119
+ def causal_loss(logits: Tensor, labels: Tensor, input_ids: Tensor, ignore_index=-100) -> Tensor:
120
+ """
121
+ Compute and return the loss using logits and labels.
122
+ """
123
+ # Shift so that tokens < n predict n
124
+ shift_logits = logits[..., :-1, :].contiguous()
125
+ shift_labels = labels[..., 1:].contiguous()
126
+
127
+ loss = torch.nn.functional.cross_entropy(
128
+ shift_logits.view(-1, shift_logits.size(-1)),
129
+ shift_labels.view(-1),
130
+ ignore_index=ignore_index,
131
+ reduction='mean',
132
+ )
133
+
134
+ return loss.nan_to_num()
135
+
136
+ # Learning to Break the Loop: Analyzing and Mitigating Repetitions for Neural Text Generation
137
+ # https://arxiv.org/abs/2206.02369
138
+ def ditto_loss(logits: Tensor, labels: Tensor, input_ids: Tensor) -> Tensor:
139
+ batch_size, seq_len, vocab_size = logits.shape
140
+ rep_reduce_gamma = 0.5
141
+ ditto_weight = 1.0e5
142
+
143
+ probs = torch.softmax(logits, dim=-1)
144
+ total_loss = None
145
+ for i in range(batch_size):
146
+ context_len = labels[i, 0].item()
147
+ sentence_len = labels[i, 1].item()
148
+ n_repeats = labels[i, 2].item()
149
+
150
+ # For readability
151
+ context_end = context_len
152
+ sentence_start = context_len
153
+ sentence_end = sentence_start + sentence_len
154
+ target_start = sentence_end
155
+
156
+ # Get causal loss for context tokens
157
+ causal_ids = input_ids[i:i+1, :context_end]
158
+ c_loss = causal_loss(
159
+ logits=logits[i:i+1, :context_end],
160
+ labels=causal_ids,
161
+ input_ids=causal_ids
162
+ )
163
+
164
+ # Slice out target probabilities
165
+ target_probs = probs[i , target_start:, :]
166
+
167
+ # Slice out first instance of repeated sentence, detach is (prevents back-prop), repeat in N times,
168
+ # and trim to length of target_probs.
169
+ baseline_probs = probs[i, sentence_start:sentence_end, :].detach().repeat(n_repeats, 1)[:target_probs.size(0), :]
170
+
171
+ # Compute DITTO loss.
172
+ one_minus_probs = torch.clamp((1.0 - torch.abs((target_probs - baseline_probs * rep_reduce_gamma))), min=1e-20)
173
+ r_loss = -torch.log(one_minus_probs).mean() * ditto_weight
174
+
175
+ # Combine repitition and causal loss
176
+ loss = c_loss + r_loss
177
+
178
+ # Add this to the total
179
+ if total_loss is None:
180
+ total_loss = loss
181
+ else:
182
+ total_loss += loss
183
+
184
+ return total_loss / batch_size
185
+
186
+ # Dynamically lookup class name and return factory for class.
187
+ def get_dynamic_class(name):
188
+ try:
189
+ module_path, class_name = name.rsplit('.', 1)
190
+ if module_path == "":
191
+ return getattr(sys.modules[__name__], class_name)
192
+ module = import_module(module_path)
193
+ return getattr(module, class_name)
194
+ except (ImportError, AttributeError) as e:
195
+ raise ImportError(name)
196
+
197
+ # An easily extensible dynamic transformer class
198
+ # Many variations can be specified entirely in the configuration, without touching this code.
199
+ class HFCausalModel(PreTrainedModel):
200
+ config_class = Config
201
+ model_type = 'Transformer'
202
+ supports_gradient_checkpointing = True
203
+ # Presently needs to be manually set to match transformer layer class...
204
+ _no_split_modules = ["DeepNetLayer"]
205
+ _supports_flash_attn_2 = True
206
+ _supports_sdpa = True
207
+
208
+ def __init__(self, config):
209
+ super().__init__(config)
210
+
211
+ self.d_model = config.hidden_size
212
+ self.transformer_head = self._make_transformer(config)
213
+ self.loss_function = get_dynamic_class(config.loss_function)
214
+ self.gradient_checkpointing = False
215
+ self.post_init()
216
+
217
+ def forward(
218
+ self,
219
+ input_ids: Optional[torch.LongTensor] = None,
220
+ attention_mask: Optional[torch.FloatTensor] = None,
221
+ token_type_ids: Optional[torch.LongTensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ labels: Optional[torch.LongTensor] = None,
224
+ output_attentions: Optional[bool] = None,
225
+ output_hidden_states: Optional[bool] = None,
226
+ return_dict: Optional[bool] = None,
227
+ **kwargs,
228
+ ) -> (Tensor, dict[str, Tensor]):
229
+
230
+ if self.gradient_checkpointing and self.training:
231
+ gradient_checkpointing_func = self._gradient_checkpointing_func
232
+ else:
233
+ gradient_checkpointing_func = None
234
+
235
+ logits, attentions = self.transformer_head(
236
+ input_ids=input_ids,
237
+ need_weights=output_attentions,
238
+ gradient_checkpointing_func=gradient_checkpointing_func,
239
+ )
240
+
241
+ # Compute loss.
242
+ if labels is not None:
243
+ loss = self.loss_function(logits=logits, labels=labels, input_ids=input_ids)
244
+ else:
245
+ loss = None
246
+
247
+ return CausalLMOutput(loss=loss, logits=logits, attentions=attentions)
248
+
249
+ # Needed for generate() method.
250
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
251
+ attention_mask = kwargs.get("attention_mask", None)
252
+ model_inputs = {
253
+ "input_ids": input_ids,
254
+ "attention_mask": attention_mask,
255
+ }
256
+ return model_inputs
257
+
258
+ def _make_embedding(self, config):
259
+ embedding_cls = get_dynamic_class(config.embdding_cls)
260
+ return embedding_cls(config.vocab_size, self.d_model, config.pad_index, **config.embedding_args)
261
+
262
+ def _make_pos_encoder(self, config):
263
+ pos_enc_cls = get_dynamic_class(config.positional_encoder_cls)
264
+ return pos_enc_cls(**config.positional_encoder_args)
265
+
266
+ def _make_output_projection(self, config):
267
+ output_proj_cls = get_dynamic_class(config.output_proj_cls)
268
+ return output_proj_cls(self.d_model, config.vocab_size, **config.output_proj_args)
269
+
270
+ def _make_dropout(self, config):
271
+ return nn.Dropout(config.dropout)
272
+
273
+ def _make_activation(self, config):
274
+ activation_cls = get_dynamic_class(config.activation_cls)
275
+ return activation_cls(**config.activation_args)
276
+
277
+ def _make_norm(self, config):
278
+ norm_cls = get_dynamic_class(config.norm_cls)
279
+ return norm_cls(self.d_model)
280
+
281
+ def _make_self_attention(self, config):
282
+ attention_cls = get_dynamic_class(config.attention_cls)
283
+ # Map HF _attn_implementation to attn_type
284
+ match config._attn_implementation:
285
+ case "flash_attention_2":
286
+ if is_flash_attn_2_available():
287
+ if not is_flash_attn_greater_or_equal_2_10():
288
+ raise Exception("flash_attn_2 >= 2.10 is required")
289
+ attn_type = "flash2"
290
+ else:
291
+ attn_type = "torch"
292
+ case "sdpa":
293
+ attn_type = "torch"
294
+ case "eager":
295
+ attn_type = "native"
296
+ case _:
297
+ raise Exception(f"Unimplemented attention type '{config._attn_implementation}'")
298
+ return attention_cls(
299
+ d_model=self.d_model,
300
+ num_heads=config.num_attention_heads,
301
+ attn_type=attn_type,
302
+ **config.attention_args,
303
+ )
304
+
305
+ def _make_feedforward(self, config):
306
+ feedforward_cls = get_dynamic_class(config.feedforward_cls)
307
+ return feedforward_cls(
308
+ d_model=self.d_model,
309
+ feedforward_dim=config.dim_feedforward,
310
+ dropout=config.dropout,
311
+ activation=self._make_activation(config),
312
+ **config.feedforward_args,
313
+ )
314
+
315
+ def _make_layer(self, config):
316
+ layer_cls = get_dynamic_class(config.layer_cls)
317
+ return layer_cls(
318
+ d_model=self.d_model,
319
+ dropout=self._make_dropout(config),
320
+ attention=self._make_self_attention(config),
321
+ feedforward=self._make_feedforward(config),
322
+ norm1=self._make_norm(config),
323
+ norm2=self._make_norm(config),
324
+ **config.layer_args,
325
+ )
326
+
327
+ def _make_layer_stack(self, config):
328
+ layer_stack_cls = get_dynamic_class(config.layer_stack_cls)
329
+ return layer_stack_cls(
330
+ layers=nn.ModuleList([
331
+ self._make_layer(config) for _ in range(config.num_hidden_layers)
332
+ ]),
333
+ **config.layer_stack_args,
334
+ )
335
+
336
+ def _make_transformer(self, config):
337
+ transformer_cls = get_dynamic_class(config.transformer_cls)
338
+ return transformer_cls(
339
+ d_model=self.d_model,
340
+ embedding=self._make_embedding(config),
341
+ positional_encoder=self._make_pos_encoder(config),
342
+ layer_stack=self._make_layer_stack(config),
343
+ output_projection=self._make_output_projection(config),
344
+ **config.transformer_args,
345
+ )
346
+
347
+ @torch.no_grad()
348
+ def _init_weights(self, module):
349
+ pass
350
+
351
+ # Register model type and configuration
352
+ AutoConfig.register(model_type, Config)
353
+ AutoModelForCausalLM.register(Config, HFCausalModel)
354
+
355
+ # A generic container class for standard transformer components.
356
+ class Transformer(nn.Module):
357
+ def __init__(self, d_model, embedding, positional_encoder, layer_stack, output_projection, **kwargs):
358
+ super().__init__()
359
+ self.embedding = embedding
360
+ self.positional_encoder = positional_encoder
361
+ self.layer_stack = layer_stack
362
+ self.output_projection = output_projection
363
+ self.d_model = d_model
364
+ self.sqrt_d_model = d_model**0.5
365
+ self.reset_parameters()
366
+
367
+ def forward(self, input_ids, need_weights, gradient_checkpointing_func):
368
+ x = self.positional_encoder(self.embedding(input_ids) * self.sqrt_d_model)
369
+
370
+ x, attentions = self.layer_stack(
371
+ x,
372
+ need_weights,
373
+ gradient_checkpointing_func,
374
+ )
375
+
376
+ # Translate output embedding ot logits.
377
+ logits = self.output_projection(x)
378
+ return logits, attentions
379
+
380
+ def reset_parameters(self):
381
+ init.xavier_uniform_(self.output_projection.weight)
382
+ init.constant_(self.output_projection.bias, 0.)
383
+ init.normal_(self.embedding.weight, std=self.d_model**-0.5)
384
+
385
+ # A vanilla positional encoder
386
+ class PositionalEncoder(nn.Module):
387
+ def __init__(self, d_embed, max_seq):
388
+ super().__init__()
389
+ self.d_embed = d_embed
390
+ self.max_seq = max_seq
391
+
392
+ weight = torch.zeros(max_seq, d_embed)
393
+ position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(1)
394
+ div_term = torch.exp(torch.arange(0, d_embed, 2).float() * (-math.log(10000.0) / d_embed))
395
+ weight[:, 0::2] = torch.sin(position * div_term)
396
+ weight[:, 1::2] = torch.cos(position * div_term)
397
+ weight = weight.unsqueeze(0)
398
+ self.register_buffer('weight', weight)
399
+
400
+ def forward(self, x):
401
+ seq_len = x.size(-2)
402
+ return x + self.weight[:, :seq_len]
403
+
404
+ # Converts a torch array of integers into their equivalent binary codes.
405
+ def binary_tensor(x, bits):
406
+ mask = 2**torch.arange(bits).to(x.device, x.dtype)
407
+ return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte()
408
+
409
+ def hadamard_walsh_matrix(k: int):
410
+ # k: The dimension of the matrix is 2^k
411
+ assert k > 0
412
+
413
+ # Start with Hadamard H2^1 matrix.
414
+ h1 = torch.tensor([[1, 1], [1, -1]], dtype=torch.float)
415
+
416
+ # The series of matrices can be computed by recurisvely applying the Kronecker product,
417
+ # starting with h1.
418
+ #
419
+ # This will produce the series of Hadamard-Wlash matrices in natural order.
420
+ w = h1
421
+ for _ in range(k-1):
422
+ w = torch.kron(h1, w)
423
+
424
+ return w
425
+
426
+ # This positional encoder adds absolute binary positions to the embedding, encoded via
427
+ # Hadamard-Walsh matrix.
428
+ # See: https://en.wikipedia.org/wiki/Hadamard_code
429
+ # Each bit in the binary code word is encoded via a row the Hadamard-Walsh matrix, with a
430
+ # 1 being encoded by the presense of the row and a 0 by its absence. While training, the base
431
+ # sequence offset is randomly selected, which appears to allow the model to generalize to
432
+ # sequences longer than it was trained on. This is similar to what is described here:
433
+ # https://arxiv.org/pdf/2305.16843.pdf
434
+ # I have tried this approach and found that my approach works better for generalization.
435
+ #
436
+ # Note: Without random shifting, the early performance of this encoder is exceptionally good.
437
+ # The drawback is that the model can't generalize to longer sequences than it was trained on
438
+ # and can't easily accomidate additonal bits later in the training process.
439
+ class RSWalshPositionalEncoder(nn.Module):
440
+ def __init__(self, d_embed, max_seq, gain=0.333):
441
+ super().__init__()
442
+ self.max_seq = max_seq
443
+ self.d_embed = d_embed
444
+
445
+ # Hadamard-Walsh k, where the dimension of the matrix is 2^k
446
+ k = math.ceil(math.log2(d_embed))
447
+
448
+ # The number of bits required to encode max_seq
449
+ bits = math.ceil(math.log2(max_seq))
450
+
451
+ # Gain controls the weight given to the encodings.
452
+ # When a trainable parameter, the value appears to settle at around 0.333.
453
+ self.gain = gain
454
+
455
+ assert bits <= d_embed, "max_seq exceeds n-bits available for d_embed"
456
+
457
+ # Generate sequential binary codes for absolute positionals.
458
+ # The implementation originally used Grey codes, which where successive symbols
459
+ # differ by by only one bit. See: https://en.wikipedia.org/wiki/Gray_code
460
+ # This, along with a few other coding schemes were tested, with a simple
461
+ # binary code having the best performance.
462
+ binary_code = binary_tensor(torch.arange(0, max_seq, 1), bits)
463
+ self.register_buffer('binary_code', binary_code, persistent=False)
464
+
465
+ # Each bit is encoded via a row of a Hadamard-Walsh matrix.
466
+ # We slice off the unused rows and columns -- ideally, d_embed should be
467
+ # the same dimension as the matrix.
468
+ walsh = hadamard_walsh_matrix(k)[:bits,:d_embed] * self.gain
469
+
470
+ # This alternative appears superior to the original.
471
+ # If starting from scratch, this use this.
472
+ # walsh = (hadamard_walsh_matrix(k)[:bits,:d_embed] -0.5) * self.gain
473
+ self.register_buffer('walsh', walsh, persistent=False)
474
+
475
+ def forward(self, x):
476
+ seq_len = x.size(-2)
477
+
478
+ # Get sequence of binary codes...
479
+ # We use a random base offset when training.
480
+ # This results in slower initial gains, but appears to allow the model to generalize to
481
+ # the value of max_seq, even if never trained with sequences of this length. I also have
482
+ # a suspicion that this has a regularizing effect on training, similar to dropout. Models with
483
+ # random base offset shifting, despite slower initial improvement, appear to perform better in the long-run.
484
+ # TODO: Setup a controlled experiment to test this hypothesis.
485
+ if self.training:
486
+ shift = torch.randint(self.max_seq - seq_len + 1, (1,)).item()
487
+ seq = self.binary_code[shift:seq_len + shift,:]
488
+
489
+ # Disable shifting when not training. This does not appear to change the evaluation loss, but
490
+ # it does makes predictions easier to analyse when the attention weights are not shifting with each step.
491
+ else:
492
+ seq = self.binary_code[:seq_len,:]
493
+
494
+ # For reasons I have yet to identify, when the model is running in Textgenwebui, the matrix appears
495
+ # to evade conversion to bfloat16, despite everything else having been converted.
496
+ # This is a work-around for this.
497
+ self.walsh = self.walsh.to(dtype=x.dtype)
498
+
499
+ # Encode binary sequence with Hadamard-Walsh codes and apply to embeddings.
500
+ # If nothing else, the Walsh encodings make the positional information exceptionally
501
+ # robust with respect to dropout and other adversities. They can still be easily detected
502
+ # at the final layer.
503
+ return x + (seq.to(dtype=x.dtype) @ self.walsh)
504
+
505
+ # A generic stack of transformer layers.
506
+ class TransformerLayerStack(nn.Module):
507
+ def __init__(self, layers):
508
+ super().__init__()
509
+ self.layers = layers
510
+
511
+ def forward(self, x, need_weights, gradient_checkpointing_func=None):
512
+ attentions = []
513
+ for layer in self.layers:
514
+ if gradient_checkpointing_func is not None:
515
+ x, attention_weights = gradient_checkpointing_func(
516
+ layer.__call__,
517
+ x,
518
+ need_weights,
519
+ use_reentrant=False
520
+ )
521
+ else:
522
+ x, attention_weights = layer(x, need_weights=need_weights)
523
+ if need_weights:
524
+ attentions.append(attention_weights)
525
+
526
+ return x, attentions
527
+
528
+ # DeepNet: Scaling Transformers to 1,000 Layers
529
+ # https://arxiv.org/abs/2203.00555
530
+ class DeepnetLayer(nn.Module):
531
+ def __init__(
532
+ self,
533
+ d_model,
534
+ attention,
535
+ feedforward,
536
+ norm1,
537
+ norm2,
538
+ dropout,
539
+ alpha=1.0,
540
+ ):
541
+ super().__init__()
542
+ self.d_model = d_model
543
+ self.attention = attention
544
+ self.feedforward = feedforward
545
+ self.norm1 = norm1
546
+ self.norm2 = norm2
547
+ self.dropout = dropout
548
+ # Deepnet alpha
549
+ self.alpha = alpha
550
+
551
+ def forward(self, x, need_weights=False):
552
+ # Keep input as residual
553
+ residual = x * self.alpha
554
+
555
+ # Compute attention
556
+ x, attention_weights = self.attention(x, need_weights)
557
+
558
+ # Add attention with residual and normalize.
559
+ x = self.norm1(residual + self.dropout(x))
560
+
561
+ # Keep output as next residual.
562
+ residual = x * self.alpha
563
+
564
+ # Pass through feedforward network.
565
+ x = self.feedforward(x)
566
+
567
+ # Combine residual and ff output, then normalize again.
568
+ x = self.norm2(residual + self.dropout(x))
569
+
570
+ return x, attention_weights
571
+
572
+ # A vanilla MLP transfomer layer.
573
+ class FeedforwardLayer(nn.Module):
574
+ def __init__(
575
+ self,
576
+ d_model: int,
577
+ feedforward_dim: int,
578
+ dropout,
579
+ activation=nn.ReLU(),
580
+ beta=1.0,
581
+ bias=True,
582
+ ):
583
+ super().__init__()
584
+ self.d_model = d_model
585
+ self.beta = beta
586
+ self.activation = activation
587
+ self.linear1 = nn.Linear(d_model, feedforward_dim, bias=bias)
588
+ self.linear2 = nn.Linear(feedforward_dim, d_model, bias=bias)
589
+ self.dropout = nn.Dropout(dropout)
590
+ self.reset_parameters()
591
+
592
+ def forward(self, x):
593
+ return self.linear2(self.dropout(self.activation(self.linear1(x))))
594
+
595
+ def reset_parameters(self):
596
+ init.xavier_uniform_(self.linear1.weight, gain=self.beta)
597
+ init.xavier_uniform_(self.linear2.weight, gain=self.beta)
598
+ init.constant_(self.linear1.bias, 0.)
599
+ init.constant_(self.linear2.bias, 0.)
600
+
601
+ # GLU Variants Improve Transformer
602
+ # https://arxiv.org/pdf/2002.05202v1.pdf
603
+ class SwiGLUFeedforwardLayer(nn.Module):
604
+ def __init__(
605
+ self,
606
+ d_model,
607
+ d_feedforward,
608
+ beta=1.0,
609
+ dropout=0.1
610
+ ):
611
+ super().__init__()
612
+ self.d_model = d_model
613
+ self.d_feedforward = d_feedforward
614
+ self.beta = 1.0
615
+
616
+ self.linear1 = nn.Linear(self.d_model, self.d_feedforward * 2, bias=False)
617
+ self.linear2 = nn.Linear(self.d_feedforward, self.d_model, bias=False)
618
+ self.dropout = nn.Dropout(dropout)
619
+ self.reset_parameters()
620
+
621
+ def forward(self, x):
622
+ x, gate = self.linear1(x).chunk(2, dim=-1)
623
+ x = x * F.silu(gate)
624
+ x = self.dropout(x)
625
+ x = self.linear2(x)
626
+ return x
627
+
628
+ def reset_parameters(self):
629
+ # Deepnet initialization
630
+ # https://arxiv.org/pdf/2203.00555.pdf
631
+ w, g = self.linear1.weight.chunk(2, dim=0)
632
+ init.xavier_uniform_(w, gain=self.beta)
633
+ init.xavier_uniform_(g, gain=self.beta)
634
+ init.xavier_uniform_(self.linear2.weight, gain=self.beta)
635
+
636
+ class CausalSelfAttention(nn.Module):
637
+ def __init__(
638
+ self,
639
+ d_model,
640
+ num_heads,
641
+ # values:
642
+ # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
643
+ # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
644
+ # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; least memory usage.
645
+ attn_type,
646
+ beta=1.0,
647
+ dropout=0.1,
648
+ ):
649
+ super().__init__()
650
+ self.d_model = d_model
651
+ self.num_heads = num_heads
652
+ self.beta = beta
653
+ self.attn_type = attn_type
654
+
655
+ assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
656
+
657
+ # The dimension of each head.
658
+ self.d_head = d_model // num_heads
659
+
660
+ # We scale the attention scores by the inverse-square-root of the head dimension
661
+ # this shifts the temerature of softmax.
662
+ self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
663
+
664
+ self.in_proj = nn.Linear(self.d_model, 3 * self.d_model, bias=True)
665
+ self.output_linear = nn.Linear(self.d_model, self.d_model, bias=True)
666
+
667
+ self.dropout = nn.Dropout(dropout)
668
+ self.reset_parameters()
669
+
670
+ def extra_repr(self) -> str:
671
+ return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, dropout={self.dropout}'
672
+
673
+ def reset_parameters(self):
674
+ # Deepnet initialization
675
+ # https://arxiv.org/pdf/2203.00555.pdf
676
+ q, k, v = self.in_proj.weight.chunk(3)
677
+ init.xavier_uniform_(q, gain=1.0)
678
+ init.xavier_uniform_(k, gain=1.0)
679
+ init.xavier_uniform_(v, gain=self.beta)
680
+ init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
681
+ init.constant_(self.in_proj.bias, 0.)
682
+ init.constant_(self.output_linear.bias, 0.)
683
+
684
+ def project_input(self, qkv):
685
+ proj = self.in_proj(qkv)
686
+ return proj.chunk(chunks=3, dim=-1)
687
+
688
+ def forward(self, qkv, need_weights):
689
+ if self.attn_type == "flash2":
690
+ return self.flash2_forward(qkv)
691
+
692
+ # qkv: (batch_size, seq_len, d_embed)
693
+ batch_size, seq_len, d_embed = qkv.shape
694
+
695
+ # Feed the inputs through the K, Q, V matrices.
696
+ query, key, value = self.project_input(qkv)
697
+
698
+ # Split projections into multiple heads and swap position of sequence / heads dimension
699
+ query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
700
+ key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
701
+ value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
702
+
703
+ # Default to returning empty attention weights.
704
+ attention_weights = None
705
+
706
+ if self.attn_type == "torch":
707
+ # This context manager can be used to force which implementation to use.
708
+ #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
709
+ attended_values = F.scaled_dot_product_attention(
710
+ query,
711
+ key,
712
+ value,
713
+ attn_mask=None,
714
+ dropout_p=self.dropout.p if self.training else 0.0,
715
+ is_causal=True,
716
+ scale=self.dot_product_scale
717
+ )
718
+ # "native" scaled-dot-product attention implementation.
719
+ else:
720
+ # Compute attention scores
721
+ scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
722
+
723
+ # Mask future positions from the past
724
+ scores.masked_fill_(
725
+ torch.tril(
726
+ torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device),
727
+ diagonal=0,
728
+ ).logical_not(),
729
+ float('-inf'),
730
+ )
731
+
732
+ # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
733
+ attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
734
+ del scores
735
+
736
+ # Use the attention weights to get a weighted combination of value vectors
737
+ attended_values = torch.matmul(attention_weights, value)
738
+ if not need_weights:
739
+ del attention_weights
740
+ attention_weights = None
741
+
742
+ # Concatenate attention heads and project to original embedding size using the output linear layer
743
+ attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
744
+
745
+ # Project the concatenated output through the output matrix.
746
+ attended_values = self.output_linear(attended_values)
747
+ return attended_values, attention_weights
748
+
749
+ def flash2_forward(self, qkv):
750
+ batch_size, seq_len, d_embed = qkv.shape
751
+
752
+ # Feed the inputs through the K, Q, V matrices.
753
+ # query : (batch_size, seq_len, d_model)
754
+ # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
755
+ qkv = self.in_proj(qkv).unflatten(
756
+ -1,
757
+ (3, self.num_heads, self.d_head)
758
+ )
759
+
760
+ attended_values = flash_attn_qkvpacked_func(
761
+ qkv.bfloat16(),
762
+ dropout_p=self.dropout.p if self.training else 0.0,
763
+ softmax_scale=self.dot_product_scale,
764
+ causal=True,
765
+ )
766
+ # attended_values: (batch_size, seqlen, nheads, headdim)
767
+
768
+ # Concatentate heads back into d_embed
769
+ attended_values = attended_values.view(batch_size, seq_len, d_embed)
770
+
771
+ # Project the concatenated output through the output matrix.
772
+ attended_values = self.output_linear(attended_values)
773
+ return attended_values, None
774
+
775
+ # Attention layer with ALiBi relative positional encoding
776
+ # TRAIN SHORT, TEST LONG: ATTENTION WITH LINEAR BIASES ENABLES INPUT LENGTH EXTRAPOLATION
777
+ # https://arxiv.org/pdf/2108.12409.pdf
778
+ def alibi_biases(query_len, key_len, device='cpu'):
779
+ x = torch.arange(key_len, device=device)[None, :]
780
+ y = torch.arange(query_len, device=device)[:, None]
781
+ return x - y
782
+
783
+ class CausalAlibiAttention(nn.Module):
784
+ def __init__(
785
+ self,
786
+ d_model,
787
+ num_heads,
788
+ beta=1.0,
789
+ dropout=0.1,
790
+ # values:
791
+ # native: Use local impementation; slowest option; good for debugging; useful when experimenting with non-standard stuff.
792
+ # torch: Use pytorch "scaled_dot_product_attention()"; faster; generally good compatibility; does not support returning attn weights.
793
+ # flash2: Use Flash-Attention2 implementation; fastest; limited to int16 and bfloat16 types; can't train Alibi weights; least memory usage.
794
+ # Note: You can perform initial training with "torch," then switch to "flash2," after the Alibi weights have settled.
795
+ window_size=None,
796
+ attn_type="native",
797
+ freeze_alibi=True,
798
+ ):
799
+ super().__init__()
800
+ self.d_model = d_model
801
+ self.num_heads = num_heads
802
+ self.beta = beta
803
+ self.attn_type = attn_type
804
+
805
+ assert d_model % num_heads == 0, "d_model must be evenly divisible by num_heads"
806
+
807
+ # The dimension of each head.
808
+ self.d_head = d_model // num_heads
809
+
810
+ # We scale the attention scores by the inverse-square-root of the head dimension
811
+ # this shifts the temerature of softmax.
812
+ self.dot_product_scale = 1.0 / math.sqrt(self.d_head)
813
+
814
+ self.in_proj = nn.Parameter(torch.empty(3 * self.d_model, self.d_model))
815
+ self.output_linear = nn.Linear(self.d_model, self.d_model, bias=False)
816
+
817
+ if window_size is not None:
818
+ self.window_size=(window_size, -1)
819
+ else:
820
+ self.window_size = (-1, -1)
821
+
822
+ self.dropout = nn.Dropout(dropout)
823
+
824
+ # This generates the original slope distribution from the paper.
825
+ # Observations with trainable slopes suggest that the high half of the slopes shift
826
+ # towards / past 1.0 and the low half approach zero or even go slightly negative.
827
+ # alibi_slopes = 1.0 / torch.logspace(1, 8, self.num_heads, base=2, dtype=torch.float)
828
+
829
+ # These appear to work better, as initial values, in practice.
830
+ alibi_slopes = 1.0 / torch.logspace(0, 7, self.num_heads, base=2, dtype=torch.float)
831
+
832
+ # If not trainable, it can improve performance somewhat if the low half are set to zero. Apparently
833
+ # making roughly half of the slopes position-agnostic is somehow closer to optimal?
834
+ # alibi_slopes.masked_fill_(torch.where(torch.arange(0, self.num_heads) >= (self.num_heads / 2), True, False), 0)
835
+
836
+ self.alibi_slopes = nn.Parameter(alibi_slopes)
837
+
838
+ # Optionally, allow/disallow training of ALiBi slopes.
839
+ self.alibi_slopes.requires_grad = (not freeze_alibi)
840
+ self.reset_parameters()
841
+
842
+ def extra_repr(self) -> str:
843
+ return f'd_model={self.d_model}, num_heads={self.num_heads}, beta={self.beta}, attn_type={self.attn_type}, window_size={self.window_size}, dropout={self.dropout}'
844
+
845
+ def reset_parameters(self):
846
+ # Deepnet initialization
847
+ # https://arxiv.org/pdf/2203.00555.pdf
848
+
849
+ q, k, v = self.in_proj.chunk(3)
850
+ init.xavier_uniform_(q, gain=1.0)
851
+ init.xavier_uniform_(k, gain=1.0)
852
+ init.xavier_uniform_(v, gain=self.beta)
853
+ init.xavier_uniform_(self.output_linear.weight, gain=self.beta)
854
+
855
+ def project_input(self, qkv):
856
+ proj = F.linear(qkv, self.in_proj)
857
+ return proj.chunk(chunks=3, dim=-1)
858
+
859
+ def forward(self, qkv, need_weights):
860
+ if self.attn_type == "flash2":
861
+ return self.flash2_forward(qkv)
862
+
863
+ # qkv: (batch_size, seq_len, d_embed)
864
+ batch_size, seq_len, d_embed = qkv.shape
865
+
866
+ # Feed the inputs through the K, Q, V matrices.
867
+ query, key, value = self.project_input(qkv)
868
+
869
+ # Split projections into multiple heads and swap position of sequence / heads dimension
870
+ query = query.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
871
+ key = key.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
872
+ value = value.view(batch_size, seq_len, self.num_heads, self.d_head).transpose(1, 2)
873
+
874
+ # Apply Alibi relative positional biases.
875
+ attn_bias = alibi_biases(seq_len, seq_len, device=query.device) * self.alibi_slopes.view(-1, 1, 1)
876
+
877
+ # Mask future positions from the past
878
+ causal_mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=qkv.device), diagonal=0)
879
+ attn_bias.masked_fill_(causal_mask.logical_not(), float('-inf'))
880
+ del causal_mask
881
+
882
+ # Default to returning empty attention weights.
883
+ attention_weights = None
884
+
885
+ if self.attn_type == "torch":
886
+ # This context manager can be used to force which implementation to use.
887
+ #with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
888
+ attended_values = F.scaled_dot_product_attention(
889
+ query,
890
+ key,
891
+ value,
892
+ attn_mask=attn_bias.to(dtype=query.dtype),
893
+ dropout_p=self.dropout.p if self.training else 0.0,
894
+ is_causal=False,
895
+ scale=self.dot_product_scale
896
+ )
897
+ # "native" scaled-dot-product attention implementation.
898
+ else:
899
+ # Compute attention scores
900
+ scores = torch.matmul(query, key.transpose(-2, -1)) * self.dot_product_scale
901
+
902
+ # Adjust scores with attn_mask
903
+ scores += attn_bias
904
+
905
+ # Calculate the attention weights; avoid NANs that might emerge from zeros in softmax's denominator
906
+ attention_weights = self.dropout(torch.softmax(scores, dim=-1).clamp(min=1e-10))
907
+
908
+ # Use the attention weights to get a weighted combination of value vectors
909
+ attended_values = torch.matmul(attention_weights, value)
910
+ if not need_weights:
911
+ attention_weights = None
912
+
913
+ # Concatenate attention heads and project to original embedding size using the output linear layer
914
+ attended_values = attended_values.transpose(1, 2).contiguous().view(batch_size, seq_len, d_embed)
915
+
916
+ # Project the concatenated output through the output matrix.
917
+ attended_values = self.output_linear(attended_values)
918
+ return attended_values, attention_weights
919
+
920
+ def flash2_forward(self, qkv):
921
+ batch_size, seq_len, d_embed = qkv.shape
922
+
923
+ # Feed the inputs through the K, Q, V matrices.
924
+ # query : (batch_size, seq_len, d_model)
925
+ # qkv : (batch_size, seq_len, 3, num_heads, d_kq)
926
+ qkv = F.linear(
927
+ qkv,
928
+ self.in_proj,
929
+ ).unflatten(
930
+ -1,
931
+ (3, self.num_heads, self.d_head)
932
+ )
933
+
934
+ attended_values = flash_attn_qkvpacked_func(
935
+ qkv.bfloat16(),
936
+ dropout_p=self.dropout.p if self.training else 0.0,
937
+ softmax_scale=self.dot_product_scale,
938
+ causal=True,
939
+ window_size=self.window_size,
940
+ alibi_slopes=self.alibi_slopes.float(),
941
+ ).to(dtype=qkv.dtype)
942
+ # attended_values: (batch_size, seqlen, nheads, headdim)
943
+
944
+ # Concatentate heads back into d_embed
945
+ attended_values = attended_values.view(batch_size, seq_len, d_embed)
946
+
947
+ # Project the concatenated output through the output matrix.
948
+ attended_values = self.output_linear(attended_values)
949
+ return attended_values, None