Upload 14 files
Browse files- .gitattributes +1 -0
- README.md +286 -0
- block.py +470 -0
- config.json +39 -0
- configuration_xlm_roberta.py +69 -0
- embedding.py +62 -0
- mha.py +662 -0
- mlp.py +194 -0
- modeling_xlm_roberta.py +1119 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +51 -0
- stochastic_depth.py +97 -0
- tokenizer.json +3 -0
- tokenizer_config.json +54 -0
- xlm_padding.py +218 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
pipeline_tag: text-classification
|
3 |
+
tags:
|
4 |
+
- transformers
|
5 |
+
- reranker
|
6 |
+
- cross-encoder
|
7 |
+
- transformers.js
|
8 |
+
language:
|
9 |
+
- multilingual
|
10 |
+
inference: false
|
11 |
+
license: cc-by-nc-4.0
|
12 |
+
library_name: transformers
|
13 |
+
---
|
14 |
+
|
15 |
+
<br><br>
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<img src="https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/603763514de52ff951d89793/AFoybzd5lpBQXEBrQHuTt.png?w=200&h=200&f=face" alt="Finetuner logo: Finetuner helps you to create experiments in order to improve embeddings on search tasks. It accompanies you to deliver the last mile of performance-tuning for neural search applications." width="150px">
|
19 |
+
</p>
|
20 |
+
|
21 |
+
<p align="center">
|
22 |
+
<b>Trained by <a href="https://jina.ai/"><b>Jina AI</b></a>.</b>
|
23 |
+
</p>
|
24 |
+
|
25 |
+
# jina-reranker-v2-base-multilingual
|
26 |
+
|
27 |
+
## Intended Usage & Model Info
|
28 |
+
|
29 |
+
The **Jina Reranker v2** (`jina-reranker-v2-base-multilingual`) is a transformer-based model that has been fine-tuned for text reranking task, which is a crucial component in many information retrieval systems. It is a cross-encoder model that takes a query and a document pair as input and outputs a score indicating the relevance of the document to the query. The model is trained on a large dataset of query-document pairs and is capable of reranking documents in multiple languages with high accuracy.
|
30 |
+
|
31 |
+
Compared with the state-of-the-art reranker models, including the previous released `jina-reranker-v1-base-en`, the **Jina Reranker v2** model has demonstrated competitiveness across a series of benchmarks targeting for text retrieval, multilingual capability, function-calling-aware and text-to-SQL-aware reranking, and code retrieval tasks.
|
32 |
+
|
33 |
+
The `jina-reranker-v2-base-multilingual` model is capable of handling long texts with a context length of up to `1024` tokens, enabling the processing of extensive inputs. To enable the model to handle long texts that exceed 1024 tokens, the model uses a sliding window approach to chunk the input text into smaller pieces and rerank each chunk separately.
|
34 |
+
|
35 |
+
The model is also equipped with a flash attention mechanism, which significantly improves the model's performance.
|
36 |
+
|
37 |
+
|
38 |
+
# Usage
|
39 |
+
|
40 |
+
_This model repository is licenced for research and evaluation purposes under CC-BY-NC-4.0. For commercial usage, please refer to Jina AI's APIs, AWS Sagemaker or Azure Marketplace offerings. Please [contact us](https://jina.ai/contact-sales) for any further clarifications._
|
41 |
+
1. The easiest way to use `jina-reranker-v2-base-multilingual` is to call Jina AI's [Reranker API](https://jina.ai/reranker/).
|
42 |
+
|
43 |
+
```bash
|
44 |
+
curl https://api.jina.ai/v1/rerank \
|
45 |
+
-H "Content-Type: application/json" \
|
46 |
+
-H "Authorization: Bearer YOUR_API_KEY" \
|
47 |
+
-d '{
|
48 |
+
"model": "jina-reranker-v2-base-multilingual",
|
49 |
+
"query": "Organic skincare products for sensitive skin",
|
50 |
+
"documents": [
|
51 |
+
"Organic skincare for sensitive skin with aloe vera and chamomile.",
|
52 |
+
"New makeup trends focus on bold colors and innovative techniques",
|
53 |
+
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
|
54 |
+
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
|
55 |
+
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
|
56 |
+
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
|
57 |
+
"针对敏感肌专门设计的天然有机护肤产品",
|
58 |
+
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
|
59 |
+
"敏感肌のために特別に設計された天然有機スキンケア製品",
|
60 |
+
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています"
|
61 |
+
],
|
62 |
+
"top_n": 3
|
63 |
+
}'
|
64 |
+
```
|
65 |
+
|
66 |
+
2. You can also use the `transformers` library to interact with the model programmatically.
|
67 |
+
|
68 |
+
Before you start, install the `transformers` and `einops` libraries:
|
69 |
+
|
70 |
+
```bash
|
71 |
+
pip install transformers einops
|
72 |
+
```
|
73 |
+
|
74 |
+
And then:
|
75 |
+
```python
|
76 |
+
from transformers import AutoModelForSequenceClassification
|
77 |
+
|
78 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
79 |
+
'jinaai/jina-reranker-v2-base-multilingual',
|
80 |
+
torch_dtype="auto",
|
81 |
+
trust_remote_code=True,
|
82 |
+
)
|
83 |
+
|
84 |
+
model.to('cuda') # or 'cpu' if no GPU is available
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
# Example query and documents
|
88 |
+
query = "Organic skincare products for sensitive skin"
|
89 |
+
documents = [
|
90 |
+
"Organic skincare for sensitive skin with aloe vera and chamomile.",
|
91 |
+
"New makeup trends focus on bold colors and innovative techniques",
|
92 |
+
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
|
93 |
+
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
|
94 |
+
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
|
95 |
+
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
|
96 |
+
"针对敏感肌专门设计的天然有机护肤产品",
|
97 |
+
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
|
98 |
+
"敏感肌のために特別に設計された天然有機スキンケア製品",
|
99 |
+
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
|
100 |
+
]
|
101 |
+
|
102 |
+
# construct sentence pairs
|
103 |
+
sentence_pairs = [[query, doc] for doc in documents]
|
104 |
+
|
105 |
+
scores = model.compute_score(sentence_pairs, max_length=1024)
|
106 |
+
```
|
107 |
+
|
108 |
+
The scores will be a list of floats, where each float represents the relevance score of the corresponding document to the query. Higher scores indicate higher relevance.
|
109 |
+
For instance the returning scores in this case will be:
|
110 |
+
```bash
|
111 |
+
[0.8311430811882019, 0.09401018172502518,
|
112 |
+
0.6334102749824524, 0.08269733935594559,
|
113 |
+
0.7620701193809509, 0.09947021305561066,
|
114 |
+
0.9263036847114563, 0.05834583938121796,
|
115 |
+
0.8418256044387817, 0.11124119907617569]
|
116 |
+
```
|
117 |
+
|
118 |
+
The model gives high relevance scores to the documents that are most relevant to the query regardless of the language of the document.
|
119 |
+
|
120 |
+
Note that by default, the `jina-reranker-v2-base-multilingual` model uses [flash attention](https://github.com/Dao-AILab/flash-attention), which requires certain types of GPU hardware to run.
|
121 |
+
If you encounter any issues, you can try call `AutoModelForSequenceClassification.from_pretrained()` with `use_flash_attn=False`.
|
122 |
+
This will use the standard attention mechanism instead of flash attention.
|
123 |
+
|
124 |
+
If you want to use flash attention for fast inference, you need to install the following packages:
|
125 |
+
```bash
|
126 |
+
pip install ninja # required for flash attention
|
127 |
+
pip install flash-attn --no-build-isolation
|
128 |
+
```
|
129 |
+
Enjoy the 3x-6x speedup with flash attention! ⚡️⚡️⚡️
|
130 |
+
|
131 |
+
|
132 |
+
3. You can also use the `transformers.js` library to run the model directly in JavaScript (in-browser, Node.js, Deno, etc.)!
|
133 |
+
|
134 |
+
If you haven't already, you can install the [Transformers.js](https://huggingface.co/docs/transformers.js) JavaScript library (v3) using:
|
135 |
+
```bash
|
136 |
+
npm i xenova/transformers.js#v3
|
137 |
+
```
|
138 |
+
|
139 |
+
Then, you can use the following code to interact with the model:
|
140 |
+
```js
|
141 |
+
import { AutoTokenizer, XLMRobertaModel } from '@xenova/transformers';
|
142 |
+
|
143 |
+
const model_id = 'jinaai/jina-reranker-v2-base-multilingual';
|
144 |
+
const model = await XLMRobertaModel.from_pretrained(model_id, { dtype: 'fp32' });
|
145 |
+
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
|
146 |
+
|
147 |
+
/**
|
148 |
+
* Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.
|
149 |
+
* @param {string} query A single query
|
150 |
+
* @param {string[]} documents A list of documents
|
151 |
+
* @param {Object} options Options for ranking
|
152 |
+
* @param {number} [options.top_k=undefined] Return the top-k documents. If undefined, all documents are returned.
|
153 |
+
* @param {number} [options.return_documents=false] If true, also returns the documents. If false, only returns the indices and scores.
|
154 |
+
*/
|
155 |
+
async function rank(query, documents, {
|
156 |
+
top_k = undefined,
|
157 |
+
return_documents = false,
|
158 |
+
} = {}) {
|
159 |
+
const inputs = tokenizer(
|
160 |
+
new Array(documents.length).fill(query),
|
161 |
+
{ text_pair: documents, padding: true, truncation: true }
|
162 |
+
)
|
163 |
+
const { logits } = await model(inputs);
|
164 |
+
return logits.sigmoid().tolist()
|
165 |
+
.map(([score], i) => ({
|
166 |
+
corpus_id: i,
|
167 |
+
score,
|
168 |
+
...(return_documents ? { text: documents[i] } : {})
|
169 |
+
})).sort((a, b) => b.score - a.score).slice(0, top_k);
|
170 |
+
}
|
171 |
+
|
172 |
+
// Example usage:
|
173 |
+
const query = "Organic skincare products for sensitive skin"
|
174 |
+
const documents = [
|
175 |
+
"Organic skincare for sensitive skin with aloe vera and chamomile.",
|
176 |
+
"New makeup trends focus on bold colors and innovative techniques",
|
177 |
+
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
|
178 |
+
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
|
179 |
+
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
|
180 |
+
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
|
181 |
+
"针对敏感肌专门设计的天然有机护肤产品",
|
182 |
+
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
|
183 |
+
"敏感肌のために特別に設計された天然有機スキンケア製品",
|
184 |
+
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
|
185 |
+
]
|
186 |
+
|
187 |
+
const results = await rank(query, documents, { return_documents: true, top_k: 3 });
|
188 |
+
console.log(results);
|
189 |
+
```
|
190 |
+
|
191 |
+
|
192 |
+
That's it! You can now use the `jina-reranker-v2-base-multilingual` model in your projects.
|
193 |
+
|
194 |
+
|
195 |
+
In addition to the `compute_score()` function, the `jina-reranker-v2-base-multilingual` model also provides a `model.rerank()` function that can be used to rerank documents based on a query. You can use it as follows:
|
196 |
+
|
197 |
+
```python
|
198 |
+
result = model.rerank(
|
199 |
+
query,
|
200 |
+
documents,
|
201 |
+
max_query_length=512,
|
202 |
+
max_length=1024,
|
203 |
+
top_n=3
|
204 |
+
)
|
205 |
+
```
|
206 |
+
|
207 |
+
Inside the `result` object, you will find the reranked documents along with their scores. You can use this information to further process the documents as needed.
|
208 |
+
|
209 |
+
The `rerank()` function will automatically chunk the input documents into smaller pieces if they exceed the model's maximum input length. This allows you to rerank long documents without running into memory issues.
|
210 |
+
Specifically, the `rerank()` function will split the documents into chunks of size `max_length` and rerank each chunk separately. The scores from all the chunks are then combined to produce the final reranking results. You can control the query length and document length in each chunk by setting the `max_query_length` and `max_length` parameters. The `rerank()` function also supports the `overlap` parameter (default is `80`) which determines how much overlap there is between adjacent chunks. This can be useful when reranking long documents to ensure that the model has enough context to make accurate predictions.
|
211 |
+
|
212 |
+
3. Alternatively, `jina-reranker-v2-base-multilingual` has been integrated with `CrossEncoder` from the `sentence-transformers` library.
|
213 |
+
|
214 |
+
Before you start, install the `sentence-transformers` libraries:
|
215 |
+
|
216 |
+
```bash
|
217 |
+
pip install sentence-transformers
|
218 |
+
```
|
219 |
+
|
220 |
+
The [`CrossEncoder`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html) class supports a [`predict`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html#sentence_transformers.cross_encoder.CrossEncoder.predict) method to get query-document relevance scores, and a [`rank`](https://sbert.net/docs/package_reference/cross_encoder/cross_encoder.html#sentence_transformers.cross_encoder.CrossEncoder.rank) method to rank all documents given your query.
|
221 |
+
|
222 |
+
```python
|
223 |
+
from sentence_transformers import CrossEncoder
|
224 |
+
|
225 |
+
model = CrossEncoder(
|
226 |
+
"jinaai/jina-reranker-v2-base-multilingual",
|
227 |
+
automodel_args={"torch_dtype": "auto"},
|
228 |
+
trust_remote_code=True,
|
229 |
+
)
|
230 |
+
|
231 |
+
# Example query and documents
|
232 |
+
query = "Organic skincare products for sensitive skin"
|
233 |
+
documents = [
|
234 |
+
"Organic skincare for sensitive skin with aloe vera and chamomile.",
|
235 |
+
"New makeup trends focus on bold colors and innovative techniques",
|
236 |
+
"Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille",
|
237 |
+
"Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken",
|
238 |
+
"Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla",
|
239 |
+
"Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras",
|
240 |
+
"针对敏感肌专门设计的天然有机护肤产品",
|
241 |
+
"新的化妆趋势注重鲜艳的颜色和创新的技巧",
|
242 |
+
"敏感肌のために特別に設計された天然有機スキンケア製品",
|
243 |
+
"新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています",
|
244 |
+
]
|
245 |
+
|
246 |
+
# construct sentence pairs
|
247 |
+
sentence_pairs = [[query, doc] for doc in documents]
|
248 |
+
|
249 |
+
scores = model.predict(sentence_pairs, convert_to_tensor=True).tolist()
|
250 |
+
"""
|
251 |
+
[0.828125, 0.0927734375, 0.6328125, 0.08251953125, 0.76171875, 0.099609375, 0.92578125, 0.058349609375, 0.84375, 0.111328125]
|
252 |
+
"""
|
253 |
+
|
254 |
+
rankings = model.rank(query, documents, return_documents=True, convert_to_tensor=True)
|
255 |
+
print(f"Query: {query}")
|
256 |
+
for ranking in rankings:
|
257 |
+
print(f"ID: {ranking['corpus_id']}, Score: {ranking['score']:.4f}, Text: {ranking['text']}")
|
258 |
+
"""
|
259 |
+
Query: Organic skincare products for sensitive skin
|
260 |
+
ID: 6, Score: 0.9258, Text: 针对敏感肌专门设计的天然有机护肤产品
|
261 |
+
ID: 8, Score: 0.8438, Text: 敏感肌のために特別に設計された天然有機スキンケア製品
|
262 |
+
ID: 0, Score: 0.8281, Text: Organic skincare for sensitive skin with aloe vera and chamomile.
|
263 |
+
ID: 4, Score: 0.7617, Text: Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla
|
264 |
+
ID: 2, Score: 0.6328, Text: Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille
|
265 |
+
ID: 9, Score: 0.1113, Text: 新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています
|
266 |
+
ID: 5, Score: 0.0996, Text: Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras
|
267 |
+
ID: 1, Score: 0.0928, Text: New makeup trends focus on bold colors and innovative techniques
|
268 |
+
ID: 3, Score: 0.0825, Text: Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken
|
269 |
+
ID: 7, Score: 0.0583, Text: 新的化妆趋势注重鲜艳的颜色和创新的技巧
|
270 |
+
"""
|
271 |
+
```
|
272 |
+
|
273 |
+
# Evaluation
|
274 |
+
|
275 |
+
We evaluated Jina Reranker v2 on multiple benchmarks to ensure top-tier performance and search relevance.
|
276 |
+
|
277 |
+
| Model Name | Model Size | MKQA(nDCG@10, 26 langs) | BEIR(nDCG@10, 17 datasets) | MLDR(recall@10, 13 langs) | CodeSearchNet (MRR@10, 3 tasks) | AirBench (nDCG@10, zh/en) | ToolBench (recall@3, 3 tasks) | TableSearch (recall@3) |
|
278 |
+
| :-----------------------------: | :----------: | ------------------------- | ---------------------------- | --------------------------- | --------------------------------- | --------------------------- | ------------------------------- | ------------------------ |
|
279 |
+
| jina-reranker-v2-multilingual | 278M | 54.83 | 53.17 | 68.95 | 71.36 | 61.33 | 77.75 | 93.31 |
|
280 |
+
| bge-reranker-v2-m3 | 568M | 54.17 | 53.65 | 59.73 | 62.86 | 61.28 | 78.46 | 74.86 |
|
281 |
+
| mmarco-mMiniLMv2-L12-H384-v1 | 118M | 53.37 | 45.40 | 28.91 | 51.78 | 56.46 | 58.39 | 53.60 |
|
282 |
+
| jina-reranker-v1-base-en | 137M | - | 52.45 | - | - | - | 74.13 | 72.89 |
|
283 |
+
|
284 |
+
Note:
|
285 |
+
- NDCG@10 and MRR@10 measure ranking quality, with higher scores indicating better search results
|
286 |
+
- recall@3 measures the proportion of relevant documents retrieved, with higher scores indicating better search results
|
block.py
ADDED
@@ -0,0 +1,470 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
3 |
+
|
4 |
+
# Copyright (c) 2024, Tri Dao.
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.fx
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
from .mha import MHA
|
16 |
+
from .mlp import Mlp
|
17 |
+
|
18 |
+
try:
|
19 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
|
20 |
+
except ImportError:
|
21 |
+
layer_norm_fn, RMSNorm = None, None
|
22 |
+
|
23 |
+
|
24 |
+
def stochastic_depth(
|
25 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
26 |
+
) -> Tensor:
|
27 |
+
"""
|
28 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
29 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
30 |
+
branches of residual architectures.
|
31 |
+
Args:
|
32 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
33 |
+
being its batch i.e. a batch with ``N`` rows.
|
34 |
+
p (float): probability of the input to be zeroed.
|
35 |
+
mode (str): ``"batch"`` or ``"row"``.
|
36 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
37 |
+
randomly selected rows from the batch.
|
38 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
39 |
+
Returns:
|
40 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
41 |
+
"""
|
42 |
+
if p < 0.0 or p > 1.0:
|
43 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
44 |
+
if mode not in ["batch", "row"]:
|
45 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
46 |
+
if not training or p == 0.0:
|
47 |
+
return input
|
48 |
+
|
49 |
+
survival_rate = 1.0 - p
|
50 |
+
if mode == "row":
|
51 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
52 |
+
else:
|
53 |
+
size = [1] * input.ndim
|
54 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
55 |
+
noise = noise.bernoulli_(survival_rate)
|
56 |
+
if survival_rate > 0.0:
|
57 |
+
noise.div_(survival_rate)
|
58 |
+
return input * noise
|
59 |
+
|
60 |
+
|
61 |
+
torch.fx.wrap("stochastic_depth")
|
62 |
+
|
63 |
+
|
64 |
+
class StochasticDepth(nn.Module):
|
65 |
+
"""
|
66 |
+
See :func:`stochastic_depth`.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, p: float, mode: str) -> None:
|
70 |
+
super().__init__()
|
71 |
+
self.p = p
|
72 |
+
self.mode = mode
|
73 |
+
|
74 |
+
def forward(self, input: Tensor) -> Tensor:
|
75 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
76 |
+
|
77 |
+
def __repr__(self) -> str:
|
78 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
79 |
+
return s
|
80 |
+
|
81 |
+
|
82 |
+
class Block(nn.Module):
|
83 |
+
def __init__(
|
84 |
+
self,
|
85 |
+
dim,
|
86 |
+
mixer_cls=None,
|
87 |
+
mlp_cls=None,
|
88 |
+
norm_cls=nn.LayerNorm,
|
89 |
+
dropout_cls=nn.Dropout,
|
90 |
+
prenorm=True,
|
91 |
+
resid_dropout1=0.0,
|
92 |
+
resid_dropout2=0.0,
|
93 |
+
drop_path1=0.0,
|
94 |
+
drop_path2=0.0,
|
95 |
+
fused_dropout_add_ln=False,
|
96 |
+
return_residual=False,
|
97 |
+
residual_in_fp32=False,
|
98 |
+
sequence_parallel=False,
|
99 |
+
mark_shared_params=False,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
For prenorm=True, this Block has a slightly different structure compared to a regular
|
103 |
+
prenorm Transformer block.
|
104 |
+
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
|
105 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
106 |
+
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
|
107 |
+
the hidden_states (output of the MLP) and the residual.
|
108 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
109 |
+
The residual needs to be provided (except for the very first block).
|
110 |
+
|
111 |
+
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
|
112 |
+
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
|
113 |
+
|
114 |
+
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
|
115 |
+
This is for performance reason: for post-norm architecture, returning the input allows us
|
116 |
+
to fuse the backward of nn.Linear with the residual connection.
|
117 |
+
"""
|
118 |
+
super().__init__()
|
119 |
+
self.prenorm = prenorm
|
120 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
121 |
+
self.return_residual = return_residual
|
122 |
+
self.residual_in_fp32 = residual_in_fp32
|
123 |
+
if self.residual_in_fp32:
|
124 |
+
assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True"
|
125 |
+
if mixer_cls is None:
|
126 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
127 |
+
if mlp_cls is None:
|
128 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
129 |
+
self.mixer = mixer_cls(dim)
|
130 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
131 |
+
self.drop_path1 = StochasticDepth(drop_path1, mode="row")
|
132 |
+
self.norm1 = norm_cls(dim)
|
133 |
+
self.mlp = mlp_cls(dim)
|
134 |
+
if not isinstance(self.mlp, nn.Identity):
|
135 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
136 |
+
self.drop_path2 = StochasticDepth(drop_path2, mode="row")
|
137 |
+
self.norm2 = norm_cls(dim)
|
138 |
+
|
139 |
+
if self.fused_dropout_add_ln:
|
140 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
141 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
142 |
+
self.dropout1, nn.Dropout
|
143 |
+
)
|
144 |
+
|
145 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
146 |
+
# then the input to each worker in the tensor parallel group will be different.
|
147 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
148 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
149 |
+
# and only use sequence_parallel=False during inference.
|
150 |
+
|
151 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
152 |
+
if sequence_parallel:
|
153 |
+
for p in self.norm1.parameters():
|
154 |
+
p._sequence_parallel = True
|
155 |
+
if hasattr(self, "norm2"):
|
156 |
+
for p in self.norm2.parameters():
|
157 |
+
p._sequence_parallel = True
|
158 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
159 |
+
if mark_shared_params:
|
160 |
+
for p in self.norm1.parameters():
|
161 |
+
p._shared_params = True
|
162 |
+
if hasattr(self, "norm2"):
|
163 |
+
for p in self.norm2.parameters():
|
164 |
+
p._shared_params = True
|
165 |
+
|
166 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
167 |
+
return self.mixer.allocate_inference_cache(
|
168 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
169 |
+
)
|
170 |
+
|
171 |
+
def forward(
|
172 |
+
self,
|
173 |
+
hidden_states: Tensor,
|
174 |
+
residual: Optional[Tensor] = None,
|
175 |
+
mixer_subset=None,
|
176 |
+
mixer_kwargs=None,
|
177 |
+
):
|
178 |
+
r"""Pass the input through the encoder layer.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
hidden_states: the sequence to the encoder layer (required).
|
182 |
+
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
|
183 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
184 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
185 |
+
about the CLS token in the last layer.
|
186 |
+
"""
|
187 |
+
if self.prenorm:
|
188 |
+
if not self.fused_dropout_add_ln:
|
189 |
+
dropped = self.drop_path1(self.dropout1(hidden_states))
|
190 |
+
residual = (dropped + residual) if residual is not None else dropped
|
191 |
+
hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
192 |
+
if self.residual_in_fp32:
|
193 |
+
residual = residual.to(torch.float32)
|
194 |
+
else:
|
195 |
+
if self.drop_path1.p == 0 or not self.training:
|
196 |
+
rowscale1 = None
|
197 |
+
else:
|
198 |
+
rowscale1 = self.drop_path1(
|
199 |
+
torch.ones(
|
200 |
+
hidden_states.shape[:-1],
|
201 |
+
device=hidden_states.device,
|
202 |
+
dtype=hidden_states.dtype,
|
203 |
+
)
|
204 |
+
)
|
205 |
+
hidden_states, residual = layer_norm_fn(
|
206 |
+
hidden_states,
|
207 |
+
self.norm1.weight,
|
208 |
+
self.norm1.bias,
|
209 |
+
residual=residual,
|
210 |
+
eps=self.norm1.eps,
|
211 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
212 |
+
rowscale=rowscale1,
|
213 |
+
prenorm=True,
|
214 |
+
residual_in_fp32=self.residual_in_fp32,
|
215 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
216 |
+
)
|
217 |
+
if mixer_kwargs is None:
|
218 |
+
mixer_kwargs = {}
|
219 |
+
if mixer_subset is not None:
|
220 |
+
mixer_kwargs["mixer_subset"] = mixer_subset
|
221 |
+
hidden_states = self.mixer(hidden_states, **mixer_kwargs)
|
222 |
+
if mixer_subset is not None:
|
223 |
+
residual = residual[:, mixer_subset]
|
224 |
+
if not isinstance(self.mlp, nn.Identity):
|
225 |
+
if not self.fused_dropout_add_ln:
|
226 |
+
dropped = self.drop_path2(self.dropout2(hidden_states))
|
227 |
+
residual = (dropped + residual) if residual is not None else dropped
|
228 |
+
hidden_states = self.norm2(
|
229 |
+
residual.to(dtype=self.norm2.weight.dtype)
|
230 |
+
)
|
231 |
+
if self.residual_in_fp32:
|
232 |
+
residual = residual.to(torch.float32)
|
233 |
+
else:
|
234 |
+
if self.drop_path2.p == 0 or not self.training:
|
235 |
+
rowscale2 = None
|
236 |
+
else:
|
237 |
+
rowscale2 = self.drop_path2(
|
238 |
+
torch.ones(
|
239 |
+
hidden_states.shape[:-1],
|
240 |
+
device=hidden_states.device,
|
241 |
+
dtype=hidden_states.dtype,
|
242 |
+
)
|
243 |
+
)
|
244 |
+
hidden_states, residual = layer_norm_fn(
|
245 |
+
hidden_states,
|
246 |
+
self.norm2.weight,
|
247 |
+
self.norm2.bias,
|
248 |
+
residual=residual,
|
249 |
+
eps=self.norm2.eps,
|
250 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
251 |
+
rowscale=rowscale2,
|
252 |
+
prenorm=True,
|
253 |
+
residual_in_fp32=self.residual_in_fp32,
|
254 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
255 |
+
)
|
256 |
+
hidden_states = self.mlp(hidden_states)
|
257 |
+
return hidden_states, residual
|
258 |
+
else:
|
259 |
+
assert residual is None
|
260 |
+
mixer_out = self.mixer(
|
261 |
+
hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {})
|
262 |
+
)
|
263 |
+
if self.return_residual: # mixer out is actually a pair here
|
264 |
+
mixer_out, hidden_states = mixer_out
|
265 |
+
if not self.fused_dropout_add_ln:
|
266 |
+
hidden_states = self.norm1(
|
267 |
+
(self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to(
|
268 |
+
dtype=self.norm1.weight.dtype
|
269 |
+
)
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
if self.drop_path1.p == 0 or not self.training:
|
273 |
+
rowscale1 = None
|
274 |
+
else:
|
275 |
+
rowscale1 = self.drop_path1(
|
276 |
+
torch.ones(
|
277 |
+
mixer_out.shape[:-1],
|
278 |
+
device=mixer_out.device,
|
279 |
+
dtype=mixer_out.dtype,
|
280 |
+
)
|
281 |
+
)
|
282 |
+
hidden_states = layer_norm_fn(
|
283 |
+
mixer_out,
|
284 |
+
self.norm1.weight,
|
285 |
+
self.norm1.bias,
|
286 |
+
residual=hidden_states,
|
287 |
+
eps=self.norm1.eps,
|
288 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
289 |
+
rowscale=rowscale1,
|
290 |
+
prenorm=False,
|
291 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
292 |
+
)
|
293 |
+
if not isinstance(self.mlp, nn.Identity):
|
294 |
+
mlp_out = self.mlp(hidden_states)
|
295 |
+
if self.return_residual: # mlp out is actually a pair here
|
296 |
+
mlp_out, hidden_states = mlp_out
|
297 |
+
if not self.fused_dropout_add_ln:
|
298 |
+
hidden_states = self.norm2(
|
299 |
+
(self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to(
|
300 |
+
dtype=self.norm2.weight.dtype
|
301 |
+
)
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
if self.drop_path2.p == 0 or not self.training:
|
305 |
+
rowscale2 = None
|
306 |
+
else:
|
307 |
+
rowscale2 = self.drop_path2(
|
308 |
+
torch.ones(
|
309 |
+
mlp_out.shape[:-1],
|
310 |
+
device=mlp_out.device,
|
311 |
+
dtype=mlp_out.dtype,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
hidden_states = layer_norm_fn(
|
315 |
+
mlp_out,
|
316 |
+
self.norm2.weight,
|
317 |
+
self.norm2.bias,
|
318 |
+
residual=hidden_states,
|
319 |
+
eps=self.norm2.eps,
|
320 |
+
dropout_p=self.dropout2.p if self.training else 0.0,
|
321 |
+
rowscale=rowscale2,
|
322 |
+
prenorm=False,
|
323 |
+
is_rms_norm=isinstance(self.norm2, RMSNorm),
|
324 |
+
)
|
325 |
+
return hidden_states
|
326 |
+
|
327 |
+
|
328 |
+
class ParallelBlock(nn.Module):
|
329 |
+
"""The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX,
|
330 |
+
and PaLM.
|
331 |
+
"""
|
332 |
+
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
dim,
|
336 |
+
mixer_cls=None,
|
337 |
+
mlp_cls=None,
|
338 |
+
norm_cls=nn.LayerNorm,
|
339 |
+
dropout_cls=nn.Dropout,
|
340 |
+
resid_dropout1=0.0,
|
341 |
+
resid_dropout2=0.0,
|
342 |
+
tied_norm=False,
|
343 |
+
fused_dropout_add_ln=False,
|
344 |
+
residual_in_fp32=False,
|
345 |
+
sequence_parallel=False,
|
346 |
+
mark_shared_params=False,
|
347 |
+
):
|
348 |
+
"""
|
349 |
+
This Block has a slightly different structure compared to a regular
|
350 |
+
prenorm Transformer block.
|
351 |
+
The standard block is: LN -> MHA / MLP -> Dropout -> Add.
|
352 |
+
[Ref: https://arxiv.org/abs/2002.04745]
|
353 |
+
Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both
|
354 |
+
the hidden_states (output1 of the MHA / MLP) and the residual.
|
355 |
+
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
|
356 |
+
The residual needs to be provided (except for the very first block).
|
357 |
+
"""
|
358 |
+
super().__init__()
|
359 |
+
self.tied_norm = tied_norm
|
360 |
+
self.fused_dropout_add_ln = fused_dropout_add_ln
|
361 |
+
self.residual_in_fp32 = residual_in_fp32
|
362 |
+
if mixer_cls is None:
|
363 |
+
mixer_cls = partial(MHA, num_heads=dim // 64)
|
364 |
+
if mlp_cls is None:
|
365 |
+
mlp_cls = partial(Mlp, hidden_features=4 * dim)
|
366 |
+
self.mixer = mixer_cls(dim)
|
367 |
+
self.dropout1 = dropout_cls(resid_dropout1)
|
368 |
+
self.norm1 = norm_cls(dim)
|
369 |
+
self.mlp = mlp_cls(dim)
|
370 |
+
self.dropout2 = dropout_cls(resid_dropout2)
|
371 |
+
if not self.tied_norm:
|
372 |
+
self.norm2 = norm_cls(dim)
|
373 |
+
|
374 |
+
if self.fused_dropout_add_ln:
|
375 |
+
assert layer_norm_fn is not None, "Triton is not installed"
|
376 |
+
assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance(
|
377 |
+
self.dropout1, nn.Dropout
|
378 |
+
)
|
379 |
+
|
380 |
+
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
|
381 |
+
# then the input to each worker in the tensor parallel group will be different.
|
382 |
+
# This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers.
|
383 |
+
# For now this is not an issue because we always use sequence_parallel=True during training
|
384 |
+
# and only use sequence_parallel=False during inference.
|
385 |
+
|
386 |
+
# Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads.
|
387 |
+
if sequence_parallel:
|
388 |
+
for p in self.norm1.parameters():
|
389 |
+
p._sequence_parallel = True
|
390 |
+
if hasattr(self, "norm2"):
|
391 |
+
for p in self.norm2.parameters():
|
392 |
+
p._sequence_parallel = True
|
393 |
+
# Mark the norm parameters as "shared_params" so that we sync their values at init.
|
394 |
+
if mark_shared_params:
|
395 |
+
for p in self.norm1.parameters():
|
396 |
+
p._shared_params = True
|
397 |
+
if hasattr(self, "norm2"):
|
398 |
+
for p in self.norm2.parameters():
|
399 |
+
p._shared_params = True
|
400 |
+
|
401 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
402 |
+
return self.mixer.allocate_inference_cache(
|
403 |
+
batch_size, max_seqlen, dtype=dtype, **kwargs
|
404 |
+
)
|
405 |
+
|
406 |
+
def forward(
|
407 |
+
self,
|
408 |
+
hidden_states1: Tensor,
|
409 |
+
hidden_states2: Optional[Tensor] = None,
|
410 |
+
residual: Optional[Tensor] = None,
|
411 |
+
mixer_kwargs=None,
|
412 |
+
):
|
413 |
+
r"""Pass the input through the encoder layer.
|
414 |
+
|
415 |
+
Args:
|
416 |
+
hidden_states1: the output of the previous attention (mixer) or embedding layer.
|
417 |
+
hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1).
|
418 |
+
residual.
|
419 |
+
"""
|
420 |
+
# TODO: Ideally we should only do the allgather / allreduce once for
|
421 |
+
# the Linear to MLP & Attention
|
422 |
+
if not self.fused_dropout_add_ln:
|
423 |
+
dropped1 = self.dropout1(hidden_states1)
|
424 |
+
# For the very 1st block, we only want 1 dropout, not two different dropouts
|
425 |
+
if hidden_states2 is not None:
|
426 |
+
dropped2 = self.dropout2(hidden_states2)
|
427 |
+
residual = (
|
428 |
+
(residual + dropped1 + dropped2)
|
429 |
+
if residual is not None
|
430 |
+
else dropped1 + dropped2
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
residual = (residual + dropped1) if residual is not None else dropped1
|
434 |
+
hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
|
435 |
+
hidden_states2 = (
|
436 |
+
self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
437 |
+
if not self.tied_norm
|
438 |
+
else hidden_states1
|
439 |
+
)
|
440 |
+
if self.residual_in_fp32:
|
441 |
+
residual = residual.to(torch.float32)
|
442 |
+
else:
|
443 |
+
weight2, bias2 = (
|
444 |
+
(self.norm2.weight, self.norm2.bias)
|
445 |
+
if not self.tied_norm
|
446 |
+
else (None, None)
|
447 |
+
)
|
448 |
+
hidden_states1, *rest, residual = layer_norm_fn(
|
449 |
+
hidden_states1,
|
450 |
+
self.norm1.weight,
|
451 |
+
self.norm1.bias,
|
452 |
+
residual=residual,
|
453 |
+
x1=hidden_states2,
|
454 |
+
weight1=weight2,
|
455 |
+
bias1=bias2,
|
456 |
+
eps=self.norm1.eps,
|
457 |
+
dropout_p=self.dropout1.p if self.training else 0.0,
|
458 |
+
prenorm=True,
|
459 |
+
residual_in_fp32=self.residual_in_fp32,
|
460 |
+
is_rms_norm=isinstance(self.norm1, RMSNorm),
|
461 |
+
)
|
462 |
+
if self.tied_norm:
|
463 |
+
hidden_states2 = hidden_states1
|
464 |
+
else:
|
465 |
+
(hidden_states2,) = rest
|
466 |
+
if mixer_kwargs is None:
|
467 |
+
mixer_kwargs = {}
|
468 |
+
hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs)
|
469 |
+
hidden_states2 = self.mlp(hidden_states2)
|
470 |
+
return hidden_states1, hidden_states2, residual
|
config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "jinaai/jina-reranker-v2-base-multilingual",
|
3 |
+
"architectures": ["XLMRobertaForSequenceClassification"],
|
4 |
+
"attention_probs_dropout_prob": 0.1,
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_xlm_roberta.XLMRobertaFlashConfig",
|
7 |
+
"AutoModel": "modeling_xlm_roberta.XLMRobertaModel",
|
8 |
+
"AutoModelForSequenceClassification": "modeling_xlm_roberta.XLMRobertaForSequenceClassification"
|
9 |
+
},
|
10 |
+
"bos_token_id": 0,
|
11 |
+
"classifier_dropout": null,
|
12 |
+
"emb_pooler": null,
|
13 |
+
"eos_token_id": 2,
|
14 |
+
"hidden_act": "gelu",
|
15 |
+
"hidden_dropout_prob": 0.1,
|
16 |
+
"hidden_size": 768,
|
17 |
+
"num_labels": 1,
|
18 |
+
"id2label": {
|
19 |
+
"0": "LABEL_0"
|
20 |
+
},
|
21 |
+
"initializer_range": 0.02,
|
22 |
+
"intermediate_size": 3072,
|
23 |
+
"label2id": {
|
24 |
+
"LABEL_0": 0
|
25 |
+
},
|
26 |
+
"layer_norm_eps": 1e-5,
|
27 |
+
"max_position_embeddings": 1026,
|
28 |
+
"num_attention_heads": 12,
|
29 |
+
"num_hidden_layers": 12,
|
30 |
+
"output_past": true,
|
31 |
+
"pad_token_id": 1,
|
32 |
+
"position_embedding_type": "absolute",
|
33 |
+
"torch_dtype": "bfloat16",
|
34 |
+
"transformers_version": "4.40.0",
|
35 |
+
"type_vocab_size": 1,
|
36 |
+
"use_cache": false,
|
37 |
+
"use_flash_attn": true,
|
38 |
+
"vocab_size": 250002
|
39 |
+
}
|
configuration_xlm_roberta.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class XLMRobertaFlashConfig(PretrainedConfig):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
vocab_size=30522,
|
8 |
+
hidden_size=768,
|
9 |
+
num_hidden_layers=12,
|
10 |
+
num_attention_heads=12,
|
11 |
+
intermediate_size=3072,
|
12 |
+
hidden_act="gelu",
|
13 |
+
hidden_dropout_prob=0.1,
|
14 |
+
attention_probs_dropout_prob=0.1,
|
15 |
+
max_position_embeddings=512,
|
16 |
+
type_vocab_size=2,
|
17 |
+
initializer_range=0.02,
|
18 |
+
layer_norm_eps=1e-12,
|
19 |
+
pad_token_id=1,
|
20 |
+
bos_token_id=0,
|
21 |
+
eos_token_id=2,
|
22 |
+
position_embedding_type="absolute",
|
23 |
+
use_cache=True,
|
24 |
+
classifier_dropout=None,
|
25 |
+
lora_adaptations=None,
|
26 |
+
lora_rank=4,
|
27 |
+
lora_dropout_p=0.0,
|
28 |
+
lora_alpha=1,
|
29 |
+
lora_main_params_trainable=False,
|
30 |
+
load_trained_adapters=False,
|
31 |
+
use_flash_attn=True,
|
32 |
+
torch_dtype=None,
|
33 |
+
emb_pooler=None,
|
34 |
+
matryoshka_dimensions=None,
|
35 |
+
truncate_dim=None,
|
36 |
+
**kwargs,
|
37 |
+
):
|
38 |
+
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
39 |
+
|
40 |
+
|
41 |
+
self.vocab_size = vocab_size
|
42 |
+
self.hidden_size = hidden_size
|
43 |
+
self.num_hidden_layers = num_hidden_layers
|
44 |
+
self.num_attention_heads = num_attention_heads
|
45 |
+
self.hidden_act = hidden_act
|
46 |
+
self.intermediate_size = intermediate_size
|
47 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
48 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
49 |
+
self.max_position_embeddings = max_position_embeddings
|
50 |
+
self.type_vocab_size = type_vocab_size
|
51 |
+
self.initializer_range = initializer_range
|
52 |
+
self.layer_norm_eps = layer_norm_eps
|
53 |
+
self.position_embedding_type = position_embedding_type
|
54 |
+
self.use_cache = use_cache
|
55 |
+
self.classifier_dropout = classifier_dropout
|
56 |
+
self.load_trained_adapters = load_trained_adapters
|
57 |
+
self.lora_adaptations = lora_adaptations
|
58 |
+
self.lora_rank = lora_rank
|
59 |
+
self.lora_dropout_p = lora_dropout_p
|
60 |
+
self.lora_alpha = lora_alpha
|
61 |
+
self.lora_main_params_trainable = lora_main_params_trainable
|
62 |
+
self.use_flash_attn = use_flash_attn
|
63 |
+
self.emb_pooler = emb_pooler
|
64 |
+
self.matryoshka_dimensions = matryoshka_dimensions
|
65 |
+
self.truncate_dim = truncate_dim
|
66 |
+
if torch_dtype and hasattr(torch, torch_dtype) and type(getattr(torch, torch_dtype)) is torch.dtype:
|
67 |
+
self.torch_dtype = getattr(torch, torch_dtype)
|
68 |
+
else:
|
69 |
+
self.torch_dtype = torch_dtype
|
embedding.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/embedding.py
|
2 |
+
# Commit id: f1a73d074002226c42ce65a1df170ecff9f022c0
|
3 |
+
|
4 |
+
# Copyright (c) 2022, Tri Dao.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
|
12 |
+
|
13 |
+
|
14 |
+
class XLMRobertaEmbeddings(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
embed_dim,
|
18 |
+
vocab_size,
|
19 |
+
max_position_embeddings,
|
20 |
+
type_vocab_size,
|
21 |
+
padding_idx=None,
|
22 |
+
device=None,
|
23 |
+
dtype=None,
|
24 |
+
):
|
25 |
+
"""
|
26 |
+
If max_position_embeddings <= 0, there's no position embeddings
|
27 |
+
If type_vocab_size <= 0, there's no token type embeddings
|
28 |
+
"""
|
29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
30 |
+
super().__init__()
|
31 |
+
self.word_embeddings = nn.Embedding(
|
32 |
+
vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs
|
33 |
+
)
|
34 |
+
self.max_position_embeddings = max_position_embeddings
|
35 |
+
self.type_vocab_size = type_vocab_size
|
36 |
+
if self.max_position_embeddings > 0:
|
37 |
+
self.position_embeddings = nn.Embedding(
|
38 |
+
max_position_embeddings, embed_dim, **factory_kwargs
|
39 |
+
)
|
40 |
+
if self.type_vocab_size > 0:
|
41 |
+
self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
|
42 |
+
|
43 |
+
def forward(self, input_ids, position_ids=None, token_type_ids=None):
|
44 |
+
"""
|
45 |
+
input_ids: (batch, seqlen)
|
46 |
+
position_ids: (batch, seqlen)
|
47 |
+
token_type_ids: (batch, seqlen)
|
48 |
+
"""
|
49 |
+
batch_size, seqlen = input_ids.shape
|
50 |
+
embeddings = self.word_embeddings(input_ids)
|
51 |
+
if self.max_position_embeddings > 0:
|
52 |
+
if position_ids is None:
|
53 |
+
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
|
54 |
+
# position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
|
55 |
+
position_embeddings = self.position_embeddings(position_ids)
|
56 |
+
embeddings = embeddings + position_embeddings
|
57 |
+
if self.type_vocab_size > 0:
|
58 |
+
if token_type_ids is None:
|
59 |
+
token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
|
60 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
61 |
+
embeddings = embeddings + token_type_embeddings
|
62 |
+
return embeddings
|
mha.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Tri Dao.
|
2 |
+
# Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
|
3 |
+
|
4 |
+
import math
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
try:
|
12 |
+
from flash_attn import (
|
13 |
+
flash_attn_kvpacked_func,
|
14 |
+
flash_attn_qkvpacked_func,
|
15 |
+
flash_attn_varlen_kvpacked_func,
|
16 |
+
flash_attn_varlen_qkvpacked_func,
|
17 |
+
flash_attn_with_kvcache,
|
18 |
+
)
|
19 |
+
except ImportError:
|
20 |
+
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
21 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
22 |
+
flash_attn_with_kvcache = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
|
26 |
+
except ImportError:
|
27 |
+
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
28 |
+
|
29 |
+
|
30 |
+
class FlashSelfAttention(nn.Module):
|
31 |
+
"""Implement the scaled dot product attention with softmax.
|
32 |
+
Arguments
|
33 |
+
---------
|
34 |
+
softmax_scale: The temperature to use for the softmax attention.
|
35 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
36 |
+
runtime)
|
37 |
+
attention_dropout: The dropout rate to apply to the attention
|
38 |
+
(default: 0.0)
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
causal=False,
|
44 |
+
softmax_scale=None,
|
45 |
+
attention_dropout=0.0,
|
46 |
+
window_size=(-1, -1),
|
47 |
+
deterministic=False,
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
51 |
+
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
52 |
+
self.causal = causal
|
53 |
+
self.softmax_scale = softmax_scale
|
54 |
+
self.drop = nn.Dropout(attention_dropout)
|
55 |
+
self.window_size = window_size
|
56 |
+
self.deterministic = deterministic
|
57 |
+
|
58 |
+
def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None):
|
59 |
+
"""Implements the multihead softmax attention.
|
60 |
+
Arguments
|
61 |
+
---------
|
62 |
+
qkv: The tensor containing the query, key, and value.
|
63 |
+
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
|
64 |
+
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
|
65 |
+
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
|
66 |
+
causal: if passed, will override self.causal
|
67 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
68 |
+
of the sequences in the batch, used to index into qkv.
|
69 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
70 |
+
Returns:
|
71 |
+
--------
|
72 |
+
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
|
73 |
+
else (B, S, H, D).
|
74 |
+
"""
|
75 |
+
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
76 |
+
assert qkv.is_cuda
|
77 |
+
causal = self.causal if causal is None else causal
|
78 |
+
unpadded = cu_seqlens is not None
|
79 |
+
|
80 |
+
if unpadded:
|
81 |
+
assert cu_seqlens.dtype == torch.int32
|
82 |
+
assert max_seqlen is not None
|
83 |
+
assert isinstance(max_seqlen, int)
|
84 |
+
return flash_attn_varlen_qkvpacked_func(
|
85 |
+
qkv,
|
86 |
+
cu_seqlens,
|
87 |
+
max_seqlen,
|
88 |
+
self.drop.p if self.training else 0.0,
|
89 |
+
softmax_scale=self.softmax_scale,
|
90 |
+
causal=causal,
|
91 |
+
alibi_slopes=None,
|
92 |
+
window_size=self.window_size,
|
93 |
+
deterministic=self.deterministic,
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
return flash_attn_qkvpacked_func(
|
97 |
+
qkv,
|
98 |
+
self.drop.p if self.training else 0.0,
|
99 |
+
softmax_scale=self.softmax_scale,
|
100 |
+
causal=causal,
|
101 |
+
alibi_slopes=None,
|
102 |
+
window_size=self.window_size,
|
103 |
+
deterministic=self.deterministic,
|
104 |
+
)
|
105 |
+
|
106 |
+
|
107 |
+
class FlashCrossAttention(nn.Module):
|
108 |
+
"""Implement the scaled dot product attention with softmax.
|
109 |
+
Arguments
|
110 |
+
---------
|
111 |
+
softmax_scale: The temperature to use for the softmax attention.
|
112 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
113 |
+
runtime)
|
114 |
+
attention_dropout: The dropout rate to apply to the attention
|
115 |
+
(default: 0.0)
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
causal=False,
|
121 |
+
softmax_scale=None,
|
122 |
+
attention_dropout=0.0,
|
123 |
+
window_size=(-1, -1),
|
124 |
+
deterministic=False,
|
125 |
+
):
|
126 |
+
super().__init__()
|
127 |
+
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
128 |
+
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
129 |
+
self.causal = causal
|
130 |
+
self.softmax_scale = softmax_scale
|
131 |
+
self.drop = nn.Dropout(attention_dropout)
|
132 |
+
self.window_size = window_size
|
133 |
+
self.deterministic = deterministic
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
q,
|
138 |
+
kv,
|
139 |
+
causal=None,
|
140 |
+
cu_seqlens=None,
|
141 |
+
max_seqlen=None,
|
142 |
+
cu_seqlens_k=None,
|
143 |
+
max_seqlen_k=None,
|
144 |
+
):
|
145 |
+
"""Implements the multihead softmax attention.
|
146 |
+
Arguments
|
147 |
+
---------
|
148 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
149 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
150 |
+
causal: if passed, will override self.causal
|
151 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
152 |
+
of the sequences in the batch, used to index into q.
|
153 |
+
max_seqlen: int. Maximum sequence length in the batch of q.
|
154 |
+
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
155 |
+
of the sequences in the batch, used to index into kv.
|
156 |
+
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
|
157 |
+
"""
|
158 |
+
assert q.dtype in [torch.float16, torch.bfloat16]
|
159 |
+
assert q.is_cuda and kv.is_cuda
|
160 |
+
causal = self.causal if causal is None else causal
|
161 |
+
unpadded = cu_seqlens is not None
|
162 |
+
|
163 |
+
if unpadded:
|
164 |
+
assert cu_seqlens.dtype == torch.int32
|
165 |
+
assert max_seqlen is not None
|
166 |
+
assert isinstance(max_seqlen, int)
|
167 |
+
assert cu_seqlens_k is not None
|
168 |
+
assert cu_seqlens_k.dtype == torch.int32
|
169 |
+
assert max_seqlen_k is not None
|
170 |
+
assert isinstance(max_seqlen, int)
|
171 |
+
return flash_attn_varlen_kvpacked_func(
|
172 |
+
q,
|
173 |
+
kv,
|
174 |
+
cu_seqlens,
|
175 |
+
cu_seqlens_k,
|
176 |
+
max_seqlen,
|
177 |
+
max_seqlen_k,
|
178 |
+
self.drop.p if self.training else 0.0,
|
179 |
+
softmax_scale=self.softmax_scale,
|
180 |
+
causal=causal,
|
181 |
+
alibi_slopes=None,
|
182 |
+
window_size=self.window_size,
|
183 |
+
deterministic=self.deterministic,
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
187 |
+
seqlen_k = kv.shape[1]
|
188 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
189 |
+
return flash_attn_kvpacked_func(
|
190 |
+
q,
|
191 |
+
kv,
|
192 |
+
self.drop.p if self.training else 0.0,
|
193 |
+
causal=causal,
|
194 |
+
softmax_scale=self.softmax_scale,
|
195 |
+
alibi_slopes=None,
|
196 |
+
window_size=self.window_size,
|
197 |
+
deterministic=self.deterministic,
|
198 |
+
)
|
199 |
+
|
200 |
+
|
201 |
+
class SelfAttention(nn.Module):
|
202 |
+
"""Implement the scaled dot product attention with softmax.
|
203 |
+
Arguments
|
204 |
+
---------
|
205 |
+
softmax_scale: The temperature to use for the softmax attention.
|
206 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
207 |
+
runtime)
|
208 |
+
attention_dropout: The dropout rate to apply to the attention
|
209 |
+
(default: 0.0)
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
213 |
+
super().__init__()
|
214 |
+
self.causal = causal
|
215 |
+
self.softmax_scale = softmax_scale
|
216 |
+
self.drop = nn.Dropout(attention_dropout)
|
217 |
+
|
218 |
+
def forward(self, qkv, causal=None, key_padding_mask=None):
|
219 |
+
"""Implements the multihead softmax attention.
|
220 |
+
Arguments
|
221 |
+
---------
|
222 |
+
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
|
223 |
+
causal: if passed, will override self.causal
|
224 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
225 |
+
False means to mask out. (B, S)
|
226 |
+
"""
|
227 |
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
228 |
+
causal = self.causal if causal is None else causal
|
229 |
+
q, k, v = qkv.unbind(dim=2)
|
230 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
231 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
232 |
+
if key_padding_mask is not None:
|
233 |
+
padding_mask = torch.full(
|
234 |
+
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
|
235 |
+
)
|
236 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
237 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
238 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
239 |
+
if causal:
|
240 |
+
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
241 |
+
# So we have to construct the mask in float
|
242 |
+
causal_mask = torch.triu(
|
243 |
+
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
|
244 |
+
)
|
245 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
246 |
+
scores = scores + causal_mask.to(dtype=scores.dtype)
|
247 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
248 |
+
attention_drop = self.drop(attention)
|
249 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
250 |
+
return output
|
251 |
+
|
252 |
+
|
253 |
+
class CrossAttention(nn.Module):
|
254 |
+
"""Implement the scaled dot product attention with softmax.
|
255 |
+
Arguments
|
256 |
+
---------
|
257 |
+
softmax_scale: The temperature to use for the softmax attention.
|
258 |
+
(default: 1/sqrt(d_keys) where d_keys is computed at
|
259 |
+
runtime)
|
260 |
+
attention_dropout: The dropout rate to apply to the attention
|
261 |
+
(default: 0.0)
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
265 |
+
super().__init__()
|
266 |
+
self.causal = causal
|
267 |
+
self.softmax_scale = softmax_scale
|
268 |
+
self.drop = nn.Dropout(attention_dropout)
|
269 |
+
|
270 |
+
def forward(self, q, kv, causal=None, key_padding_mask=None):
|
271 |
+
"""Implements the multihead softmax attention.
|
272 |
+
Arguments
|
273 |
+
---------
|
274 |
+
q: The tensor containing the query. (B, Sq, H, D)
|
275 |
+
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
|
276 |
+
causal: if passed, will override self.causal
|
277 |
+
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
|
278 |
+
False means to mask out. (B, Sk)
|
279 |
+
"""
|
280 |
+
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
281 |
+
causal = self.causal if causal is None else causal
|
282 |
+
seqlen_k = kv.shape[1]
|
283 |
+
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
284 |
+
if kv.shape[3] != q.shape[2]: # MQA/GQA
|
285 |
+
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
286 |
+
k, v = kv.unbind(dim=2)
|
287 |
+
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
288 |
+
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
289 |
+
if key_padding_mask is not None:
|
290 |
+
padding_mask = torch.full(
|
291 |
+
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
|
292 |
+
)
|
293 |
+
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
294 |
+
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
295 |
+
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
296 |
+
if causal:
|
297 |
+
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
|
298 |
+
row_idx = rearrange(
|
299 |
+
torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1"
|
300 |
+
)
|
301 |
+
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
|
302 |
+
sk = (
|
303 |
+
seqlen_k
|
304 |
+
if key_padding_mask is None
|
305 |
+
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
306 |
+
)
|
307 |
+
causal_mask = col_idx > row_idx + sk - seqlen_q
|
308 |
+
scores = scores.masked_fill(causal_mask, -10000.0)
|
309 |
+
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
310 |
+
attention_drop = self.drop(attention)
|
311 |
+
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
312 |
+
return output
|
313 |
+
|
314 |
+
|
315 |
+
class LinearResidual(nn.Linear):
|
316 |
+
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
317 |
+
|
318 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
319 |
+
return super().forward(input), input
|
320 |
+
|
321 |
+
|
322 |
+
def _update_kv_cache(kv, inference_params, layer_idx):
|
323 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
324 |
+
# Pre-allocate memory for key-values for inference.
|
325 |
+
num_heads, head_dim = kv.shape[-2:]
|
326 |
+
if layer_idx not in inference_params.key_value_memory_dict:
|
327 |
+
kv_cache = torch.empty(
|
328 |
+
inference_params.max_batch_size,
|
329 |
+
inference_params.max_seqlen,
|
330 |
+
2,
|
331 |
+
num_heads,
|
332 |
+
head_dim,
|
333 |
+
dtype=kv.dtype,
|
334 |
+
device=kv.device,
|
335 |
+
)
|
336 |
+
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
337 |
+
else:
|
338 |
+
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
339 |
+
# Adjust key and value for inference
|
340 |
+
batch_start = inference_params.batch_size_offset
|
341 |
+
batch_end = batch_start + kv.shape[0]
|
342 |
+
sequence_start = inference_params.seqlen_offset
|
343 |
+
sequence_end = sequence_start + kv.shape[1]
|
344 |
+
assert batch_end <= kv_cache.shape[0]
|
345 |
+
assert sequence_end <= kv_cache.shape[1]
|
346 |
+
assert kv_cache is not None
|
347 |
+
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
348 |
+
return kv_cache[batch_start:batch_end, :sequence_end, ...]
|
349 |
+
|
350 |
+
|
351 |
+
class MHA(nn.Module):
|
352 |
+
"""Multi-head self-attention and cross-attention"""
|
353 |
+
|
354 |
+
def __init__(
|
355 |
+
self,
|
356 |
+
embed_dim,
|
357 |
+
num_heads,
|
358 |
+
num_heads_kv=None,
|
359 |
+
cross_attn=False,
|
360 |
+
qkv_proj_bias=True,
|
361 |
+
out_proj_bias=True,
|
362 |
+
dropout=0.0,
|
363 |
+
softmax_scale=None,
|
364 |
+
causal=False,
|
365 |
+
layer_idx=None,
|
366 |
+
dwconv=False,
|
367 |
+
window_size=(-1, -1),
|
368 |
+
fused_bias_fc=False,
|
369 |
+
use_flash_attn=False,
|
370 |
+
return_residual=False,
|
371 |
+
checkpointing=False,
|
372 |
+
device=None,
|
373 |
+
dtype=None,
|
374 |
+
) -> None:
|
375 |
+
"""
|
376 |
+
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
377 |
+
return_residual: whether to return the input x along with the output. This is for
|
378 |
+
performance reason: for post-norm architecture, returning the input allows us
|
379 |
+
to fuse the backward of nn.Linear with the residual connection.
|
380 |
+
"""
|
381 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
382 |
+
super().__init__()
|
383 |
+
self.embed_dim = embed_dim
|
384 |
+
self.cross_attn = cross_attn
|
385 |
+
self.causal = causal
|
386 |
+
self.layer_idx = layer_idx
|
387 |
+
self.dwconv = dwconv
|
388 |
+
self.use_flash_attn = use_flash_attn
|
389 |
+
self.return_residual = return_residual
|
390 |
+
self.checkpointing = checkpointing
|
391 |
+
|
392 |
+
if window_size != (-1, -1):
|
393 |
+
assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
|
394 |
+
|
395 |
+
self.num_heads = num_heads
|
396 |
+
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
397 |
+
assert (
|
398 |
+
self.num_heads % self.num_heads_kv == 0
|
399 |
+
), "num_heads must be divisible by num_heads_kv"
|
400 |
+
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
401 |
+
self.head_dim = self.embed_dim // num_heads
|
402 |
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
403 |
+
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
404 |
+
|
405 |
+
if fused_bias_fc and FusedDense is None:
|
406 |
+
raise ImportError("fused_dense is not installed")
|
407 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
408 |
+
linear_resid_cls = (
|
409 |
+
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
410 |
+
)
|
411 |
+
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
412 |
+
inner_attn_cls = (
|
413 |
+
partial(FlashSelfAttention, window_size=window_size)
|
414 |
+
if use_flash_attn
|
415 |
+
else SelfAttention
|
416 |
+
)
|
417 |
+
inner_cross_attn_cls = (
|
418 |
+
partial(FlashCrossAttention, window_size=window_size)
|
419 |
+
if use_flash_attn
|
420 |
+
else CrossAttention
|
421 |
+
)
|
422 |
+
if not self.cross_attn:
|
423 |
+
self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
424 |
+
else:
|
425 |
+
self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
|
426 |
+
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
427 |
+
if self.dwconv:
|
428 |
+
if self.num_heads_kv == self.num_heads:
|
429 |
+
self.dwconv_qkv = nn.Conv1d(
|
430 |
+
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
self.dwconv_q = nn.Conv1d(
|
434 |
+
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
435 |
+
)
|
436 |
+
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
437 |
+
self.inner_attn = inner_attn_cls(
|
438 |
+
causal=causal,
|
439 |
+
softmax_scale=softmax_scale,
|
440 |
+
attention_dropout=dropout,
|
441 |
+
)
|
442 |
+
self.inner_cross_attn = inner_cross_attn_cls(
|
443 |
+
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
444 |
+
)
|
445 |
+
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
446 |
+
|
447 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
|
448 |
+
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
449 |
+
device = self.out_proj.weight.device
|
450 |
+
return torch.empty(
|
451 |
+
batch_size,
|
452 |
+
max_seqlen,
|
453 |
+
2,
|
454 |
+
self.num_heads_kv,
|
455 |
+
self.head_dim,
|
456 |
+
dtype=dtype,
|
457 |
+
device=device,
|
458 |
+
)
|
459 |
+
|
460 |
+
def _update_kv_cache(self, kv, inference_params):
|
461 |
+
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
462 |
+
assert not self.dwconv, "Generation does not support dwconv yet"
|
463 |
+
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
464 |
+
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
465 |
+
|
466 |
+
def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
|
467 |
+
"""
|
468 |
+
Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention.
|
469 |
+
q: (batch_size, seqlen_q, nheads, head_dim)
|
470 |
+
kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim)
|
471 |
+
"""
|
472 |
+
assert inference_params is not None and inference_params.seqlen_offset > 0
|
473 |
+
assert self.use_flash_attn
|
474 |
+
batch = q.shape[0]
|
475 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
476 |
+
cache_seqlens = (
|
477 |
+
inference_params.lengths_per_sample[:batch]
|
478 |
+
if inference_params.lengths_per_sample is not None
|
479 |
+
else inference_params.seqlen_offset
|
480 |
+
)
|
481 |
+
context = flash_attn_with_kvcache(
|
482 |
+
q,
|
483 |
+
kv_cache[:, :, 0],
|
484 |
+
kv_cache[:, :, 1],
|
485 |
+
kv[:, :, 0],
|
486 |
+
kv[:, :, 1],
|
487 |
+
cache_seqlens=cache_seqlens,
|
488 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
489 |
+
causal=self.inner_cross_attn.causal,
|
490 |
+
rotary_interleaved=False,
|
491 |
+
alibi_slopes=None,
|
492 |
+
)
|
493 |
+
return context
|
494 |
+
|
495 |
+
def _update_kvcache_attention(self, q, kv, inference_params):
|
496 |
+
"""Write kv to inference_params, then do attention"""
|
497 |
+
if (
|
498 |
+
inference_params.seqlen_offset == 0
|
499 |
+
or flash_attn_with_kvcache is None
|
500 |
+
or not self.use_flash_attn
|
501 |
+
):
|
502 |
+
# TODO: this only uses seqlen_offset and not lengths_per_sample.
|
503 |
+
kv = self._update_kv_cache(kv, inference_params)
|
504 |
+
return self.inner_cross_attn(q, kv)
|
505 |
+
else:
|
506 |
+
batch = q.shape[0]
|
507 |
+
kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch]
|
508 |
+
cache_seqlens = (
|
509 |
+
inference_params.lengths_per_sample[:batch]
|
510 |
+
if inference_params.lengths_per_sample is not None
|
511 |
+
else inference_params.seqlen_offset
|
512 |
+
)
|
513 |
+
return flash_attn_with_kvcache(
|
514 |
+
q,
|
515 |
+
kv_cache[:, :, 0],
|
516 |
+
kv_cache[:, :, 1],
|
517 |
+
kv[:, :, 0],
|
518 |
+
kv[:, :, 1],
|
519 |
+
cache_seqlens=cache_seqlens,
|
520 |
+
softmax_scale=self.inner_cross_attn.softmax_scale,
|
521 |
+
causal=self.inner_cross_attn.causal,
|
522 |
+
alibi_slopes=None,
|
523 |
+
)
|
524 |
+
|
525 |
+
def forward(
|
526 |
+
self,
|
527 |
+
x,
|
528 |
+
x_kv=None,
|
529 |
+
key_padding_mask=None,
|
530 |
+
cu_seqlens=None,
|
531 |
+
max_seqlen=None,
|
532 |
+
mixer_subset=None,
|
533 |
+
inference_params=None,
|
534 |
+
**kwargs,
|
535 |
+
):
|
536 |
+
"""
|
537 |
+
Arguments:
|
538 |
+
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
539 |
+
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
|
540 |
+
is the is the sum of the sequence lengths in the batch.
|
541 |
+
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
|
542 |
+
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
|
543 |
+
of the sequences in the batch, used to index into x. Only applicable when using
|
544 |
+
FlashAttention.
|
545 |
+
max_seqlen: int. Maximum sequence length in the batch.
|
546 |
+
key_padding_mask: boolean mask, True means to keep, False means to mask out.
|
547 |
+
(batch, seqlen). Only applicable when not using FlashAttention.
|
548 |
+
mixer_subset: for cross-attention only. If not None, will take a subset of x
|
549 |
+
before applying the query projection. Useful for e.g., ViT where we only care
|
550 |
+
about the CLS token in the last layer.
|
551 |
+
inference_params: for generation. Adapted from Megatron-LM (and Apex)
|
552 |
+
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
|
553 |
+
"""
|
554 |
+
if cu_seqlens is not None:
|
555 |
+
assert max_seqlen is not None
|
556 |
+
assert key_padding_mask is None
|
557 |
+
assert self.use_flash_attn
|
558 |
+
assert not self.dwconv
|
559 |
+
if key_padding_mask is not None:
|
560 |
+
assert cu_seqlens is None
|
561 |
+
assert max_seqlen is None
|
562 |
+
assert not self.use_flash_attn
|
563 |
+
if inference_params is not None:
|
564 |
+
assert key_padding_mask is None
|
565 |
+
assert cu_seqlens is None and max_seqlen is None
|
566 |
+
assert not self.dwconv
|
567 |
+
|
568 |
+
kwargs = (
|
569 |
+
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
570 |
+
if self.use_flash_attn
|
571 |
+
else {"key_padding_mask": key_padding_mask, **kwargs}
|
572 |
+
)
|
573 |
+
seqlen_offset = (
|
574 |
+
0
|
575 |
+
if inference_params is None
|
576 |
+
else (
|
577 |
+
inference_params.lengths_per_sample
|
578 |
+
if inference_params.lengths_per_sample is not None
|
579 |
+
else inference_params.seqlen_offset
|
580 |
+
)
|
581 |
+
)
|
582 |
+
rotary_max_seqlen = (
|
583 |
+
inference_params.max_sequence_len if inference_params is not None else max_seqlen
|
584 |
+
)
|
585 |
+
batch, seqlen = x.shape[:2]
|
586 |
+
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
587 |
+
assert x_kv is None and mixer_subset is None
|
588 |
+
if not self.return_residual:
|
589 |
+
qkv = self.Wqkv(x)
|
590 |
+
else:
|
591 |
+
qkv, x = self.Wqkv(x)
|
592 |
+
if self.dwconv:
|
593 |
+
qkv = rearrange(
|
594 |
+
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
595 |
+
).contiguous()
|
596 |
+
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
597 |
+
if (
|
598 |
+
inference_params is None
|
599 |
+
or inference_params.seqlen_offset == 0
|
600 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
601 |
+
or not self.use_flash_attn
|
602 |
+
):
|
603 |
+
if inference_params is None:
|
604 |
+
if not self.checkpointing:
|
605 |
+
context = self.inner_attn(qkv, **kwargs)
|
606 |
+
else:
|
607 |
+
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
608 |
+
else:
|
609 |
+
context = self._update_kvcache_attention(
|
610 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
611 |
+
)
|
612 |
+
else:
|
613 |
+
context = self._apply_rotary_update_kvcache_attention(
|
614 |
+
qkv[:, :, 0], qkv[:, :, 1:], inference_params
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
if self.cross_attn:
|
618 |
+
if not self.return_residual:
|
619 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
620 |
+
kv = self.Wkv(x_kv if x_kv is not None else x)
|
621 |
+
else:
|
622 |
+
if x_kv is not None:
|
623 |
+
kv, x_kv = self.Wkv(x_kv)
|
624 |
+
else:
|
625 |
+
kv, x = self.Wkv(x)
|
626 |
+
q = self.Wq(x if mixer_subset is None else x[:, mixer_subset])
|
627 |
+
else:
|
628 |
+
assert self.num_heads_kv != self.num_heads
|
629 |
+
if not self.return_residual:
|
630 |
+
qkv = self.Wqkv(x)
|
631 |
+
else:
|
632 |
+
qkv, x = self.Wqkv(x)
|
633 |
+
q = qkv[..., : self.num_heads * self.head_dim]
|
634 |
+
kv = qkv[..., self.num_heads * self.head_dim :]
|
635 |
+
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
636 |
+
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
637 |
+
if self.dwconv:
|
638 |
+
q = rearrange(
|
639 |
+
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
640 |
+
).contiguous()
|
641 |
+
kv = rearrange(
|
642 |
+
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
643 |
+
).contiguous()
|
644 |
+
if (
|
645 |
+
inference_params is None
|
646 |
+
or inference_params.seqlen_offset == 0
|
647 |
+
or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0)
|
648 |
+
or not self.use_flash_attn
|
649 |
+
):
|
650 |
+
if inference_params is None:
|
651 |
+
if not self.checkpointing:
|
652 |
+
context = self.inner_cross_attn(q, kv, **kwargs)
|
653 |
+
else:
|
654 |
+
context = torch.utils.checkpoint.checkpoint(
|
655 |
+
self.inner_cross_attn, q, kv, **kwargs
|
656 |
+
)
|
657 |
+
else:
|
658 |
+
context = self._update_kvcache_attention(q, kv, inference_params)
|
659 |
+
else:
|
660 |
+
context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
|
661 |
+
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
662 |
+
return out if not self.return_residual else (out, x)
|
mlp.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mlp.py
|
2 |
+
# Commit id: c3b219665292c61a51153d0ded4473c494296382
|
3 |
+
|
4 |
+
# Copyright (c) 2023, Tri Dao.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.distributed import ProcessGroup
|
10 |
+
|
11 |
+
|
12 |
+
try:
|
13 |
+
from flash_attn.ops.activations import swiglu
|
14 |
+
except ImportError:
|
15 |
+
swiglu = None
|
16 |
+
|
17 |
+
try:
|
18 |
+
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
19 |
+
except ImportError:
|
20 |
+
ColumnParallelLinear, RowParallelLinear = None, None
|
21 |
+
|
22 |
+
try:
|
23 |
+
from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP
|
24 |
+
except ImportError:
|
25 |
+
FusedMLP, ParallelFusedMLP = None, None
|
26 |
+
|
27 |
+
|
28 |
+
class Mlp(nn.Module):
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
in_features,
|
32 |
+
hidden_features=None,
|
33 |
+
out_features=None,
|
34 |
+
activation=F.gelu,
|
35 |
+
bias1=True,
|
36 |
+
bias2=True,
|
37 |
+
return_residual=False,
|
38 |
+
device=None,
|
39 |
+
dtype=None,
|
40 |
+
):
|
41 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
42 |
+
super().__init__()
|
43 |
+
out_features = out_features if out_features is not None else in_features
|
44 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
45 |
+
self.return_residual = return_residual
|
46 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
47 |
+
self.activation = activation
|
48 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
y = self.fc1(x)
|
52 |
+
y = self.activation(y)
|
53 |
+
y = self.fc2(y)
|
54 |
+
return y if not self.return_residual else (y, x)
|
55 |
+
|
56 |
+
|
57 |
+
class ParallelMLP(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
in_features,
|
61 |
+
hidden_features=None,
|
62 |
+
out_features=None,
|
63 |
+
activation=F.gelu,
|
64 |
+
process_group: ProcessGroup = None,
|
65 |
+
sequence_parallel=True,
|
66 |
+
bias1=True,
|
67 |
+
bias2=True,
|
68 |
+
device=None,
|
69 |
+
dtype=None,
|
70 |
+
):
|
71 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
72 |
+
super().__init__()
|
73 |
+
assert ColumnParallelLinear is not None, "Need to install fused_dense"
|
74 |
+
assert RowParallelLinear is not None, "Need to install fused_dense"
|
75 |
+
out_features = out_features if out_features is not None else in_features
|
76 |
+
hidden_features = hidden_features if hidden_features is not None else in_features * 4
|
77 |
+
self.fc1 = ColumnParallelLinear(
|
78 |
+
in_features,
|
79 |
+
hidden_features,
|
80 |
+
process_group,
|
81 |
+
bias=bias1,
|
82 |
+
sequence_parallel=sequence_parallel,
|
83 |
+
**factory_kwargs,
|
84 |
+
)
|
85 |
+
self.activation = activation
|
86 |
+
self.fc2 = RowParallelLinear(
|
87 |
+
hidden_features,
|
88 |
+
out_features,
|
89 |
+
process_group,
|
90 |
+
bias=bias2,
|
91 |
+
sequence_parallel=sequence_parallel,
|
92 |
+
**factory_kwargs,
|
93 |
+
)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
y = self.fc1(x)
|
97 |
+
y = self.activation(y)
|
98 |
+
y = self.fc2(y)
|
99 |
+
return y
|
100 |
+
|
101 |
+
|
102 |
+
class GatedMlp(nn.Module):
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
in_features,
|
106 |
+
hidden_features=None,
|
107 |
+
out_features=None,
|
108 |
+
activation=F.sigmoid,
|
109 |
+
bias1=True,
|
110 |
+
bias2=True,
|
111 |
+
multiple_of=128,
|
112 |
+
return_residual=False,
|
113 |
+
device=None,
|
114 |
+
dtype=None,
|
115 |
+
):
|
116 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
117 |
+
super().__init__()
|
118 |
+
out_features = out_features if out_features is not None else in_features
|
119 |
+
hidden_features = (
|
120 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
121 |
+
)
|
122 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
123 |
+
self.return_residual = return_residual
|
124 |
+
self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
|
125 |
+
self.activation = activation
|
126 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
y = self.fc1(x)
|
130 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
131 |
+
y = F.glu(y, dim=-1)
|
132 |
+
elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
|
133 |
+
y, gate = y.chunk(2, dim=-1)
|
134 |
+
y = swiglu(gate, y)
|
135 |
+
else:
|
136 |
+
y, gate = y.chunk(2, dim=-1)
|
137 |
+
y = y * self.activation(gate)
|
138 |
+
y = self.fc2(y)
|
139 |
+
return y if not self.return_residual else (y, x)
|
140 |
+
|
141 |
+
|
142 |
+
class ParallelGatedMlp(nn.Module):
|
143 |
+
"""Parallel GatedMlp"""
|
144 |
+
|
145 |
+
def __init__(
|
146 |
+
self,
|
147 |
+
in_features,
|
148 |
+
process_group,
|
149 |
+
hidden_features=None,
|
150 |
+
out_features=None,
|
151 |
+
activation=F.sigmoid,
|
152 |
+
bias1=True,
|
153 |
+
bias2=True,
|
154 |
+
multiple_of=128,
|
155 |
+
sequence_parallel=True,
|
156 |
+
device=None,
|
157 |
+
dtype=None,
|
158 |
+
):
|
159 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
160 |
+
super().__init__()
|
161 |
+
out_features = out_features if out_features is not None else in_features
|
162 |
+
hidden_features = (
|
163 |
+
hidden_features if hidden_features is not None else int(8 * in_features / 3)
|
164 |
+
)
|
165 |
+
hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
|
166 |
+
if ColumnParallelLinear is None or RowParallelLinear is None:
|
167 |
+
raise ImportError("fused_dense is not installed")
|
168 |
+
self.fc1 = ColumnParallelLinear(
|
169 |
+
in_features,
|
170 |
+
2 * hidden_features,
|
171 |
+
process_group,
|
172 |
+
bias=bias1,
|
173 |
+
sequence_parallel=sequence_parallel,
|
174 |
+
**factory_kwargs,
|
175 |
+
)
|
176 |
+
self.activation = activation
|
177 |
+
self.fc2 = RowParallelLinear(
|
178 |
+
hidden_features,
|
179 |
+
out_features,
|
180 |
+
process_group,
|
181 |
+
bias=bias2,
|
182 |
+
sequence_parallel=sequence_parallel,
|
183 |
+
**factory_kwargs,
|
184 |
+
)
|
185 |
+
|
186 |
+
def forward(self, x):
|
187 |
+
y = self.fc1(x)
|
188 |
+
if self.activation == F.sigmoid: # Special case for GLU
|
189 |
+
y = F.glu(y, dim=-1)
|
190 |
+
else:
|
191 |
+
y, gate = y.chunk(2, dim=-1)
|
192 |
+
y = y * self.activation(gate)
|
193 |
+
y = self.fc2(y)
|
194 |
+
return y
|
modeling_xlm_roberta.py
ADDED
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation was adopted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/bert.py
|
2 |
+
# Commit id: abbc1311731867310635f9edc2a9ec18317c8c48
|
3 |
+
# Copyright (c) 2022, Tri Dao.
|
4 |
+
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
|
5 |
+
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
6 |
+
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
7 |
+
|
8 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
9 |
+
|
10 |
+
import importlib.util
|
11 |
+
import logging
|
12 |
+
import re
|
13 |
+
from collections import OrderedDict
|
14 |
+
from collections.abc import Sequence
|
15 |
+
from functools import partial
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.utils.checkpoint
|
22 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
23 |
+
from einops import rearrange
|
24 |
+
from transformers import PretrainedConfig
|
25 |
+
from transformers.modeling_utils import PreTrainedModel
|
26 |
+
from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
|
27 |
+
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
|
28 |
+
|
29 |
+
from transformers.models.bert.modeling_bert import (
|
30 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
31 |
+
BertForPreTrainingOutput,
|
32 |
+
)
|
33 |
+
|
34 |
+
from typing import List, Optional, Tuple, Union
|
35 |
+
|
36 |
+
from .xlm_padding import (
|
37 |
+
index_first_axis,
|
38 |
+
index_first_axis_residual,
|
39 |
+
pad_input,
|
40 |
+
unpad_input,
|
41 |
+
)
|
42 |
+
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
43 |
+
from .block import Block
|
44 |
+
from .embedding import XLMRobertaEmbeddings
|
45 |
+
from .mha import MHA
|
46 |
+
from .mlp import FusedMLP, Mlp
|
47 |
+
|
48 |
+
try:
|
49 |
+
from flash_attn.ops.fused_dense import FusedDense
|
50 |
+
except ImportError:
|
51 |
+
FusedDense = None
|
52 |
+
|
53 |
+
try:
|
54 |
+
from flash_attn.ops.triton.layer_norm import layer_norm_fn
|
55 |
+
except ImportError:
|
56 |
+
layer_norm_fn = None
|
57 |
+
|
58 |
+
|
59 |
+
try:
|
60 |
+
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
61 |
+
except ImportError:
|
62 |
+
CrossEntropyLoss = torch.nn.CrossEntropyLoss
|
63 |
+
|
64 |
+
try:
|
65 |
+
from tqdm.autonotebook import trange
|
66 |
+
except ImportError:
|
67 |
+
trange = None
|
68 |
+
|
69 |
+
|
70 |
+
logger = logging.getLogger(__name__)
|
71 |
+
|
72 |
+
|
73 |
+
def get_use_flash_attn(config: XLMRobertaFlashConfig):
|
74 |
+
if not getattr(config, "use_flash_attn", False):
|
75 |
+
return False
|
76 |
+
if not torch.cuda.is_available():
|
77 |
+
return False
|
78 |
+
if importlib.util.find_spec("flash_attn") is None:
|
79 |
+
logger.warning(
|
80 |
+
'flash_attn is not installed. Using PyTorch native attention implementation.'
|
81 |
+
)
|
82 |
+
return False
|
83 |
+
return True
|
84 |
+
|
85 |
+
|
86 |
+
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
87 |
+
use_flash_attn = get_use_flash_attn(config)
|
88 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
89 |
+
|
90 |
+
mixer_cls = partial(
|
91 |
+
MHA,
|
92 |
+
num_heads=config.num_attention_heads,
|
93 |
+
cross_attn=cross_attn,
|
94 |
+
dropout=config.attention_probs_dropout_prob,
|
95 |
+
causal=False,
|
96 |
+
fused_bias_fc=fused_bias_fc,
|
97 |
+
use_flash_attn=use_flash_attn,
|
98 |
+
return_residual=return_residual,
|
99 |
+
)
|
100 |
+
return mixer_cls
|
101 |
+
|
102 |
+
|
103 |
+
def create_mlp_cls(config, layer_idx=None, return_residual=False):
|
104 |
+
inner_dim = config.intermediate_size
|
105 |
+
fused_mlp = getattr(config, "fused_mlp", False)
|
106 |
+
if fused_mlp:
|
107 |
+
assert config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"], (
|
108 |
+
"fused_mlp only " "supports approximate gelu"
|
109 |
+
)
|
110 |
+
if not fused_mlp:
|
111 |
+
approximate = (
|
112 |
+
"tanh"
|
113 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
114 |
+
else "none"
|
115 |
+
)
|
116 |
+
mlp_cls = partial(
|
117 |
+
Mlp,
|
118 |
+
hidden_features=inner_dim,
|
119 |
+
activation=partial(F.gelu, approximate=approximate),
|
120 |
+
return_residual=return_residual,
|
121 |
+
)
|
122 |
+
else:
|
123 |
+
if FusedMLP is None:
|
124 |
+
raise ImportError("fused_dense is not installed")
|
125 |
+
mlp_checkpoint_lvl = getattr(config, "mlp_checkpoint_lvl", 0)
|
126 |
+
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
|
127 |
+
if isinstance(mlp_checkpoint_lvl, Sequence):
|
128 |
+
assert layer_idx is not None
|
129 |
+
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
|
130 |
+
mlp_cls = partial(
|
131 |
+
FusedMLP,
|
132 |
+
hidden_features=inner_dim,
|
133 |
+
checkpoint_lvl=mlp_checkpoint_lvl,
|
134 |
+
return_residual=return_residual,
|
135 |
+
)
|
136 |
+
return mlp_cls
|
137 |
+
|
138 |
+
|
139 |
+
def create_block(config, layer_idx=None):
|
140 |
+
last_layer_subset = getattr(config, "last_layer_subset", False)
|
141 |
+
cross_attn = last_layer_subset and layer_idx == config.num_hidden_layers - 1
|
142 |
+
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
|
143 |
+
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
|
144 |
+
# one layer) so we just choose not to return residual in this case.
|
145 |
+
return_residual = not cross_attn
|
146 |
+
mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual)
|
147 |
+
mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual)
|
148 |
+
norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
149 |
+
block = Block(
|
150 |
+
config.hidden_size,
|
151 |
+
mixer_cls,
|
152 |
+
mlp_cls,
|
153 |
+
norm_cls=norm_cls,
|
154 |
+
prenorm=False,
|
155 |
+
resid_dropout1=config.hidden_dropout_prob,
|
156 |
+
resid_dropout2=config.hidden_dropout_prob,
|
157 |
+
fused_dropout_add_ln=getattr(config, "fused_dropout_add_ln", False),
|
158 |
+
return_residual=return_residual,
|
159 |
+
)
|
160 |
+
return block
|
161 |
+
|
162 |
+
|
163 |
+
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
|
164 |
+
def _init_weights(module, initializer_range=0.02):
|
165 |
+
if isinstance(module, nn.Linear):
|
166 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
167 |
+
if module.bias is not None:
|
168 |
+
nn.init.zeros_(module.bias)
|
169 |
+
elif isinstance(module, nn.Embedding):
|
170 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
171 |
+
if module.padding_idx is not None:
|
172 |
+
nn.init.zeros_(module.weight[module.padding_idx])
|
173 |
+
|
174 |
+
|
175 |
+
class XLMRobertaEncoder(nn.Module):
|
176 |
+
def __init__(self, config: XLMRobertaFlashConfig):
|
177 |
+
super().__init__()
|
178 |
+
self.use_flash_attn = get_use_flash_attn(config)
|
179 |
+
self.layers = nn.ModuleList(
|
180 |
+
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
181 |
+
)
|
182 |
+
self._grad_checkpointing = False
|
183 |
+
|
184 |
+
@property
|
185 |
+
def gradient_checkpointing(self):
|
186 |
+
return self._grad_checkpointing
|
187 |
+
|
188 |
+
@gradient_checkpointing.setter
|
189 |
+
def gradient_checkpointing(self, value):
|
190 |
+
self._grad_checkpointing = value
|
191 |
+
|
192 |
+
def forward(self, hidden_states, key_padding_mask=None, subset_mask=None):
|
193 |
+
"""If subset_mask is not None, we only want output for the subset of the sequence.
|
194 |
+
This means that we only compute the last layer output for these tokens.
|
195 |
+
subset_mask: (batch, seqlen), dtype=torch.bool
|
196 |
+
"""
|
197 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
198 |
+
mixer_kwargs = (
|
199 |
+
{"key_padding_mask": key_padding_mask.bool()}
|
200 |
+
if key_padding_mask is not None
|
201 |
+
else None
|
202 |
+
)
|
203 |
+
for layer in self.layers:
|
204 |
+
if self._grad_checkpointing:
|
205 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
206 |
+
layer,
|
207 |
+
hidden_states,
|
208 |
+
use_reentrant=False,
|
209 |
+
mixer_kwargs=mixer_kwargs,
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
213 |
+
if subset_mask is not None:
|
214 |
+
hidden_states = hidden_states[subset_mask]
|
215 |
+
else:
|
216 |
+
batch, seqlen = hidden_states.shape[:2]
|
217 |
+
hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
218 |
+
hidden_states, key_padding_mask
|
219 |
+
)
|
220 |
+
mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
|
221 |
+
if subset_mask is None:
|
222 |
+
for layer in self.layers:
|
223 |
+
if self._grad_checkpointing:
|
224 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
225 |
+
layer,
|
226 |
+
hidden_states,
|
227 |
+
use_reentrant=False,
|
228 |
+
mixer_kwargs=mixer_kwargs,
|
229 |
+
)
|
230 |
+
else:
|
231 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
232 |
+
hidden_states = pad_input(hidden_states, indices, batch, seqlen)
|
233 |
+
else:
|
234 |
+
for layer in self.layers[:-1]:
|
235 |
+
if self._grad_checkpointing:
|
236 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
237 |
+
layer,
|
238 |
+
hidden_states,
|
239 |
+
use_reentrant=False,
|
240 |
+
mixer_kwargs=mixer_kwargs,
|
241 |
+
)
|
242 |
+
else:
|
243 |
+
hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
|
244 |
+
if key_padding_mask is not None:
|
245 |
+
subset_idx = torch.nonzero(
|
246 |
+
subset_mask[key_padding_mask], as_tuple=False
|
247 |
+
).flatten()
|
248 |
+
subset_seqlens = (subset_mask & key_padding_mask).sum(
|
249 |
+
dim=-1, dtype=torch.int32
|
250 |
+
)
|
251 |
+
subset_cu_seqlens = F.pad(
|
252 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
253 |
+
(1, 0),
|
254 |
+
)
|
255 |
+
else:
|
256 |
+
subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten()
|
257 |
+
subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32)
|
258 |
+
subset_cu_seqlens = F.pad(
|
259 |
+
torch.cumsum(subset_seqlens, dim=0, dtype=torch.torch.int32),
|
260 |
+
(1, 0),
|
261 |
+
)
|
262 |
+
hidden_states_subset, hidden_states = index_first_axis_residual(
|
263 |
+
hidden_states, subset_idx
|
264 |
+
)
|
265 |
+
# It's ok to set max_seqlen_q to be much larger
|
266 |
+
mixer_kwargs = {
|
267 |
+
"x_kv": hidden_states,
|
268 |
+
"cu_seqlens": subset_cu_seqlens,
|
269 |
+
"max_seqlen": max_seqlen_in_batch,
|
270 |
+
"cu_seqlens_k": cu_seqlens,
|
271 |
+
"max_seqlen_k": max_seqlen_in_batch,
|
272 |
+
}
|
273 |
+
if self._grad_checkpointing:
|
274 |
+
torch.utils.checkpoint.checkpoint(
|
275 |
+
self.layers[-1],
|
276 |
+
hidden_states_subset,
|
277 |
+
use_reentrant=False,
|
278 |
+
mixer_kwargs=mixer_kwargs,
|
279 |
+
)
|
280 |
+
else:
|
281 |
+
hidden_states = self.layers[-1](
|
282 |
+
hidden_states_subset, mixer_kwargs=mixer_kwargs
|
283 |
+
)
|
284 |
+
return hidden_states
|
285 |
+
|
286 |
+
|
287 |
+
class XLMRobertaPooler(nn.Module):
|
288 |
+
def __init__(self, config):
|
289 |
+
super().__init__()
|
290 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
291 |
+
if fused_bias_fc and FusedDense is None:
|
292 |
+
raise ImportError("fused_dense is not installed")
|
293 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
294 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
295 |
+
self.activation = nn.Tanh()
|
296 |
+
|
297 |
+
def forward(self, hidden_states, pool=True):
|
298 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
299 |
+
# to the first token.
|
300 |
+
first_token_tensor = hidden_states[:, 0] if pool else hidden_states
|
301 |
+
pooled_output = self.dense(first_token_tensor)
|
302 |
+
pooled_output = self.activation(pooled_output)
|
303 |
+
return pooled_output
|
304 |
+
|
305 |
+
|
306 |
+
class XLMRobertaPredictionHeadTransform(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super().__init__()
|
309 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
310 |
+
if fused_bias_fc and FusedDense is None:
|
311 |
+
raise ImportError("fused_dense is not installed")
|
312 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
313 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
314 |
+
raise ImportError("Triton is not installed")
|
315 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
316 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
317 |
+
approximate = (
|
318 |
+
"tanh"
|
319 |
+
if config.hidden_act in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"]
|
320 |
+
else "none"
|
321 |
+
)
|
322 |
+
self.transform_act_fn = nn.GELU(approximate=approximate)
|
323 |
+
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
324 |
+
|
325 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
326 |
+
hidden_states = self.dense(hidden_states)
|
327 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
328 |
+
if not self.fused_dropout_add_ln:
|
329 |
+
hidden_states = self.layer_norm(hidden_states)
|
330 |
+
else:
|
331 |
+
hidden_states = layer_norm_fn(
|
332 |
+
hidden_states,
|
333 |
+
self.layer_norm.weight,
|
334 |
+
self.layer_norm.bias,
|
335 |
+
eps=self.layer_norm.eps,
|
336 |
+
)
|
337 |
+
return hidden_states
|
338 |
+
|
339 |
+
|
340 |
+
class XLMRobertaLMPredictionHead(nn.Module):
|
341 |
+
def __init__(self, config):
|
342 |
+
super().__init__()
|
343 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
344 |
+
if fused_bias_fc and FusedDense is None:
|
345 |
+
raise ImportError("fused_dense is not installed")
|
346 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
347 |
+
|
348 |
+
self.transform = XLMRobertaPredictionHeadTransform(config)
|
349 |
+
|
350 |
+
# The output weights are the same as the input embeddings, but there is
|
351 |
+
# an output-only bias for each token.
|
352 |
+
self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True)
|
353 |
+
|
354 |
+
def forward(self, hidden_states):
|
355 |
+
hidden_states = self.transform(hidden_states)
|
356 |
+
hidden_states = self.decoder(hidden_states)
|
357 |
+
return hidden_states
|
358 |
+
|
359 |
+
|
360 |
+
class XLMRobertaPreTrainingHeads(nn.Module):
|
361 |
+
def __init__(self, config):
|
362 |
+
super().__init__()
|
363 |
+
self.predictions = XLMRobertaLMPredictionHead(config)
|
364 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
365 |
+
|
366 |
+
def forward(self, sequence_output, pooled_output):
|
367 |
+
prediction_scores = self.predictions(sequence_output)
|
368 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
369 |
+
return prediction_scores, seq_relationship_score
|
370 |
+
|
371 |
+
|
372 |
+
class XLMRobertaPreTrainedModel(PreTrainedModel):
|
373 |
+
"""An abstract class to handle weights initialization and
|
374 |
+
a simple interface for dowloading and loading pretrained models.
|
375 |
+
"""
|
376 |
+
|
377 |
+
config_class = XLMRobertaFlashConfig
|
378 |
+
base_model_prefix = "roberta"
|
379 |
+
supports_gradient_checkpointing = True
|
380 |
+
|
381 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
382 |
+
if isinstance(module, XLMRobertaEncoder):
|
383 |
+
module.gradient_checkpointing = value
|
384 |
+
|
385 |
+
@classmethod
|
386 |
+
def from_pretrained(
|
387 |
+
cls,
|
388 |
+
*args,
|
389 |
+
**kwargs,
|
390 |
+
):
|
391 |
+
if not 'torch_dtype' in kwargs:
|
392 |
+
kwargs['torch_dtype'] = 'auto'
|
393 |
+
return super().from_pretrained(*args, **kwargs)
|
394 |
+
|
395 |
+
|
396 |
+
|
397 |
+
class XLMRobertaModel(XLMRobertaPreTrainedModel):
|
398 |
+
def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
|
399 |
+
super().__init__(config)
|
400 |
+
self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
401 |
+
if config.vocab_size % self.pad_vocab_size_multiple != 0:
|
402 |
+
config.vocab_size += self.pad_vocab_size_multiple - (
|
403 |
+
config.vocab_size % self.pad_vocab_size_multiple
|
404 |
+
)
|
405 |
+
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
|
406 |
+
if self.fused_dropout_add_ln and layer_norm_fn is None:
|
407 |
+
raise ImportError("Triton is not installed")
|
408 |
+
assert config.hidden_act in [
|
409 |
+
"gelu",
|
410 |
+
"gelu_new",
|
411 |
+
"gelu_fast",
|
412 |
+
"gelu_pytorch_tanh",
|
413 |
+
]
|
414 |
+
|
415 |
+
self.embeddings = XLMRobertaEmbeddings(
|
416 |
+
config.hidden_size,
|
417 |
+
config.vocab_size,
|
418 |
+
config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
|
419 |
+
config.type_vocab_size,
|
420 |
+
padding_idx=config.pad_token_id,
|
421 |
+
)
|
422 |
+
self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
|
423 |
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
424 |
+
self.encoder = XLMRobertaEncoder(config)
|
425 |
+
self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
|
426 |
+
|
427 |
+
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
428 |
+
|
429 |
+
|
430 |
+
@torch.inference_mode()
|
431 |
+
def encode(
|
432 |
+
self: 'XLMRobertaModel',
|
433 |
+
sentences: Union[str, List[str]],
|
434 |
+
batch_size: int = 32,
|
435 |
+
show_progress_bar: Optional[bool] = None,
|
436 |
+
output_value: str = 'sentence_embedding',
|
437 |
+
convert_to_numpy: bool = True,
|
438 |
+
convert_to_tensor: bool = False,
|
439 |
+
device: Optional[torch.device] = None,
|
440 |
+
normalize_embeddings: bool = False,
|
441 |
+
truncate_dim: Optional[int] = None,
|
442 |
+
**tokenizer_kwargs,
|
443 |
+
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
444 |
+
"""
|
445 |
+
Computes sentence embeddings
|
446 |
+
Args:
|
447 |
+
sentences(`str` or `List[str]`):
|
448 |
+
Sentence or sentences to be encoded
|
449 |
+
batch_size(`int`, *optional*, defaults to 32):
|
450 |
+
Batch size for the computation
|
451 |
+
show_progress_bar(`bool`, *optional*, defaults to None):
|
452 |
+
Show a progress bar when encoding sentences.
|
453 |
+
If set to None, progress bar is only shown when
|
454 |
+
`logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
|
455 |
+
output_value(`str`, *optional*, defaults to 'sentence_embedding'):
|
456 |
+
Default sentence_embedding, to get sentence embeddings.
|
457 |
+
Can be set to token_embeddings to get wordpiece token embeddings.
|
458 |
+
Set to None, to get all output values
|
459 |
+
convert_to_numpy(`bool`, *optional*, defaults to True):
|
460 |
+
If true, the output is a list of numpy vectors.
|
461 |
+
Else, it is a list of pytorch tensors.
|
462 |
+
convert_to_tensor(`bool`, *optional*, defaults to False):
|
463 |
+
If true, you get one large tensor as return.
|
464 |
+
Overwrites any setting from convert_to_numpy
|
465 |
+
device(`torch.device`, *optional*, defaults to None):
|
466 |
+
Which torch.device to use for the computation
|
467 |
+
normalize_embeddings(`bool`, *optional*, defaults to False):
|
468 |
+
If set to true, returned vectors will have length 1. In that case, the
|
469 |
+
faster dot-product (util.dot_score) instead of cosine similarity can
|
470 |
+
be used.
|
471 |
+
truncate_dim(`int`, *optional*, defaults to None):
|
472 |
+
The dimension to truncate sentence embeddings to. `None` does no truncation.
|
473 |
+
tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
|
474 |
+
Keyword arguments for the tokenizer
|
475 |
+
Returns:
|
476 |
+
By default, a list of tensors is returned.
|
477 |
+
If convert_to_tensor, a stacked tensor is returned.
|
478 |
+
If convert_to_numpy, a numpy matrix is returned.
|
479 |
+
"""
|
480 |
+
from transformers import AutoTokenizer
|
481 |
+
|
482 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
483 |
+
self.name_or_path, trust_remote_code=True
|
484 |
+
)
|
485 |
+
|
486 |
+
is_training = self.training
|
487 |
+
self.eval()
|
488 |
+
|
489 |
+
if show_progress_bar is None:
|
490 |
+
show_progress_bar = (
|
491 |
+
logger.getEffectiveLevel() == logging.INFO
|
492 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
493 |
+
)
|
494 |
+
|
495 |
+
if convert_to_tensor:
|
496 |
+
convert_to_numpy = False
|
497 |
+
|
498 |
+
if output_value != 'sentence_embedding':
|
499 |
+
convert_to_tensor = False
|
500 |
+
convert_to_numpy = False
|
501 |
+
|
502 |
+
input_was_string = False
|
503 |
+
if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
|
504 |
+
sentences = [sentences]
|
505 |
+
input_was_string = True
|
506 |
+
|
507 |
+
if device is not None:
|
508 |
+
self.to(device)
|
509 |
+
|
510 |
+
permutation = np.argsort([-len(i) for i in sentences])
|
511 |
+
inverse_permutation = np.argsort(permutation)
|
512 |
+
sentences = [sentences[idx] for idx in permutation]
|
513 |
+
|
514 |
+
tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
|
515 |
+
tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
|
516 |
+
'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
|
517 |
+
)
|
518 |
+
tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
|
519 |
+
|
520 |
+
all_embeddings = []
|
521 |
+
|
522 |
+
if trange is not None:
|
523 |
+
range_iter = trange(
|
524 |
+
0,
|
525 |
+
len(sentences),
|
526 |
+
batch_size,
|
527 |
+
desc="Encoding",
|
528 |
+
disable=not show_progress_bar,
|
529 |
+
)
|
530 |
+
else:
|
531 |
+
range_iter = range(0, len(sentences), batch_size)
|
532 |
+
|
533 |
+
for i in range_iter:
|
534 |
+
encoded_input = self.tokenizer(
|
535 |
+
sentences[i : i + batch_size],
|
536 |
+
return_tensors='pt',
|
537 |
+
**tokenizer_kwargs,
|
538 |
+
).to(self.device)
|
539 |
+
token_embs = self.forward(**encoded_input)[0]
|
540 |
+
|
541 |
+
# Accumulate in fp32 to avoid overflow
|
542 |
+
token_embs = token_embs.float()
|
543 |
+
|
544 |
+
if output_value == 'token_embeddings':
|
545 |
+
raise NotImplementedError
|
546 |
+
elif output_value is None:
|
547 |
+
raise NotImplementedError
|
548 |
+
else:
|
549 |
+
if self.config.emb_pooler == 'cls':
|
550 |
+
embeddings = self.cls_pooling(
|
551 |
+
token_embs, encoded_input['attention_mask']
|
552 |
+
)
|
553 |
+
else:
|
554 |
+
embeddings = self.mean_pooling(
|
555 |
+
token_embs, encoded_input['attention_mask']
|
556 |
+
)
|
557 |
+
|
558 |
+
if normalize_embeddings:
|
559 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
560 |
+
|
561 |
+
if convert_to_numpy:
|
562 |
+
embeddings = embeddings.cpu()
|
563 |
+
all_embeddings.extend(embeddings)
|
564 |
+
|
565 |
+
all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
|
566 |
+
|
567 |
+
truncate_dim = truncate_dim or self.config.truncate_dim
|
568 |
+
if truncate_dim:
|
569 |
+
all_embeddings = self.truncate_embeddings(all_embeddings, truncate_dim)
|
570 |
+
|
571 |
+
if convert_to_tensor:
|
572 |
+
all_embeddings = torch.stack(all_embeddings)
|
573 |
+
elif convert_to_numpy:
|
574 |
+
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
|
575 |
+
|
576 |
+
if input_was_string:
|
577 |
+
all_embeddings = all_embeddings[0]
|
578 |
+
|
579 |
+
self.train(is_training)
|
580 |
+
return all_embeddings
|
581 |
+
|
582 |
+
|
583 |
+
def truncate_embeddings(self, embeddings, truncate_dim):
|
584 |
+
if not self.config.matryoshka_dimensions:
|
585 |
+
logger.warning(
|
586 |
+
'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
|
587 |
+
)
|
588 |
+
return embeddings
|
589 |
+
elif truncate_dim in self.config.matryoshka_dimensions:
|
590 |
+
return [tensor[:truncate_dim] for tensor in embeddings]
|
591 |
+
else:
|
592 |
+
raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
|
593 |
+
f'Supported dimensions are {self.config.matryoshka_dimensions}.')
|
594 |
+
|
595 |
+
def mean_pooling(
|
596 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
597 |
+
):
|
598 |
+
input_mask_expanded = (
|
599 |
+
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
600 |
+
)
|
601 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
602 |
+
input_mask_expanded.sum(1), min=1e-9
|
603 |
+
)
|
604 |
+
|
605 |
+
|
606 |
+
def cls_pooling(
|
607 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
608 |
+
):
|
609 |
+
return token_embeddings[:,0]
|
610 |
+
|
611 |
+
|
612 |
+
def forward(
|
613 |
+
self,
|
614 |
+
input_ids,
|
615 |
+
position_ids=None,
|
616 |
+
token_type_ids=None,
|
617 |
+
attention_mask=None,
|
618 |
+
masked_tokens_mask=None,
|
619 |
+
return_dict=None,
|
620 |
+
**kwargs,
|
621 |
+
):
|
622 |
+
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in XLMForPreTraining),
|
623 |
+
we only want the output for the masked tokens. This means that we only compute the last
|
624 |
+
layer output for these tokens.
|
625 |
+
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
|
626 |
+
"""
|
627 |
+
|
628 |
+
if kwargs:
|
629 |
+
for key, value in kwargs.items():
|
630 |
+
if value is not None:
|
631 |
+
logger.warning(
|
632 |
+
'Flash attention implementation does not support kwargs: %s',
|
633 |
+
key,
|
634 |
+
)
|
635 |
+
|
636 |
+
return_dict = (
|
637 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
638 |
+
)
|
639 |
+
|
640 |
+
hidden_states = self.embeddings(
|
641 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
642 |
+
)
|
643 |
+
# TD [2022-12:18]: Don't need to force residual in fp32
|
644 |
+
# BERT puts embedding LayerNorm before embedding dropout.
|
645 |
+
if not self.fused_dropout_add_ln:
|
646 |
+
hidden_states = self.emb_ln(hidden_states)
|
647 |
+
else:
|
648 |
+
hidden_states = layer_norm_fn(
|
649 |
+
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
|
650 |
+
)
|
651 |
+
hidden_states = self.emb_drop(hidden_states)
|
652 |
+
|
653 |
+
if masked_tokens_mask is not None:
|
654 |
+
batch_size, seqlen = input_ids.shape[:2]
|
655 |
+
# We also need the first column for the CLS token
|
656 |
+
first_col_mask = torch.zeros(
|
657 |
+
batch_size, seqlen, dtype=torch.bool, device=input_ids.device
|
658 |
+
)
|
659 |
+
first_col_mask[:, 0] = True
|
660 |
+
subset_mask = masked_tokens_mask | first_col_mask
|
661 |
+
else:
|
662 |
+
subset_mask = None
|
663 |
+
|
664 |
+
sequence_output = self.encoder(
|
665 |
+
hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask
|
666 |
+
)
|
667 |
+
|
668 |
+
if masked_tokens_mask is None:
|
669 |
+
pooled_output = (
|
670 |
+
self.pooler(sequence_output) if self.pooler is not None else None
|
671 |
+
)
|
672 |
+
else:
|
673 |
+
# TD [2022-03-01]: the indexing here is very tricky.
|
674 |
+
if attention_mask is not None:
|
675 |
+
subset_idx = subset_mask[attention_mask]
|
676 |
+
pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]]
|
677 |
+
sequence_output = sequence_output[
|
678 |
+
masked_tokens_mask[attention_mask][subset_idx]
|
679 |
+
]
|
680 |
+
else:
|
681 |
+
pool_input = sequence_output[first_col_mask[subset_mask]]
|
682 |
+
sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
|
683 |
+
pooled_output = (
|
684 |
+
self.pooler(pool_input, pool=False) if self.pooler is not None else None
|
685 |
+
)
|
686 |
+
|
687 |
+
if not return_dict:
|
688 |
+
return sequence_output, pooled_output
|
689 |
+
|
690 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
691 |
+
last_hidden_state=sequence_output,
|
692 |
+
pooler_output=pooled_output,
|
693 |
+
)
|
694 |
+
|
695 |
+
|
696 |
+
class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
|
697 |
+
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
|
698 |
+
|
699 |
+
def __init__(self, config):
|
700 |
+
super().__init__(config)
|
701 |
+
|
702 |
+
if config.is_decoder:
|
703 |
+
logger.warning(
|
704 |
+
"If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
|
705 |
+
"bi-directional self-attention."
|
706 |
+
)
|
707 |
+
|
708 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
709 |
+
self.lm_head = XLMRobertaLMHead(config)
|
710 |
+
|
711 |
+
# Initialize weights and apply final processing
|
712 |
+
self.post_init()
|
713 |
+
|
714 |
+
def get_input_embeddings(self):
|
715 |
+
return self.roberta.embeddings.word_embeddings
|
716 |
+
|
717 |
+
def get_output_embeddings(self):
|
718 |
+
return self.lm_head.decoder
|
719 |
+
|
720 |
+
def set_output_embeddings(self, new_embeddings):
|
721 |
+
self.lm_head.decoder = new_embeddings
|
722 |
+
|
723 |
+
def forward(
|
724 |
+
self,
|
725 |
+
input_ids: Optional[torch.LongTensor] = None,
|
726 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
727 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
728 |
+
position_ids: Optional[torch.LongTensor] = None,
|
729 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
730 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
731 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
732 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
733 |
+
labels: Optional[torch.LongTensor] = None,
|
734 |
+
output_attentions: Optional[bool] = None,
|
735 |
+
output_hidden_states: Optional[bool] = None,
|
736 |
+
return_dict: Optional[bool] = None,
|
737 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
738 |
+
r"""
|
739 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
740 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
741 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
742 |
+
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
743 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
744 |
+
Used to hide legacy arguments that have been deprecated.
|
745 |
+
"""
|
746 |
+
return_dict = (
|
747 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
748 |
+
)
|
749 |
+
|
750 |
+
outputs = self.roberta(
|
751 |
+
input_ids,
|
752 |
+
attention_mask=attention_mask,
|
753 |
+
token_type_ids=token_type_ids,
|
754 |
+
position_ids=position_ids,
|
755 |
+
head_mask=head_mask,
|
756 |
+
inputs_embeds=inputs_embeds,
|
757 |
+
encoder_hidden_states=encoder_hidden_states,
|
758 |
+
encoder_attention_mask=encoder_attention_mask,
|
759 |
+
output_attentions=output_attentions,
|
760 |
+
output_hidden_states=output_hidden_states,
|
761 |
+
return_dict=return_dict,
|
762 |
+
)
|
763 |
+
sequence_output = outputs[0]
|
764 |
+
prediction_scores = self.lm_head(sequence_output)
|
765 |
+
|
766 |
+
masked_lm_loss = None
|
767 |
+
if labels is not None:
|
768 |
+
# move labels to correct device to enable model parallelism
|
769 |
+
labels = labels.to(prediction_scores.device)
|
770 |
+
loss_fct = CrossEntropyLoss()
|
771 |
+
masked_lm_loss = loss_fct(
|
772 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
|
773 |
+
)
|
774 |
+
|
775 |
+
if not return_dict:
|
776 |
+
output = (prediction_scores,) + outputs[2:]
|
777 |
+
return (
|
778 |
+
((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
779 |
+
)
|
780 |
+
|
781 |
+
return MaskedLMOutput(
|
782 |
+
loss=masked_lm_loss,
|
783 |
+
logits=prediction_scores,
|
784 |
+
hidden_states=outputs.hidden_states,
|
785 |
+
attentions=outputs.attentions,
|
786 |
+
)
|
787 |
+
|
788 |
+
|
789 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta
|
790 |
+
class XLMRobertaClassificationHead(nn.Module):
|
791 |
+
"""Head for sentence-level classification tasks."""
|
792 |
+
|
793 |
+
def __init__(self, config):
|
794 |
+
super().__init__()
|
795 |
+
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
796 |
+
if fused_bias_fc and FusedDense is None:
|
797 |
+
raise ImportError("fused_dense is not installed")
|
798 |
+
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
799 |
+
self.dense = linear_cls(config.hidden_size, config.hidden_size)
|
800 |
+
classifier_dropout = (
|
801 |
+
config.classifier_dropout
|
802 |
+
if config.classifier_dropout is not None
|
803 |
+
else config.hidden_dropout_prob
|
804 |
+
)
|
805 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
806 |
+
self.out_proj = linear_cls(config.hidden_size, config.num_labels)
|
807 |
+
|
808 |
+
def forward(self, features, **kwargs):
|
809 |
+
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
810 |
+
x = self.dropout(x)
|
811 |
+
x = self.dense(x)
|
812 |
+
x = torch.tanh(x)
|
813 |
+
x = self.dropout(x)
|
814 |
+
x = self.out_proj(x)
|
815 |
+
return x
|
816 |
+
|
817 |
+
|
818 |
+
# Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA
|
819 |
+
class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
|
820 |
+
def __init__(self, config):
|
821 |
+
super().__init__(config)
|
822 |
+
self.num_labels = config.num_labels
|
823 |
+
self.config = config
|
824 |
+
|
825 |
+
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
|
826 |
+
self.classifier = XLMRobertaClassificationHead(config)
|
827 |
+
|
828 |
+
# Initialize weights and apply final processing
|
829 |
+
self.post_init()
|
830 |
+
|
831 |
+
def forward(
|
832 |
+
self,
|
833 |
+
input_ids: Optional[torch.LongTensor] = None,
|
834 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
835 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
836 |
+
position_ids: Optional[torch.LongTensor] = None,
|
837 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
838 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
839 |
+
labels: Optional[torch.LongTensor] = None,
|
840 |
+
output_attentions: Optional[bool] = None,
|
841 |
+
output_hidden_states: Optional[bool] = None,
|
842 |
+
return_dict: Optional[bool] = None,
|
843 |
+
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
844 |
+
r"""
|
845 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
846 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
847 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
848 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
849 |
+
"""
|
850 |
+
return_dict = (
|
851 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
852 |
+
)
|
853 |
+
|
854 |
+
outputs = self.roberta(
|
855 |
+
input_ids,
|
856 |
+
attention_mask=attention_mask,
|
857 |
+
token_type_ids=token_type_ids,
|
858 |
+
position_ids=position_ids,
|
859 |
+
head_mask=head_mask,
|
860 |
+
inputs_embeds=inputs_embeds,
|
861 |
+
output_attentions=output_attentions,
|
862 |
+
output_hidden_states=output_hidden_states,
|
863 |
+
return_dict=return_dict,
|
864 |
+
)
|
865 |
+
sequence_output = outputs[0]
|
866 |
+
logits = self.classifier(sequence_output)
|
867 |
+
|
868 |
+
loss = None
|
869 |
+
if labels is not None:
|
870 |
+
# move labels to correct device to enable model parallelism
|
871 |
+
labels = labels.to(logits.device)
|
872 |
+
if self.config.problem_type is None:
|
873 |
+
if self.num_labels == 1:
|
874 |
+
self.config.problem_type = "regression"
|
875 |
+
elif self.num_labels > 1 and (
|
876 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
877 |
+
):
|
878 |
+
self.config.problem_type = "single_label_classification"
|
879 |
+
else:
|
880 |
+
self.config.problem_type = "multi_label_classification"
|
881 |
+
|
882 |
+
if self.config.problem_type == "regression":
|
883 |
+
loss_fct = MSELoss()
|
884 |
+
if self.num_labels == 1:
|
885 |
+
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
886 |
+
else:
|
887 |
+
loss = loss_fct(logits, labels)
|
888 |
+
elif self.config.problem_type == "single_label_classification":
|
889 |
+
loss_fct = CrossEntropyLoss()
|
890 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
891 |
+
elif self.config.problem_type == "multi_label_classification":
|
892 |
+
loss_fct = BCEWithLogitsLoss()
|
893 |
+
loss = loss_fct(logits, labels)
|
894 |
+
|
895 |
+
if not return_dict:
|
896 |
+
output = (logits,) + outputs[2:]
|
897 |
+
return ((loss,) + output) if loss is not None else output
|
898 |
+
|
899 |
+
return SequenceClassifierOutput(
|
900 |
+
loss=loss,
|
901 |
+
logits=logits,
|
902 |
+
hidden_states=outputs.hidden_states,
|
903 |
+
attentions=outputs.attentions,
|
904 |
+
)
|
905 |
+
|
906 |
+
|
907 |
+
@torch.inference_mode()
|
908 |
+
def compute_score(
|
909 |
+
self,
|
910 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
911 |
+
batch_size: int = 32,
|
912 |
+
max_length: Optional[int] = None,
|
913 |
+
) -> List[float]:
|
914 |
+
|
915 |
+
if not hasattr(self, "_tokenizer"):
|
916 |
+
from transformers import AutoTokenizer
|
917 |
+
|
918 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
919 |
+
self.name_or_path, trust_remote_code=True
|
920 |
+
)
|
921 |
+
|
922 |
+
assert isinstance(sentence_pairs, list)
|
923 |
+
if isinstance(sentence_pairs[0], str):
|
924 |
+
sentence_pairs = [sentence_pairs]
|
925 |
+
|
926 |
+
all_scores = []
|
927 |
+
for start_index in range(
|
928 |
+
0, len(sentence_pairs), batch_size
|
929 |
+
):
|
930 |
+
sentences_batch = sentence_pairs[
|
931 |
+
start_index : start_index + batch_size
|
932 |
+
]
|
933 |
+
inputs = self._tokenizer(
|
934 |
+
sentences_batch,
|
935 |
+
padding=True,
|
936 |
+
truncation=True,
|
937 |
+
return_tensors='pt',
|
938 |
+
max_length=max_length,
|
939 |
+
).to(self.device)
|
940 |
+
scores = (
|
941 |
+
self.forward(**inputs, return_dict=True)
|
942 |
+
.logits.view(
|
943 |
+
-1,
|
944 |
+
)
|
945 |
+
.float()
|
946 |
+
)
|
947 |
+
scores = torch.sigmoid(scores)
|
948 |
+
all_scores.extend(scores.cpu().numpy().tolist())
|
949 |
+
|
950 |
+
if len(all_scores) == 1:
|
951 |
+
return all_scores[0]
|
952 |
+
return all_scores
|
953 |
+
|
954 |
+
def predict(
|
955 |
+
self,
|
956 |
+
sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]],
|
957 |
+
batch_size: int = 32,
|
958 |
+
max_length: Optional[int] = None,
|
959 |
+
) -> List[float]:
|
960 |
+
# used for beir evaluation
|
961 |
+
return self.compute_score(sentence_pairs, batch_size=batch_size, max_length=max_length)
|
962 |
+
|
963 |
+
def rerank(
|
964 |
+
self,
|
965 |
+
query: str,
|
966 |
+
documents: List[str],
|
967 |
+
batch_size: int = 32,
|
968 |
+
max_length: int = 1024,
|
969 |
+
max_query_length: int = 512,
|
970 |
+
overlap_tokens: int = 80,
|
971 |
+
top_n: Optional[int] = None,
|
972 |
+
**kwargs,
|
973 |
+
):
|
974 |
+
assert max_length >= max_query_length * 2, (
|
975 |
+
f'max_length ({max_length}) must be greater than or equal to '
|
976 |
+
f'max_query_length ({max_query_length}) * 2'
|
977 |
+
)
|
978 |
+
|
979 |
+
if not hasattr(self, "_tokenizer"):
|
980 |
+
from transformers import AutoTokenizer
|
981 |
+
|
982 |
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
983 |
+
self.name_or_path, trust_remote_code=True
|
984 |
+
)
|
985 |
+
|
986 |
+
# preproc of tokenization
|
987 |
+
sentence_pairs, sentence_pairs_pids = reranker_tokenize_preproc(
|
988 |
+
query,
|
989 |
+
documents,
|
990 |
+
tokenizer=self._tokenizer,
|
991 |
+
max_length=max_length,
|
992 |
+
max_query_length=max_query_length,
|
993 |
+
overlap_tokens=overlap_tokens,
|
994 |
+
)
|
995 |
+
|
996 |
+
tot_scores = []
|
997 |
+
with torch.no_grad():
|
998 |
+
for k in range(0, len(sentence_pairs), batch_size):
|
999 |
+
batch = self._tokenizer.pad(
|
1000 |
+
sentence_pairs[k : k + batch_size],
|
1001 |
+
padding=True,
|
1002 |
+
max_length=max_length,
|
1003 |
+
pad_to_multiple_of=None,
|
1004 |
+
return_tensors="pt",
|
1005 |
+
)
|
1006 |
+
batch_on_device = {k: v.to(self.device) for k, v in batch.items()}
|
1007 |
+
scores = (
|
1008 |
+
self.forward(**batch_on_device, return_dict=True)
|
1009 |
+
.logits.view(
|
1010 |
+
-1,
|
1011 |
+
)
|
1012 |
+
.float()
|
1013 |
+
)
|
1014 |
+
scores = torch.sigmoid(scores)
|
1015 |
+
tot_scores.extend(scores.cpu().numpy().tolist())
|
1016 |
+
|
1017 |
+
# ranking
|
1018 |
+
merge_scores = [0 for _ in range(len(documents))]
|
1019 |
+
for pid, score in zip(sentence_pairs_pids, tot_scores):
|
1020 |
+
merge_scores[pid] = max(merge_scores[pid], score)
|
1021 |
+
|
1022 |
+
merge_scores_argsort = np.argsort(merge_scores)[::-1]
|
1023 |
+
sorted_documents = []
|
1024 |
+
sorted_scores = []
|
1025 |
+
for mid in merge_scores_argsort:
|
1026 |
+
sorted_scores.append(merge_scores[mid])
|
1027 |
+
sorted_documents.append(documents[mid])
|
1028 |
+
|
1029 |
+
top_n = min(top_n or len(sorted_documents), len(sorted_documents))
|
1030 |
+
|
1031 |
+
return [
|
1032 |
+
{
|
1033 |
+
'document': sorted_documents[i],
|
1034 |
+
'relevance_score': sorted_scores[i],
|
1035 |
+
'index': merge_scores_argsort[i],
|
1036 |
+
}
|
1037 |
+
for i in range(top_n)
|
1038 |
+
]
|
1039 |
+
|
1040 |
+
|
1041 |
+
def reranker_tokenize_preproc(
|
1042 |
+
query: str,
|
1043 |
+
passages: List[str],
|
1044 |
+
tokenizer=None,
|
1045 |
+
max_length: int = 1024,
|
1046 |
+
max_query_length: int = 512,
|
1047 |
+
overlap_tokens: int = 80,
|
1048 |
+
):
|
1049 |
+
from copy import deepcopy
|
1050 |
+
|
1051 |
+
assert tokenizer is not None, "Please provide a valid tokenizer for tokenization!"
|
1052 |
+
sep_id = tokenizer.sep_token_id
|
1053 |
+
|
1054 |
+
def _merge_inputs(chunk1_raw, chunk2):
|
1055 |
+
chunk1 = deepcopy(chunk1_raw)
|
1056 |
+
chunk1['input_ids'].append(sep_id)
|
1057 |
+
chunk1['input_ids'].extend(chunk2['input_ids'])
|
1058 |
+
chunk1['input_ids'].append(sep_id)
|
1059 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][0])
|
1060 |
+
chunk1['attention_mask'].extend(chunk2['attention_mask'])
|
1061 |
+
chunk1['attention_mask'].append(chunk2['attention_mask'][-1])
|
1062 |
+
if 'token_type_ids' in chunk1:
|
1063 |
+
token_type_ids = [1 for _ in range(len(chunk2['token_type_ids']) + 2)]
|
1064 |
+
chunk1['token_type_ids'].extend(token_type_ids)
|
1065 |
+
return chunk1
|
1066 |
+
|
1067 |
+
# Note: the long query will be truncated to 256 tokens by default
|
1068 |
+
query_inputs = tokenizer.encode_plus(
|
1069 |
+
query, truncation=True, padding=False, max_length=max_query_length
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
max_passage_inputs_length = max_length - len(query_inputs['input_ids']) - 2
|
1073 |
+
# assert (
|
1074 |
+
# max_passage_inputs_length > 100
|
1075 |
+
# ), "Your query is too long! Please make sure your query less than 500 tokens!"
|
1076 |
+
|
1077 |
+
overlap_tokens_implt = min(overlap_tokens, max_passage_inputs_length // 4)
|
1078 |
+
|
1079 |
+
res_merge_inputs = []
|
1080 |
+
res_merge_inputs_pids = []
|
1081 |
+
for pid, passage in enumerate(passages):
|
1082 |
+
passage_inputs = tokenizer.encode_plus(
|
1083 |
+
passage,
|
1084 |
+
truncation=False,
|
1085 |
+
padding=False,
|
1086 |
+
add_special_tokens=False,
|
1087 |
+
max_length=0,
|
1088 |
+
)
|
1089 |
+
passage_inputs_length = len(passage_inputs['input_ids'])
|
1090 |
+
|
1091 |
+
if passage_inputs_length <= max_passage_inputs_length:
|
1092 |
+
qp_merge_inputs = _merge_inputs(query_inputs, passage_inputs)
|
1093 |
+
res_merge_inputs.append(qp_merge_inputs)
|
1094 |
+
res_merge_inputs_pids.append(pid)
|
1095 |
+
else:
|
1096 |
+
start_id = 0
|
1097 |
+
while start_id < passage_inputs_length:
|
1098 |
+
end_id = start_id + max_passage_inputs_length
|
1099 |
+
# make sure the length of the last chunk is `max_passage_inputs_length`
|
1100 |
+
if end_id >= passage_inputs_length:
|
1101 |
+
sub_passage_inputs = {
|
1102 |
+
k: v[-max_passage_inputs_length:]
|
1103 |
+
for k, v in passage_inputs.items()
|
1104 |
+
}
|
1105 |
+
else:
|
1106 |
+
sub_passage_inputs = {
|
1107 |
+
k: v[start_id:end_id] for k, v in passage_inputs.items()
|
1108 |
+
}
|
1109 |
+
start_id = (
|
1110 |
+
end_id - overlap_tokens_implt
|
1111 |
+
if end_id < passage_inputs_length
|
1112 |
+
else end_id
|
1113 |
+
)
|
1114 |
+
|
1115 |
+
qp_merge_inputs = _merge_inputs(query_inputs, sub_passage_inputs)
|
1116 |
+
res_merge_inputs.append(qp_merge_inputs)
|
1117 |
+
res_merge_inputs_pids.append(pid)
|
1118 |
+
|
1119 |
+
return res_merge_inputs, res_merge_inputs_pids
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:318b11c3ce6d8d34e5034d001166a857934c0811c4fc5fb4a40328477ccaaaf9
|
3 |
+
size 561622266
|
special_tokens_map.json
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"cls_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"eos_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"mask_token": {
|
24 |
+
"content": "<mask>",
|
25 |
+
"lstrip": true,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
},
|
30 |
+
"pad_token": {
|
31 |
+
"content": "<pad>",
|
32 |
+
"lstrip": false,
|
33 |
+
"normalized": false,
|
34 |
+
"rstrip": false,
|
35 |
+
"single_word": false
|
36 |
+
},
|
37 |
+
"sep_token": {
|
38 |
+
"content": "</s>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false
|
43 |
+
},
|
44 |
+
"unk_token": {
|
45 |
+
"content": "<unk>",
|
46 |
+
"lstrip": false,
|
47 |
+
"normalized": false,
|
48 |
+
"rstrip": false,
|
49 |
+
"single_word": false
|
50 |
+
}
|
51 |
+
}
|
stochastic_depth.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Implementation modified from torchvision:
|
2 |
+
# https://github.com/pytorch/vision/blob/main/torchvision/ops/stochastic_depth.py
|
3 |
+
#
|
4 |
+
# License:
|
5 |
+
# BSD 3-Clause License
|
6 |
+
#
|
7 |
+
# Copyright (c) Soumith Chintala 2016,
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# Redistribution and use in source and binary forms, with or without
|
11 |
+
# modification, are permitted provided that the following conditions are met:
|
12 |
+
#
|
13 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
14 |
+
# list of conditions and the following disclaimer.
|
15 |
+
#
|
16 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
17 |
+
# this list of conditions and the following disclaimer in the documentation
|
18 |
+
# and/or other materials provided with the distribution.
|
19 |
+
#
|
20 |
+
# * Neither the name of the copyright holder nor the names of its
|
21 |
+
# contributors may be used to endorse or promote products derived from
|
22 |
+
# this software without specific prior written permission.
|
23 |
+
#
|
24 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
25 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
26 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
27 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
28 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
29 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
30 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
31 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
32 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
33 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
34 |
+
|
35 |
+
import torch
|
36 |
+
import torch.fx
|
37 |
+
from torch import nn, Tensor
|
38 |
+
|
39 |
+
|
40 |
+
def stochastic_depth(
|
41 |
+
input: Tensor, p: float, mode: str, training: bool = True
|
42 |
+
) -> Tensor:
|
43 |
+
"""
|
44 |
+
Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
|
45 |
+
<https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
|
46 |
+
branches of residual architectures.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
|
50 |
+
being its batch i.e. a batch with ``N`` rows.
|
51 |
+
p (float): probability of the input to be zeroed.
|
52 |
+
mode (str): ``"batch"`` or ``"row"``.
|
53 |
+
``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
|
54 |
+
randomly selected rows from the batch.
|
55 |
+
training: apply stochastic depth if is ``True``. Default: ``True``
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
Tensor[N, ...]: The randomly zeroed tensor.
|
59 |
+
"""
|
60 |
+
if p < 0.0 or p > 1.0:
|
61 |
+
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
|
62 |
+
if mode not in ["batch", "row"]:
|
63 |
+
raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
|
64 |
+
if not training or p == 0.0:
|
65 |
+
return input
|
66 |
+
|
67 |
+
survival_rate = 1.0 - p
|
68 |
+
if mode == "row":
|
69 |
+
size = [input.shape[0]] + [1] * (input.ndim - 1)
|
70 |
+
else:
|
71 |
+
size = [1] * input.ndim
|
72 |
+
noise = torch.empty(size, dtype=input.dtype, device=input.device)
|
73 |
+
noise = noise.bernoulli_(survival_rate)
|
74 |
+
if survival_rate > 0.0:
|
75 |
+
noise.div_(survival_rate)
|
76 |
+
return input * noise
|
77 |
+
|
78 |
+
|
79 |
+
torch.fx.wrap("stochastic_depth")
|
80 |
+
|
81 |
+
|
82 |
+
class StochasticDepth(nn.Module):
|
83 |
+
"""
|
84 |
+
See :func:`stochastic_depth`.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, p: float, mode: str) -> None:
|
88 |
+
super().__init__()
|
89 |
+
self.p = p
|
90 |
+
self.mode = mode
|
91 |
+
|
92 |
+
def forward(self, input: Tensor) -> Tensor:
|
93 |
+
return stochastic_depth(input, self.p, self.mode, self.training)
|
94 |
+
|
95 |
+
def __repr__(self) -> str:
|
96 |
+
s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
|
97 |
+
return s
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3a56def25aa40facc030ea8b0b87f3688e4b3c39eb8b45d5702b3a1300fe2a20
|
3 |
+
size 17082734
|
tokenizer_config.json
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<pad>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
},
|
27 |
+
"3": {
|
28 |
+
"content": "<unk>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": false,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false,
|
33 |
+
"special": true
|
34 |
+
},
|
35 |
+
"250001": {
|
36 |
+
"content": "<mask>",
|
37 |
+
"lstrip": true,
|
38 |
+
"normalized": false,
|
39 |
+
"rstrip": false,
|
40 |
+
"single_word": false,
|
41 |
+
"special": true
|
42 |
+
}
|
43 |
+
},
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"clean_up_tokenization_spaces": true,
|
46 |
+
"cls_token": "<s>",
|
47 |
+
"eos_token": "</s>",
|
48 |
+
"mask_token": "<mask>",
|
49 |
+
"model_max_length": 1024,
|
50 |
+
"pad_token": "<pad>",
|
51 |
+
"sep_token": "</s>",
|
52 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
53 |
+
"unk_token": "<unk>"
|
54 |
+
}
|
xlm_padding.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/block.py
|
2 |
+
# Commit id: c94cd09744d20f0ac587a351ff6ff2e8ad11ae1b
|
3 |
+
|
4 |
+
# Previously adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from einops import rearrange, repeat
|
9 |
+
|
10 |
+
|
11 |
+
class IndexFirstAxis(torch.autograd.Function):
|
12 |
+
@staticmethod
|
13 |
+
def forward(ctx, input, indices):
|
14 |
+
ctx.save_for_backward(indices)
|
15 |
+
assert input.ndim >= 2
|
16 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
17 |
+
second_dim = other_shape.numel()
|
18 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
19 |
+
# return input[indices]
|
20 |
+
return torch.gather(
|
21 |
+
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
|
22 |
+
).reshape(-1, *other_shape)
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, grad_output):
|
26 |
+
(indices,) = ctx.saved_tensors
|
27 |
+
assert grad_output.ndim >= 2
|
28 |
+
other_shape = grad_output.shape[1:]
|
29 |
+
grad_output = rearrange(grad_output, "b ... -> b (...)")
|
30 |
+
grad_input = torch.zeros(
|
31 |
+
[ctx.first_axis_dim, grad_output.shape[1]],
|
32 |
+
device=grad_output.device,
|
33 |
+
dtype=grad_output.dtype,
|
34 |
+
)
|
35 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
36 |
+
# grad_input[indices] = grad_output
|
37 |
+
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
|
38 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
39 |
+
|
40 |
+
|
41 |
+
index_first_axis = IndexFirstAxis.apply
|
42 |
+
|
43 |
+
|
44 |
+
class IndexPutFirstAxis(torch.autograd.Function):
|
45 |
+
@staticmethod
|
46 |
+
def forward(ctx, values, indices, first_axis_dim):
|
47 |
+
ctx.save_for_backward(indices)
|
48 |
+
assert indices.ndim == 1
|
49 |
+
assert values.ndim >= 2
|
50 |
+
output = torch.zeros(
|
51 |
+
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
|
52 |
+
)
|
53 |
+
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
|
54 |
+
output[indices] = values
|
55 |
+
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
|
56 |
+
return output
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def backward(ctx, grad_output):
|
60 |
+
(indices,) = ctx.saved_tensors
|
61 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
62 |
+
grad_values = grad_output[indices]
|
63 |
+
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
|
64 |
+
return grad_values, None, None
|
65 |
+
|
66 |
+
|
67 |
+
index_put_first_axis = IndexPutFirstAxis.apply
|
68 |
+
|
69 |
+
|
70 |
+
class IndexFirstAxisResidual(torch.autograd.Function):
|
71 |
+
@staticmethod
|
72 |
+
def forward(ctx, input, indices):
|
73 |
+
ctx.save_for_backward(indices)
|
74 |
+
assert input.ndim >= 2
|
75 |
+
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
|
76 |
+
second_dim = other_shape.numel()
|
77 |
+
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
|
78 |
+
output = input[indices]
|
79 |
+
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
|
80 |
+
# memory format to channel_first. In other words, input might not be contiguous.
|
81 |
+
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
|
82 |
+
return output, input.detach()
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def backward(ctx, grad_output, grad_residual):
|
86 |
+
(indices,) = ctx.saved_tensors
|
87 |
+
assert grad_output.ndim >= 2
|
88 |
+
other_shape = grad_output.shape[1:]
|
89 |
+
assert grad_residual.shape[1:] == other_shape
|
90 |
+
grad_input = grad_residual
|
91 |
+
# grad_input[indices] += grad_output
|
92 |
+
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
|
93 |
+
indices = indices.expand_as(grad_output)
|
94 |
+
grad_input.scatter_add_(0, indices, grad_output)
|
95 |
+
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
|
96 |
+
|
97 |
+
|
98 |
+
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
+
|
100 |
+
|
101 |
+
def unpad_input(hidden_states, attention_mask):
|
102 |
+
"""
|
103 |
+
Arguments:
|
104 |
+
hidden_states: (batch, seqlen, ...)
|
105 |
+
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
|
106 |
+
Return:
|
107 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
108 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
109 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
110 |
+
max_seqlen_in_batch: int
|
111 |
+
"""
|
112 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
113 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
114 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
116 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
117 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
118 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
119 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
120 |
+
# so we write custom forward and backward to make it a bit faster.
|
121 |
+
return (
|
122 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
123 |
+
indices,
|
124 |
+
cu_seqlens,
|
125 |
+
max_seqlen_in_batch,
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
|
130 |
+
"""
|
131 |
+
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
|
132 |
+
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
|
133 |
+
|
134 |
+
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
|
135 |
+
```
|
136 |
+
[
|
137 |
+
[2, 3, 0, 0, 0, 0],
|
138 |
+
[3, 2, 0, 0, 0, 0],
|
139 |
+
[6, 0, 0, 0, 0, 0]
|
140 |
+
]
|
141 |
+
```
|
142 |
+
, which refers to the 3D-attention mask:
|
143 |
+
```
|
144 |
+
[
|
145 |
+
[
|
146 |
+
[1, 0, 0, 0, 0, 0],
|
147 |
+
[1, 1, 0, 0, 0, 0],
|
148 |
+
[0, 0, 1, 0, 0, 0],
|
149 |
+
[0, 0, 1, 1, 0, 0],
|
150 |
+
[0, 0, 1, 1, 1, 0],
|
151 |
+
[0, 0, 0, 0, 0, 1]
|
152 |
+
],
|
153 |
+
[
|
154 |
+
[1, 0, 0, 0, 0, 0],
|
155 |
+
[1, 1, 0, 0, 0, 0],
|
156 |
+
[1, 1, 1, 0, 0, 0],
|
157 |
+
[0, 0, 0, 1, 0, 0],
|
158 |
+
[0, 0, 0, 1, 1, 0],
|
159 |
+
[0, 0, 0, 0, 0, 1]
|
160 |
+
],
|
161 |
+
[
|
162 |
+
[1, 0, 0, 0, 0, 0],
|
163 |
+
[1, 1, 0, 0, 0, 0],
|
164 |
+
[1, 1, 1, 0, 0, 0],
|
165 |
+
[1, 1, 1, 1, 0, 0],
|
166 |
+
[1, 1, 1, 1, 1, 0],
|
167 |
+
[1, 1, 1, 1, 1, 1]
|
168 |
+
]
|
169 |
+
]
|
170 |
+
```.
|
171 |
+
|
172 |
+
Arguments:
|
173 |
+
hidden_states: (batch, seqlen, ...)
|
174 |
+
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
|
175 |
+
Return:
|
176 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
177 |
+
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
|
178 |
+
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
|
179 |
+
max_seqlen_in_batch: int
|
180 |
+
"""
|
181 |
+
length = attention_mask_in_length.sum(dim=-1)
|
182 |
+
seqlen = attention_mask_in_length.size(-1)
|
183 |
+
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
|
184 |
+
seqlen) < length.unsqueeze(
|
185 |
+
1)
|
186 |
+
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
|
187 |
+
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
|
188 |
+
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
|
189 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
190 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
191 |
+
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
192 |
+
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
193 |
+
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
|
194 |
+
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
|
195 |
+
# so we write custom forward and backward to make it a bit faster.
|
196 |
+
return (
|
197 |
+
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
|
198 |
+
indices,
|
199 |
+
cu_seqlens,
|
200 |
+
max_seqlen_in_batch,
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
def pad_input(hidden_states, indices, batch, seqlen):
|
205 |
+
"""
|
206 |
+
Arguments:
|
207 |
+
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
|
208 |
+
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
|
209 |
+
batch: int, batch size for the padded sequence.
|
210 |
+
seqlen: int, maximum sequence length for the padded sequence.
|
211 |
+
Return:
|
212 |
+
hidden_states: (batch, seqlen, ...)
|
213 |
+
"""
|
214 |
+
dim = hidden_states.shape[-1]
|
215 |
+
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
|
216 |
+
# output[indices] = hidden_states
|
217 |
+
output = index_put_first_axis(hidden_states, indices, batch * seqlen)
|
218 |
+
return rearrange(output, "(b s) ... -> b s ...", b=batch)
|