lv12 commited on
Commit
375ea9e
1 Parent(s): 2fce630

full set E-I triplets

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: []
3
+ library_name: sentence-transformers
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-similarity
7
+ - feature-extraction
8
+ - dataset_size:100K<n<1M
9
+ - loss:CachedMultipleNegativesRankingLoss
10
+ base_model: nomic-ai/nomic-embed-text-v1.5
11
+ metrics:
12
+ - cosine_accuracy
13
+ - dot_accuracy
14
+ - manhattan_accuracy
15
+ - euclidean_accuracy
16
+ - max_accuracy
17
+ widget:
18
+ - source_sentence: 'search_query: adorime'
19
+ sentences:
20
+ - 'search_query: green air scents llc'
21
+ - 'search_query: dpms sbr accessories'
22
+ - 'search_query: sweaters cowl neck men'
23
+ - source_sentence: 'search_query: serving'
24
+ sentences:
25
+ - 'search_query: ceramic cups without handles'
26
+ - 'search_query: 100 mm cigarette case'
27
+ - 'search_query: toddler girl leopard midi'
28
+ - source_sentence: 'search_query: haierc'
29
+ sentences:
30
+ - 'search_query: homder'
31
+ - 'search_query: 3d milling metal cnc'
32
+ - 'search_query: sandals for women'
33
+ - source_sentence: 'search_query: poppies'
34
+ sentences:
35
+ - 'search_query: fake plants without pot'
36
+ - 'search_query: tonsil stone remover'
37
+ - 'search_query: vestido corto sexy de mujer'
38
+ - source_sentence: 'search_query: dab rig'
39
+ sentences:
40
+ - 'search_query: volcano weed vaporizer'
41
+ - 'search_query: 22 gold chain for men'
42
+ - 'search_query: apple watch screen protector'
43
+ pipeline_tag: sentence-similarity
44
+ model-index:
45
+ - name: SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
46
+ results:
47
+ - task:
48
+ type: triplet
49
+ name: Triplet
50
+ dataset:
51
+ name: triplet esci
52
+ type: triplet-esci
53
+ metrics:
54
+ - type: cosine_accuracy
55
+ value: 0.7405
56
+ name: Cosine Accuracy
57
+ - type: dot_accuracy
58
+ value: 0.269
59
+ name: Dot Accuracy
60
+ - type: manhattan_accuracy
61
+ value: 0.7432
62
+ name: Manhattan Accuracy
63
+ - type: euclidean_accuracy
64
+ value: 0.7457
65
+ name: Euclidean Accuracy
66
+ - type: max_accuracy
67
+ value: 0.7457
68
+ name: Max Accuracy
69
+ ---
70
+
71
+ # SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
72
+
73
+ This is a [sentence-transformers](https://www.SBERT.net) model finetuned from [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5). It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
74
+
75
+ ## Model Details
76
+
77
+ ### Model Description
78
+ - **Model Type:** Sentence Transformer
79
+ - **Base model:** [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) <!-- at revision 91d2d6bfdddf0b0da840f901b533e99bae30d757 -->
80
+ - **Maximum Sequence Length:** 8192 tokens
81
+ - **Output Dimensionality:** 768 tokens
82
+ - **Similarity Function:** Cosine Similarity
83
+ <!-- - **Training Dataset:** Unknown -->
84
+ <!-- - **Language:** Unknown -->
85
+ <!-- - **License:** Unknown -->
86
+
87
+ ### Model Sources
88
+
89
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
90
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
91
+ - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
92
+
93
+ ### Full Model Architecture
94
+
95
+ ```
96
+ SentenceTransformer(
97
+ (0): Transformer({'max_seq_length': 8192, 'do_lower_case': False}) with Transformer model: NomicBertModel
98
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
99
+ )
100
+ ```
101
+
102
+ ## Usage
103
+
104
+ ### Direct Usage (Sentence Transformers)
105
+
106
+ First install the Sentence Transformers library:
107
+
108
+ ```bash
109
+ pip install -U sentence-transformers
110
+ ```
111
+
112
+ Then you can load this model and run inference.
113
+ ```python
114
+ from sentence_transformers import SentenceTransformer
115
+
116
+ # Download from the 🤗 Hub
117
+ model = SentenceTransformer("sentence_transformers_model_id")
118
+ # Run inference
119
+ sentences = [
120
+ 'search_query: dab rig',
121
+ 'search_query: volcano weed vaporizer',
122
+ 'search_query: 22 gold chain for men',
123
+ ]
124
+ embeddings = model.encode(sentences)
125
+ print(embeddings.shape)
126
+ # [3, 768]
127
+
128
+ # Get the similarity scores for the embeddings
129
+ similarities = model.similarity(embeddings, embeddings)
130
+ print(similarities.shape)
131
+ # [3, 3]
132
+ ```
133
+
134
+ <!--
135
+ ### Direct Usage (Transformers)
136
+
137
+ <details><summary>Click to see the direct usage in Transformers</summary>
138
+
139
+ </details>
140
+ -->
141
+
142
+ <!--
143
+ ### Downstream Usage (Sentence Transformers)
144
+
145
+ You can finetune this model on your own dataset.
146
+
147
+ <details><summary>Click to expand</summary>
148
+
149
+ </details>
150
+ -->
151
+
152
+ <!--
153
+ ### Out-of-Scope Use
154
+
155
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
156
+ -->
157
+
158
+ ## Evaluation
159
+
160
+ ### Metrics
161
+
162
+ #### Triplet
163
+ * Dataset: `triplet-esci`
164
+ * Evaluated with [<code>TripletEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.TripletEvaluator)
165
+
166
+ | Metric | Value |
167
+ |:--------------------|:-----------|
168
+ | **cosine_accuracy** | **0.7405** |
169
+ | dot_accuracy | 0.269 |
170
+ | manhattan_accuracy | 0.7432 |
171
+ | euclidean_accuracy | 0.7457 |
172
+ | max_accuracy | 0.7457 |
173
+
174
+ <!--
175
+ ## Bias, Risks and Limitations
176
+
177
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
178
+ -->
179
+
180
+ <!--
181
+ ### Recommendations
182
+
183
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
184
+ -->
185
+
186
+ ## Training Details
187
+
188
+ ### Training Dataset
189
+
190
+ #### Unnamed Dataset
191
+
192
+
193
+ * Size: 167,039 training samples
194
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
195
+ * Approximate statistics based on the first 1000 samples:
196
+ | | anchor | positive | negative |
197
+ |:--------|:---------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|
198
+ | type | string | string | string |
199
+ | details | <ul><li>min: 7 tokens</li><li>mean: 11.1 tokens</li><li>max: 38 tokens</li></ul> | <ul><li>min: 14 tokens</li><li>mean: 43.23 tokens</li><li>max: 124 tokens</li></ul> | <ul><li>min: 16 tokens</li><li>mean: 43.16 tokens</li><li>max: 97 tokens</li></ul> |
200
+ * Samples:
201
+ | anchor | positive | negative |
202
+ |:--------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
203
+ | <code>search_query: foos ball coffee table</code> | <code>search_document: KICK Vanquish 55" in Foosball Table, KICK, Blue/Gray</code> | <code>search_document: KICK Legend 55" Foosball Table (Black), KICK, Black</code> |
204
+ | <code>search_query: bathroom rugs white washable</code> | <code>search_document: Luxury Bath Mat Floor Towel Set - Absorbent Cotton Hotel Spa Shower/Bathtub Mats [Not a Bathroom Rug] 22"x34" | White | 2 Pack, White Classic, White</code> | <code>search_document: Utopia Towels Cotton Banded Bath Mats, White [Not a Bathroom Rug] 21 x 34 Inches, 100% Ring Spun Cotton - Highly Absorbent and Machine Washable Shower Bathroom Floor Mat (Pack of 2), Utopia Towels, White</code> |
205
+ | <code>search_query: kids gloves</code> | <code>search_document: EvridWear Boys Girls Magic Stretch Gripper Gloves 3 Pair Pack Assortment, Kids One Size Winter Warm Gloves Children (8-14Years, 3 Pairs Camo), Evridwear, 3 Pairs Camo</code> | <code>search_document: Body Glove Little Boys 2-Piece UPF 50+ Rash Guard Swimsuit Set (2 Piece), All Black, Size 5, Body Glove, All Black</code> |
206
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
207
+ ```json
208
+ {
209
+ "scale": 20.0,
210
+ "similarity_fct": "cos_sim"
211
+ }
212
+ ```
213
+
214
+ ### Evaluation Dataset
215
+
216
+ #### Unnamed Dataset
217
+
218
+
219
+ * Size: 10,000 evaluation samples
220
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
221
+ * Approximate statistics based on the first 1000 samples:
222
+ | | anchor | positive | negative |
223
+ |:--------|:----------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|
224
+ | type | string | string | string |
225
+ | details | <ul><li>min: 7 tokens</li><li>mean: 11.44 tokens</li><li>max: 31 tokens</li></ul> | <ul><li>min: 16 tokens</li><li>mean: 42.26 tokens</li><li>max: 92 tokens</li></ul> | <ul><li>min: 16 tokens</li><li>mean: 42.28 tokens</li><li>max: 105 tokens</li></ul> |
226
+ * Samples:
227
+ | anchor | positive | negative |
228
+ |:--------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
229
+ | <code>search_query: defender series iphone 8</code> | <code>search_document: Hand-e Muscle Series Belt Clip Case for Apple iPhone 7 / iPhone 8 / iPhone SE “2020” (4.7”) 2-in-1 Protective Defender w Screen Protector & Holster & Kickstand/Shock & Drop Proof – Camouflage/Orange, Hand-e, Camouflage / Orange</code> | <code>search_document: OtterBox Defender Series Rugged Case for iPhone 8 PLUS & iPhone 7 PLUS - Case Only - Non-Retail Packaging - Dark Lake - With Microbial Defense, OtterBox, Dark Lake</code> |
230
+ | <code>search_query: joy mangano</code> | <code>search_document: Joy by Joy Mangano 11-Piece Complete Luxury Towel Set, Ivory, Joy Mangano, Ivory</code> | <code>search_document: BAGSMART Jewelry Organizer Case Travel Jewelry Storage Bag for Necklace, Earrings, Rings, Bracelet, Soft Pink, BAGSMART, Soft Pink</code> |
231
+ | <code>search_query: cashel fly masks for horses without ears</code> | <code>search_document: Cashel Crusader Designer Horse Fly Mask, Leopard, Weanling, Cashel, Leopard</code> | <code>search_document: Cashel Crusader Designer Horse Fly Mask with Ears, Teal Tribal, Weanling, Cashel, Teal Tribal</code> |
232
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
233
+ ```json
234
+ {
235
+ "scale": 20.0,
236
+ "similarity_fct": "cos_sim"
237
+ }
238
+ ```
239
+
240
+ ### Training Hyperparameters
241
+ #### Non-Default Hyperparameters
242
+
243
+ - `per_device_train_batch_size`: 4
244
+ - `per_device_eval_batch_size`: 4
245
+ - `gradient_accumulation_steps`: 4
246
+ - `learning_rate`: 1e-06
247
+ - `num_train_epochs`: 5
248
+ - `lr_scheduler_type`: cosine_with_restarts
249
+ - `warmup_ratio`: 0.1
250
+ - `dataloader_drop_last`: True
251
+ - `dataloader_num_workers`: 4
252
+ - `dataloader_prefetch_factor`: 2
253
+ - `load_best_model_at_end`: True
254
+ - `batch_sampler`: no_duplicates
255
+
256
+ #### All Hyperparameters
257
+ <details><summary>Click to expand</summary>
258
+
259
+ - `overwrite_output_dir`: False
260
+ - `do_predict`: False
261
+ - `prediction_loss_only`: True
262
+ - `per_device_train_batch_size`: 4
263
+ - `per_device_eval_batch_size`: 4
264
+ - `per_gpu_train_batch_size`: None
265
+ - `per_gpu_eval_batch_size`: None
266
+ - `gradient_accumulation_steps`: 4
267
+ - `eval_accumulation_steps`: None
268
+ - `learning_rate`: 1e-06
269
+ - `weight_decay`: 0.0
270
+ - `adam_beta1`: 0.9
271
+ - `adam_beta2`: 0.999
272
+ - `adam_epsilon`: 1e-08
273
+ - `max_grad_norm`: 1.0
274
+ - `num_train_epochs`: 5
275
+ - `max_steps`: -1
276
+ - `lr_scheduler_type`: cosine_with_restarts
277
+ - `lr_scheduler_kwargs`: {}
278
+ - `warmup_ratio`: 0.1
279
+ - `warmup_steps`: 0
280
+ - `log_level`: passive
281
+ - `log_level_replica`: warning
282
+ - `log_on_each_node`: True
283
+ - `logging_nan_inf_filter`: True
284
+ - `save_safetensors`: True
285
+ - `save_on_each_node`: False
286
+ - `save_only_model`: False
287
+ - `no_cuda`: False
288
+ - `use_cpu`: False
289
+ - `use_mps_device`: False
290
+ - `seed`: 42
291
+ - `data_seed`: None
292
+ - `jit_mode_eval`: False
293
+ - `use_ipex`: False
294
+ - `bf16`: False
295
+ - `fp16`: False
296
+ - `fp16_opt_level`: O1
297
+ - `half_precision_backend`: auto
298
+ - `bf16_full_eval`: False
299
+ - `fp16_full_eval`: False
300
+ - `tf32`: None
301
+ - `local_rank`: 0
302
+ - `ddp_backend`: None
303
+ - `tpu_num_cores`: None
304
+ - `tpu_metrics_debug`: False
305
+ - `debug`: []
306
+ - `dataloader_drop_last`: True
307
+ - `dataloader_num_workers`: 4
308
+ - `dataloader_prefetch_factor`: 2
309
+ - `past_index`: -1
310
+ - `disable_tqdm`: False
311
+ - `remove_unused_columns`: True
312
+ - `label_names`: None
313
+ - `load_best_model_at_end`: True
314
+ - `ignore_data_skip`: False
315
+ - `fsdp`: []
316
+ - `fsdp_min_num_params`: 0
317
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
318
+ - `fsdp_transformer_layer_cls_to_wrap`: None
319
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True}
320
+ - `deepspeed`: None
321
+ - `label_smoothing_factor`: 0.0
322
+ - `optim`: adamw_torch
323
+ - `optim_args`: None
324
+ - `adafactor`: False
325
+ - `group_by_length`: False
326
+ - `length_column_name`: length
327
+ - `ddp_find_unused_parameters`: None
328
+ - `ddp_bucket_cap_mb`: None
329
+ - `ddp_broadcast_buffers`: False
330
+ - `dataloader_pin_memory`: True
331
+ - `dataloader_persistent_workers`: False
332
+ - `skip_memory_metrics`: True
333
+ - `use_legacy_prediction_loop`: False
334
+ - `push_to_hub`: False
335
+ - `resume_from_checkpoint`: None
336
+ - `hub_model_id`: None
337
+ - `hub_strategy`: every_save
338
+ - `hub_private_repo`: False
339
+ - `hub_always_push`: False
340
+ - `gradient_checkpointing`: False
341
+ - `gradient_checkpointing_kwargs`: None
342
+ - `include_inputs_for_metrics`: False
343
+ - `fp16_backend`: auto
344
+ - `push_to_hub_model_id`: None
345
+ - `push_to_hub_organization`: None
346
+ - `mp_parameters`:
347
+ - `auto_find_batch_size`: False
348
+ - `full_determinism`: False
349
+ - `torchdynamo`: None
350
+ - `ray_scope`: last
351
+ - `ddp_timeout`: 1800
352
+ - `torch_compile`: False
353
+ - `torch_compile_backend`: None
354
+ - `torch_compile_mode`: None
355
+ - `dispatch_batches`: None
356
+ - `split_batches`: None
357
+ - `include_tokens_per_second`: False
358
+ - `include_num_input_tokens_seen`: False
359
+ - `neftune_noise_alpha`: None
360
+ - `batch_sampler`: no_duplicates
361
+ - `multi_dataset_batch_sampler`: proportional
362
+
363
+ </details>
364
+
365
+ ### Training Logs
366
+ <details><summary>Click to expand</summary>
367
+
368
+ | Epoch | Step | Training Loss | loss | triplet-esci_cosine_accuracy |
369
+ |:------:|:-----:|:-------------:|:------:|:----------------------------:|
370
+ | 0.0096 | 100 | 0.6669 | - | - |
371
+ | 0.0192 | 200 | 0.6633 | - | - |
372
+ | 0.0287 | 300 | 0.6575 | - | - |
373
+ | 0.0383 | 400 | 0.6638 | - | - |
374
+ | 0.0479 | 500 | 0.6191 | - | - |
375
+ | 0.0575 | 600 | 0.6464 | - | - |
376
+ | 0.0671 | 700 | 0.6291 | - | - |
377
+ | 0.0766 | 800 | 0.5973 | - | - |
378
+ | 0.0862 | 900 | 0.605 | - | - |
379
+ | 0.0958 | 1000 | 0.6278 | 0.6525 | 0.7269 |
380
+ | 0.1054 | 1100 | 0.6041 | - | - |
381
+ | 0.1149 | 1200 | 0.6077 | - | - |
382
+ | 0.1245 | 1300 | 0.589 | - | - |
383
+ | 0.1341 | 1400 | 0.5811 | - | - |
384
+ | 0.1437 | 1500 | 0.5512 | - | - |
385
+ | 0.1533 | 1600 | 0.5907 | - | - |
386
+ | 0.1628 | 1700 | 0.5718 | - | - |
387
+ | 0.1724 | 1800 | 0.5446 | - | - |
388
+ | 0.1820 | 1900 | 0.546 | - | - |
389
+ | 0.1916 | 2000 | 0.5141 | 0.6105 | 0.7386 |
390
+ | 0.2012 | 2100 | 0.5359 | - | - |
391
+ | 0.2107 | 2200 | 0.5093 | - | - |
392
+ | 0.2203 | 2300 | 0.5384 | - | - |
393
+ | 0.2299 | 2400 | 0.5582 | - | - |
394
+ | 0.2395 | 2500 | 0.5038 | - | - |
395
+ | 0.2490 | 2600 | 0.5031 | - | - |
396
+ | 0.2586 | 2700 | 0.5393 | - | - |
397
+ | 0.2682 | 2800 | 0.4979 | - | - |
398
+ | 0.2778 | 2900 | 0.5221 | - | - |
399
+ | 0.2874 | 3000 | 0.4956 | 0.5852 | 0.7495 |
400
+ | 0.2969 | 3100 | 0.506 | - | - |
401
+ | 0.3065 | 3200 | 0.4962 | - | - |
402
+ | 0.3161 | 3300 | 0.4713 | - | - |
403
+ | 0.3257 | 3400 | 0.5016 | - | - |
404
+ | 0.3353 | 3500 | 0.4749 | - | - |
405
+ | 0.3448 | 3600 | 0.4732 | - | - |
406
+ | 0.3544 | 3700 | 0.4789 | - | - |
407
+ | 0.3640 | 3800 | 0.4825 | - | - |
408
+ | 0.3736 | 3900 | 0.4803 | - | - |
409
+ | 0.3832 | 4000 | 0.4471 | 0.5743 | 0.7546 |
410
+ | 0.3927 | 4100 | 0.4593 | - | - |
411
+ | 0.4023 | 4200 | 0.4481 | - | - |
412
+ | 0.4119 | 4300 | 0.4603 | - | - |
413
+ | 0.4215 | 4400 | 0.4569 | - | - |
414
+ | 0.4310 | 4500 | 0.4807 | - | - |
415
+ | 0.4406 | 4600 | 0.4368 | - | - |
416
+ | 0.4502 | 4700 | 0.4532 | - | - |
417
+ | 0.4598 | 4800 | 0.4432 | - | - |
418
+ | 0.4694 | 4900 | 0.4802 | - | - |
419
+ | 0.4789 | 5000 | 0.4643 | 0.5663 | 0.7593 |
420
+ | 0.4885 | 5100 | 0.4154 | - | - |
421
+ | 0.4981 | 5200 | 0.4441 | - | - |
422
+ | 0.5077 | 5300 | 0.4156 | - | - |
423
+ | 0.5173 | 5400 | 0.4273 | - | - |
424
+ | 0.5268 | 5500 | 0.3988 | - | - |
425
+ | 0.5364 | 5600 | 0.3942 | - | - |
426
+ | 0.5460 | 5700 | 0.4186 | - | - |
427
+ | 0.5556 | 5800 | 0.423 | - | - |
428
+ | 0.5651 | 5900 | 0.434 | - | - |
429
+ | 0.5747 | 6000 | 0.4136 | 0.5704 | 0.7616 |
430
+ | 0.5843 | 6100 | 0.3968 | - | - |
431
+ | 0.5939 | 6200 | 0.4045 | - | - |
432
+ | 0.6035 | 6300 | 0.4122 | - | - |
433
+ | 0.6130 | 6400 | 0.3618 | - | - |
434
+ | 0.6226 | 6500 | 0.341 | - | - |
435
+ | 0.6322 | 6600 | 0.3689 | - | - |
436
+ | 0.6418 | 6700 | 0.3621 | - | - |
437
+ | 0.6514 | 6800 | 0.3774 | - | - |
438
+ | 0.6609 | 6900 | 0.3519 | - | - |
439
+ | 0.6705 | 7000 | 0.3974 | 0.5729 | 0.7644 |
440
+ | 0.6801 | 7100 | 0.3443 | - | - |
441
+ | 0.6897 | 7200 | 0.3665 | - | - |
442
+ | 0.6993 | 7300 | 0.3683 | - | - |
443
+ | 0.7088 | 7400 | 0.3593 | - | - |
444
+ | 0.7184 | 7500 | 0.3419 | - | - |
445
+ | 0.7280 | 7600 | 0.3587 | - | - |
446
+ | 0.7376 | 7700 | 0.3463 | - | - |
447
+ | 0.7471 | 7800 | 0.3417 | - | - |
448
+ | 0.7567 | 7900 | 0.32 | - | - |
449
+ | 0.7663 | 8000 | 0.32 | 0.5735 | 0.7677 |
450
+ | 0.7759 | 8100 | 0.3296 | - | - |
451
+ | 0.7855 | 8200 | 0.3492 | - | - |
452
+ | 0.7950 | 8300 | 0.3022 | - | - |
453
+ | 0.8046 | 8400 | 0.3159 | - | - |
454
+ | 0.8142 | 8500 | 0.3172 | - | - |
455
+ | 0.8238 | 8600 | 0.3157 | - | - |
456
+ | 0.8334 | 8700 | 0.3271 | - | - |
457
+ | 0.8429 | 8800 | 0.337 | - | - |
458
+ | 0.8525 | 8900 | 0.322 | - | - |
459
+ | 0.8621 | 9000 | 0.3187 | 0.5803 | 0.7652 |
460
+ | 0.8717 | 9100 | 0.307 | - | - |
461
+ | 0.8812 | 9200 | 0.2984 | - | - |
462
+ | 0.8908 | 9300 | 0.2727 | - | - |
463
+ | 0.9004 | 9400 | 0.304 | - | - |
464
+ | 0.9100 | 9500 | 0.321 | - | - |
465
+ | 0.9196 | 9600 | 0.304 | - | - |
466
+ | 0.9291 | 9700 | 0.3302 | - | - |
467
+ | 0.9387 | 9800 | 0.3302 | - | - |
468
+ | 0.9483 | 9900 | 0.3134 | - | - |
469
+ | 0.9579 | 10000 | 0.2936 | 0.5858 | 0.7671 |
470
+ | 0.9675 | 10100 | 0.2953 | - | - |
471
+ | 0.9770 | 10200 | 0.3035 | - | - |
472
+ | 0.9866 | 10300 | 0.303 | - | - |
473
+ | 0.9962 | 10400 | 0.2606 | - | - |
474
+ | 1.0058 | 10500 | 0.2615 | - | - |
475
+ | 1.0153 | 10600 | 0.2703 | - | - |
476
+ | 1.0249 | 10700 | 0.2761 | - | - |
477
+ | 1.0345 | 10800 | 0.2559 | - | - |
478
+ | 1.0441 | 10900 | 0.2672 | - | - |
479
+ | 1.0537 | 11000 | 0.2656 | 0.5933 | 0.7676 |
480
+ | 1.0632 | 11100 | 0.2825 | - | - |
481
+ | 1.0728 | 11200 | 0.2484 | - | - |
482
+ | 1.0824 | 11300 | 0.2472 | - | - |
483
+ | 1.0920 | 11400 | 0.2678 | - | - |
484
+ | 1.1016 | 11500 | 0.2443 | - | - |
485
+ | 1.1111 | 11600 | 0.2685 | - | - |
486
+ | 1.1207 | 11700 | 0.2504 | - | - |
487
+ | 1.1303 | 11800 | 0.2431 | - | - |
488
+ | 1.1399 | 11900 | 0.2248 | - | - |
489
+ | 1.1495 | 12000 | 0.2229 | 0.5958 | 0.7688 |
490
+ | 1.1590 | 12100 | 0.228 | - | - |
491
+ | 1.1686 | 12200 | 0.2304 | - | - |
492
+ | 1.1782 | 12300 | 0.2193 | - | - |
493
+ | 1.1878 | 12400 | 0.2238 | - | - |
494
+ | 1.1973 | 12500 | 0.1957 | - | - |
495
+ | 1.2069 | 12600 | 0.2075 | - | - |
496
+ | 1.2165 | 12700 | 0.2014 | - | - |
497
+ | 1.2261 | 12800 | 0.2222 | - | - |
498
+ | 1.2357 | 12900 | 0.2059 | - | - |
499
+ | 1.2452 | 13000 | 0.2051 | 0.6077 | 0.7651 |
500
+ | 1.2548 | 13100 | 0.2076 | - | - |
501
+ | 1.2644 | 13200 | 0.226 | - | - |
502
+ | 1.2740 | 13300 | 0.1941 | - | - |
503
+ | 1.2836 | 13400 | 0.2053 | - | - |
504
+ | 1.2931 | 13500 | 0.2003 | - | - |
505
+ | 1.3027 | 13600 | 0.1947 | - | - |
506
+ | 1.3123 | 13700 | 0.1914 | - | - |
507
+ | 1.3219 | 13800 | 0.1956 | - | - |
508
+ | 1.3314 | 13900 | 0.1862 | - | - |
509
+ | 1.3410 | 14000 | 0.1873 | 0.6110 | 0.7646 |
510
+ | 1.3506 | 14100 | 0.1812 | - | - |
511
+ | 1.3602 | 14200 | 0.1828 | - | - |
512
+ | 1.3698 | 14300 | 0.1696 | - | - |
513
+ | 1.3793 | 14400 | 0.1705 | - | - |
514
+ | 1.3889 | 14500 | 0.1746 | - | - |
515
+ | 1.3985 | 14600 | 0.1756 | - | - |
516
+ | 1.4081 | 14700 | 0.1682 | - | - |
517
+ | 1.4177 | 14800 | 0.1769 | - | - |
518
+ | 1.4272 | 14900 | 0.1795 | - | - |
519
+ | 1.4368 | 15000 | 0.1736 | 0.6278 | 0.7616 |
520
+ | 1.4464 | 15100 | 0.1546 | - | - |
521
+ | 1.4560 | 15200 | 0.1643 | - | - |
522
+ | 1.4656 | 15300 | 0.1903 | - | - |
523
+ | 1.4751 | 15400 | 0.1902 | - | - |
524
+ | 1.4847 | 15500 | 0.1531 | - | - |
525
+ | 1.4943 | 15600 | 0.1711 | - | - |
526
+ | 1.5039 | 15700 | 0.1546 | - | - |
527
+ | 1.5134 | 15800 | 0.1503 | - | - |
528
+ | 1.5230 | 15900 | 0.1429 | - | - |
529
+ | 1.5326 | 16000 | 0.147 | 0.6306 | 0.7623 |
530
+ | 1.5422 | 16100 | 0.1507 | - | - |
531
+ | 1.5518 | 16200 | 0.152 | - | - |
532
+ | 1.5613 | 16300 | 0.1602 | - | - |
533
+ | 1.5709 | 16400 | 0.1541 | - | - |
534
+ | 1.5805 | 16500 | 0.1491 | - | - |
535
+ | 1.5901 | 16600 | 0.1378 | - | - |
536
+ | 1.5997 | 16700 | 0.1505 | - | - |
537
+ | 1.6092 | 16800 | 0.1334 | - | - |
538
+ | 1.6188 | 16900 | 0.1288 | - | - |
539
+ | 1.6284 | 17000 | 0.1168 | 0.6372 | 0.7629 |
540
+ | 1.6380 | 17100 | 0.135 | - | - |
541
+ | 1.6475 | 17200 | 0.1239 | - | - |
542
+ | 1.6571 | 17300 | 0.1398 | - | - |
543
+ | 1.6667 | 17400 | 0.1292 | - | - |
544
+ | 1.6763 | 17500 | 0.1414 | - | - |
545
+ | 1.6859 | 17600 | 0.116 | - | - |
546
+ | 1.6954 | 17700 | 0.1302 | - | - |
547
+ | 1.7050 | 17800 | 0.1194 | - | - |
548
+ | 1.7146 | 17900 | 0.1394 | - | - |
549
+ | 1.7242 | 18000 | 0.1316 | 0.6561 | 0.7592 |
550
+ | 1.7338 | 18100 | 0.1246 | - | - |
551
+ | 1.7433 | 18200 | 0.1277 | - | - |
552
+ | 1.7529 | 18300 | 0.1055 | - | - |
553
+ | 1.7625 | 18400 | 0.1211 | - | - |
554
+ | 1.7721 | 18500 | 0.1107 | - | - |
555
+ | 1.7817 | 18600 | 0.1145 | - | - |
556
+ | 1.7912 | 18700 | 0.1162 | - | - |
557
+ | 1.8008 | 18800 | 0.1114 | - | - |
558
+ | 1.8104 | 18900 | 0.1182 | - | - |
559
+ | 1.8200 | 19000 | 0.1152 | 0.6567 | 0.7591 |
560
+ | 1.8295 | 19100 | 0.1212 | - | - |
561
+ | 1.8391 | 19200 | 0.1253 | - | - |
562
+ | 1.8487 | 19300 | 0.115 | - | - |
563
+ | 1.8583 | 19400 | 0.1292 | - | - |
564
+ | 1.8679 | 19500 | 0.1151 | - | - |
565
+ | 1.8774 | 19600 | 0.1005 | - | - |
566
+ | 1.8870 | 19700 | 0.1079 | - | - |
567
+ | 1.8966 | 19800 | 0.0954 | - | - |
568
+ | 1.9062 | 19900 | 0.1045 | - | - |
569
+ | 1.9158 | 20000 | 0.1086 | 0.6727 | 0.7554 |
570
+ | 1.9253 | 20100 | 0.1174 | - | - |
571
+ | 1.9349 | 20200 | 0.1108 | - | - |
572
+ | 1.9445 | 20300 | 0.0992 | - | - |
573
+ | 1.9541 | 20400 | 0.1168 | - | - |
574
+ | 1.9636 | 20500 | 0.1028 | - | - |
575
+ | 1.9732 | 20600 | 0.1126 | - | - |
576
+ | 1.9828 | 20700 | 0.1113 | - | - |
577
+ | 1.9924 | 20800 | 0.1065 | - | - |
578
+ | 2.0020 | 20900 | 0.078 | - | - |
579
+ | 2.0115 | 21000 | 0.0921 | 0.6727 | 0.7568 |
580
+ | 2.0211 | 21100 | 0.0866 | - | - |
581
+ | 2.0307 | 21200 | 0.0918 | - | - |
582
+ | 2.0403 | 21300 | 0.0893 | - | - |
583
+ | 2.0499 | 21400 | 0.0882 | - | - |
584
+ | 2.0594 | 21500 | 0.0986 | - | - |
585
+ | 2.0690 | 21600 | 0.0923 | - | - |
586
+ | 2.0786 | 21700 | 0.0805 | - | - |
587
+ | 2.0882 | 21800 | 0.0887 | - | - |
588
+ | 2.0978 | 21900 | 0.1 | - | - |
589
+ | 2.1073 | 22000 | 0.0957 | 0.6854 | 0.7539 |
590
+ | 2.1169 | 22100 | 0.0921 | - | - |
591
+ | 2.1265 | 22200 | 0.0892 | - | - |
592
+ | 2.1361 | 22300 | 0.0805 | - | - |
593
+ | 2.1456 | 22400 | 0.0767 | - | - |
594
+ | 2.1552 | 22500 | 0.0715 | - | - |
595
+ | 2.1648 | 22600 | 0.083 | - | - |
596
+ | 2.1744 | 22700 | 0.0755 | - | - |
597
+ | 2.1840 | 22800 | 0.075 | - | - |
598
+ | 2.1935 | 22900 | 0.0724 | - | - |
599
+ | 2.2031 | 23000 | 0.0822 | 0.6913 | 0.7534 |
600
+ | 2.2127 | 23100 | 0.0623 | - | - |
601
+ | 2.2223 | 23200 | 0.0765 | - | - |
602
+ | 2.2319 | 23300 | 0.0755 | - | - |
603
+ | 2.2414 | 23400 | 0.0786 | - | - |
604
+ | 2.2510 | 23500 | 0.0651 | - | - |
605
+ | 2.2606 | 23600 | 0.081 | - | - |
606
+ | 2.2702 | 23700 | 0.0664 | - | - |
607
+ | 2.2797 | 23800 | 0.0906 | - | - |
608
+ | 2.2893 | 23900 | 0.0714 | - | - |
609
+ | 2.2989 | 24000 | 0.0703 | 0.6971 | 0.7536 |
610
+ | 2.3085 | 24100 | 0.0672 | - | - |
611
+ | 2.3181 | 24200 | 0.0754 | - | - |
612
+ | 2.3276 | 24300 | 0.0687 | - | - |
613
+ | 2.3372 | 24400 | 0.0668 | - | - |
614
+ | 2.3468 | 24500 | 0.0616 | - | - |
615
+ | 2.3564 | 24600 | 0.0693 | - | - |
616
+ | 2.3660 | 24700 | 0.0587 | - | - |
617
+ | 2.3755 | 24800 | 0.0612 | - | - |
618
+ | 2.3851 | 24900 | 0.0559 | - | - |
619
+ | 2.3947 | 25000 | 0.0676 | 0.7128 | 0.7497 |
620
+ | 2.4043 | 25100 | 0.0607 | - | - |
621
+ | 2.4139 | 25200 | 0.0727 | - | - |
622
+ | 2.4234 | 25300 | 0.0573 | - | - |
623
+ | 2.4330 | 25400 | 0.0717 | - | - |
624
+ | 2.4426 | 25500 | 0.0493 | - | - |
625
+ | 2.4522 | 25600 | 0.0558 | - | - |
626
+ | 2.4617 | 25700 | 0.0676 | - | - |
627
+ | 2.4713 | 25800 | 0.0757 | - | - |
628
+ | 2.4809 | 25900 | 0.0735 | - | - |
629
+ | 2.4905 | 26000 | 0.056 | 0.7044 | 0.7513 |
630
+ | 2.5001 | 26100 | 0.0687 | - | - |
631
+ | 2.5096 | 26200 | 0.0592 | - | - |
632
+ | 2.5192 | 26300 | 0.057 | - | - |
633
+ | 2.5288 | 26400 | 0.0444 | - | - |
634
+ | 2.5384 | 26500 | 0.0547 | - | - |
635
+ | 2.5480 | 26600 | 0.0605 | - | - |
636
+ | 2.5575 | 26700 | 0.066 | - | - |
637
+ | 2.5671 | 26800 | 0.0631 | - | - |
638
+ | 2.5767 | 26900 | 0.0634 | - | - |
639
+ | 2.5863 | 27000 | 0.0537 | 0.7127 | 0.7512 |
640
+ | 2.5958 | 27100 | 0.0535 | - | - |
641
+ | 2.6054 | 27200 | 0.0572 | - | - |
642
+ | 2.6150 | 27300 | 0.0473 | - | - |
643
+ | 2.6246 | 27400 | 0.0418 | - | - |
644
+ | 2.6342 | 27500 | 0.0585 | - | - |
645
+ | 2.6437 | 27600 | 0.0475 | - | - |
646
+ | 2.6533 | 27700 | 0.0549 | - | - |
647
+ | 2.6629 | 27800 | 0.0452 | - | - |
648
+ | 2.6725 | 27900 | 0.0514 | - | - |
649
+ | 2.6821 | 28000 | 0.0449 | 0.7337 | 0.7482 |
650
+ | 2.6916 | 28100 | 0.0544 | - | - |
651
+ | 2.7012 | 28200 | 0.041 | - | - |
652
+ | 2.7108 | 28300 | 0.0599 | - | - |
653
+ | 2.7204 | 28400 | 0.057 | - | - |
654
+ | 2.7300 | 28500 | 0.0503 | - | - |
655
+ | 2.7395 | 28600 | 0.0487 | - | - |
656
+ | 2.7491 | 28700 | 0.0503 | - | - |
657
+ | 2.7587 | 28800 | 0.0446 | - | - |
658
+ | 2.7683 | 28900 | 0.042 | - | - |
659
+ | 2.7778 | 29000 | 0.0501 | 0.7422 | 0.7469 |
660
+ | 2.7874 | 29100 | 0.0494 | - | - |
661
+ | 2.7970 | 29200 | 0.0423 | - | - |
662
+ | 2.8066 | 29300 | 0.0508 | - | - |
663
+ | 2.8162 | 29400 | 0.0459 | - | - |
664
+ | 2.8257 | 29500 | 0.0514 | - | - |
665
+ | 2.8353 | 29600 | 0.0484 | - | - |
666
+ | 2.8449 | 29700 | 0.0571 | - | - |
667
+ | 2.8545 | 29800 | 0.0558 | - | - |
668
+ | 2.8641 | 29900 | 0.0466 | - | - |
669
+ | 2.8736 | 30000 | 0.0465 | 0.7478 | 0.7447 |
670
+ | 2.8832 | 30100 | 0.0463 | - | - |
671
+ | 2.8928 | 30200 | 0.0362 | - | - |
672
+ | 2.9024 | 30300 | 0.0435 | - | - |
673
+ | 2.9119 | 30400 | 0.0419 | - | - |
674
+ | 2.9215 | 30500 | 0.046 | - | - |
675
+ | 2.9311 | 30600 | 0.0451 | - | - |
676
+ | 2.9407 | 30700 | 0.0458 | - | - |
677
+ | 2.9503 | 30800 | 0.052 | - | - |
678
+ | 2.9598 | 30900 | 0.0454 | - | - |
679
+ | 2.9694 | 31000 | 0.0433 | 0.7580 | 0.745 |
680
+ | 2.9790 | 31100 | 0.0438 | - | - |
681
+ | 2.9886 | 31200 | 0.0537 | - | - |
682
+ | 2.9982 | 31300 | 0.033 | - | - |
683
+ | 3.0077 | 31400 | 0.0384 | - | - |
684
+ | 3.0173 | 31500 | 0.0349 | - | - |
685
+ | 3.0269 | 31600 | 0.0365 | - | - |
686
+ | 3.0365 | 31700 | 0.0397 | - | - |
687
+ | 3.0460 | 31800 | 0.0396 | - | - |
688
+ | 3.0556 | 31900 | 0.0358 | - | - |
689
+ | 3.0652 | 32000 | 0.0443 | 0.7592 | 0.7454 |
690
+ | 3.0748 | 32100 | 0.0323 | - | - |
691
+ | 3.0844 | 32200 | 0.0418 | - | - |
692
+ | 3.0939 | 32300 | 0.0463 | - | - |
693
+ | 3.1035 | 32400 | 0.0397 | - | - |
694
+ | 3.1131 | 32500 | 0.0425 | - | - |
695
+ | 3.1227 | 32600 | 0.0406 | - | - |
696
+ | 3.1323 | 32700 | 0.0454 | - | - |
697
+ | 3.1418 | 32800 | 0.0287 | - | - |
698
+ | 3.1514 | 32900 | 0.0267 | - | - |
699
+ | 3.1610 | 33000 | 0.0341 | 0.7672 | 0.7431 |
700
+ | 3.1706 | 33100 | 0.0357 | - | - |
701
+ | 3.1802 | 33200 | 0.0322 | - | - |
702
+ | 3.1897 | 33300 | 0.0367 | - | - |
703
+ | 3.1993 | 33400 | 0.0419 | - | - |
704
+ | 3.2089 | 33500 | 0.0349 | - | - |
705
+ | 3.2185 | 33600 | 0.0327 | - | - |
706
+ | 3.2280 | 33700 | 0.0377 | - | - |
707
+ | 3.2376 | 33800 | 0.0353 | - | - |
708
+ | 3.2472 | 33900 | 0.0305 | - | - |
709
+ | 3.2568 | 34000 | 0.0362 | 0.7668 | 0.7463 |
710
+ | 3.2664 | 34100 | 0.0311 | - | - |
711
+ | 3.2759 | 34200 | 0.0405 | - | - |
712
+ | 3.2855 | 34300 | 0.0401 | - | - |
713
+ | 3.2951 | 34400 | 0.0361 | - | - |
714
+ | 3.3047 | 34500 | 0.0302 | - | - |
715
+ | 3.3143 | 34600 | 0.0379 | - | - |
716
+ | 3.3238 | 34700 | 0.03 | - | - |
717
+ | 3.3334 | 34800 | 0.039 | - | - |
718
+ | 3.3430 | 34900 | 0.0288 | - | - |
719
+ | 3.3526 | 35000 | 0.0318 | 0.7782 | 0.7436 |
720
+ | 3.3621 | 35100 | 0.0283 | - | - |
721
+ | 3.3717 | 35200 | 0.029 | - | - |
722
+ | 3.3813 | 35300 | 0.0287 | - | - |
723
+ | 3.3909 | 35400 | 0.0343 | - | - |
724
+ | 3.4005 | 35500 | 0.0326 | - | - |
725
+ | 3.4100 | 35600 | 0.031 | - | - |
726
+ | 3.4196 | 35700 | 0.0304 | - | - |
727
+ | 3.4292 | 35800 | 0.0314 | - | - |
728
+ | 3.4388 | 35900 | 0.0286 | - | - |
729
+ | 3.4484 | 36000 | 0.0229 | 0.7978 | 0.7428 |
730
+ | 3.4579 | 36100 | 0.0258 | - | - |
731
+ | 3.4675 | 36200 | 0.043 | - | - |
732
+ | 3.4771 | 36300 | 0.042 | - | - |
733
+ | 3.4867 | 36400 | 0.029 | - | - |
734
+ | 3.4963 | 36500 | 0.0343 | - | - |
735
+ | 3.5058 | 36600 | 0.0317 | - | - |
736
+ | 3.5154 | 36700 | 0.0307 | - | - |
737
+ | 3.5250 | 36800 | 0.0251 | - | - |
738
+ | 3.5346 | 36900 | 0.025 | - | - |
739
+ | 3.5441 | 37000 | 0.0309 | 0.8002 | 0.7446 |
740
+ | 3.5537 | 37100 | 0.031 | - | - |
741
+ | 3.5633 | 37200 | 0.0345 | - | - |
742
+ | 3.5729 | 37300 | 0.0332 | - | - |
743
+ | 3.5825 | 37400 | 0.0346 | - | - |
744
+ | 3.5920 | 37500 | 0.026 | - | - |
745
+ | 3.6016 | 37600 | 0.0293 | - | - |
746
+ | 3.6112 | 37700 | 0.0268 | - | - |
747
+ | 3.6208 | 37800 | 0.0264 | - | - |
748
+ | 3.6304 | 37900 | 0.0259 | - | - |
749
+ | 3.6399 | 38000 | 0.032 | 0.7896 | 0.7438 |
750
+ | 3.6495 | 38100 | 0.0246 | - | - |
751
+ | 3.6591 | 38200 | 0.0279 | - | - |
752
+ | 3.6687 | 38300 | 0.0274 | - | - |
753
+ | 3.6782 | 38400 | 0.0241 | - | - |
754
+ | 3.6878 | 38500 | 0.027 | - | - |
755
+ | 3.6974 | 38600 | 0.022 | - | - |
756
+ | 3.7070 | 38700 | 0.0305 | - | - |
757
+ | 3.7166 | 38800 | 0.0368 | - | - |
758
+ | 3.7261 | 38900 | 0.0304 | - | - |
759
+ | 3.7357 | 39000 | 0.0249 | 0.7978 | 0.7437 |
760
+ | 3.7453 | 39100 | 0.0312 | - | - |
761
+ | 3.7549 | 39200 | 0.0257 | - | - |
762
+ | 3.7645 | 39300 | 0.0273 | - | - |
763
+ | 3.7740 | 39400 | 0.0209 | - | - |
764
+ | 3.7836 | 39500 | 0.0298 | - | - |
765
+ | 3.7932 | 39600 | 0.0282 | - | - |
766
+ | 3.8028 | 39700 | 0.028 | - | - |
767
+ | 3.8124 | 39800 | 0.0279 | - | - |
768
+ | 3.8219 | 39900 | 0.0283 | - | - |
769
+ | 3.8315 | 40000 | 0.0239 | 0.7982 | 0.7424 |
770
+ | 3.8411 | 40100 | 0.0378 | - | - |
771
+ | 3.8507 | 40200 | 0.028 | - | - |
772
+ | 3.8602 | 40300 | 0.0321 | - | - |
773
+ | 3.8698 | 40400 | 0.0289 | - | - |
774
+ | 3.8794 | 40500 | 0.027 | - | - |
775
+ | 3.8890 | 40600 | 0.0224 | - | - |
776
+ | 3.8986 | 40700 | 0.0236 | - | - |
777
+ | 3.9081 | 40800 | 0.0267 | - | - |
778
+ | 3.9177 | 40900 | 0.0228 | - | - |
779
+ | 3.9273 | 41000 | 0.0322 | 0.8101 | 0.7415 |
780
+ | 3.9369 | 41100 | 0.0262 | - | - |
781
+ | 3.9465 | 41200 | 0.0276 | - | - |
782
+ | 3.9560 | 41300 | 0.0292 | - | - |
783
+ | 3.9656 | 41400 | 0.0278 | - | - |
784
+ | 3.9752 | 41500 | 0.0262 | - | - |
785
+ | 3.9848 | 41600 | 0.0306 | - | - |
786
+ | 3.9943 | 41700 | 0.0238 | - | - |
787
+ | 4.0039 | 41800 | 0.0165 | - | - |
788
+ | 4.0135 | 41900 | 0.0241 | - | - |
789
+ | 4.0231 | 42000 | 0.0211 | 0.8092 | 0.742 |
790
+ | 4.0327 | 42100 | 0.0257 | - | - |
791
+ | 4.0422 | 42200 | 0.0236 | - | - |
792
+ | 4.0518 | 42300 | 0.0254 | - | - |
793
+ | 4.0614 | 42400 | 0.0248 | - | - |
794
+ | 4.0710 | 42500 | 0.026 | - | - |
795
+ | 4.0806 | 42600 | 0.0245 | - | - |
796
+ | 4.0901 | 42700 | 0.0325 | - | - |
797
+ | 4.0997 | 42800 | 0.0209 | - | - |
798
+ | 4.1093 | 42900 | 0.033 | - | - |
799
+ | 4.1189 | 43000 | 0.0265 | 0.8105 | 0.7412 |
800
+ | 4.1285 | 43100 | 0.027 | - | - |
801
+ | 4.1380 | 43200 | 0.0208 | - | - |
802
+ | 4.1476 | 43300 | 0.0179 | - | - |
803
+ | 4.1572 | 43400 | 0.0194 | - | - |
804
+ | 4.1668 | 43500 | 0.0217 | - | - |
805
+ | 4.1763 | 43600 | 0.0212 | - | - |
806
+ | 4.1859 | 43700 | 0.0226 | - | - |
807
+ | 4.1955 | 43800 | 0.0252 | - | - |
808
+ | 4.2051 | 43900 | 0.0293 | - | - |
809
+ | 4.2147 | 44000 | 0.0216 | 0.8029 | 0.7414 |
810
+ | 4.2242 | 44100 | 0.029 | - | - |
811
+ | 4.2338 | 44200 | 0.0216 | - | - |
812
+ | 4.2434 | 44300 | 0.0251 | - | - |
813
+ | 4.2530 | 44400 | 0.018 | - | - |
814
+ | 4.2626 | 44500 | 0.025 | - | - |
815
+ | 4.2721 | 44600 | 0.0225 | - | - |
816
+ | 4.2817 | 44700 | 0.0303 | - | - |
817
+ | 4.2913 | 44800 | 0.028 | - | - |
818
+ | 4.3009 | 44900 | 0.0203 | - | - |
819
+ | 4.3104 | 45000 | 0.026 | 0.8081 | 0.7405 |
820
+
821
+ </details>
822
+
823
+ ### Framework Versions
824
+ - Python: 3.10.12
825
+ - Sentence Transformers: 3.0.0
826
+ - Transformers: 4.38.2
827
+ - PyTorch: 2.1.2+cu121
828
+ - Accelerate: 0.27.2
829
+ - Datasets: 2.19.1
830
+ - Tokenizers: 0.15.2
831
+
832
+ ## Citation
833
+
834
+ ### BibTeX
835
+
836
+ #### Sentence Transformers
837
+ ```bibtex
838
+ @inproceedings{reimers-2019-sentence-bert,
839
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
840
+ author = "Reimers, Nils and Gurevych, Iryna",
841
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
842
+ month = "11",
843
+ year = "2019",
844
+ publisher = "Association for Computational Linguistics",
845
+ url = "https://arxiv.org/abs/1908.10084",
846
+ }
847
+ ```
848
+
849
+ #### CachedMultipleNegativesRankingLoss
850
+ ```bibtex
851
+ @misc{gao2021scaling,
852
+ title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
853
+ author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
854
+ year={2021},
855
+ eprint={2101.06983},
856
+ archivePrefix={arXiv},
857
+ primaryClass={cs.LG}
858
+ }
859
+ ```
860
+
861
+ <!--
862
+ ## Glossary
863
+
864
+ *Clearly define terms in order to be accessible across audiences.*
865
+ -->
866
+
867
+ <!--
868
+ ## Model Card Authors
869
+
870
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
871
+ -->
872
+
873
+ <!--
874
+ ## Model Card Contact
875
+
876
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
877
+ -->
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "models/nomic-embed-text-esci/checkpoint-45000",
3
+ "activation_function": "swiglu",
4
+ "architectures": [
5
+ "NomicBertModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_hf_nomic_bert.NomicBertConfig",
10
+ "AutoModel": "modeling_hf_nomic_bert.NomicBertModel",
11
+ "AutoModelForMaskedLM": "nomic-ai/nomic-bert-2048--modeling_hf_nomic_bert.NomicBertForPreTraining"
12
+ },
13
+ "bos_token_id": null,
14
+ "causal": false,
15
+ "dense_seq_output": true,
16
+ "embd_pdrop": 0.0,
17
+ "eos_token_id": null,
18
+ "fused_bias_fc": true,
19
+ "fused_dropout_add_ln": true,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-12,
22
+ "max_trained_positions": 2048,
23
+ "mlp_fc1_bias": false,
24
+ "mlp_fc2_bias": false,
25
+ "model_type": "nomic_bert",
26
+ "n_embd": 768,
27
+ "n_head": 12,
28
+ "n_inner": 3072,
29
+ "n_layer": 12,
30
+ "n_positions": 8192,
31
+ "pad_vocab_size_multiple": 64,
32
+ "parallel_block": false,
33
+ "parallel_block_tied_norm": false,
34
+ "prenorm": false,
35
+ "qkv_proj_bias": false,
36
+ "reorder_and_upcast_attn": false,
37
+ "resid_pdrop": 0.0,
38
+ "rotary_emb_base": 1000,
39
+ "rotary_emb_fraction": 1.0,
40
+ "rotary_emb_interleaved": false,
41
+ "rotary_emb_scale_base": null,
42
+ "rotary_scaling_factor": null,
43
+ "scale_attn_by_inverse_layer_idx": false,
44
+ "scale_attn_weights": true,
45
+ "summary_activation": null,
46
+ "summary_first_dropout": 0.0,
47
+ "summary_proj_to_labels": true,
48
+ "summary_type": "cls_index",
49
+ "summary_use_proj": true,
50
+ "torch_dtype": "float32",
51
+ "transformers_version": "4.38.2",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "use_flash_attn": true,
55
+ "use_rms_norm": false,
56
+ "use_xentropy": true,
57
+ "vocab_size": 30528
58
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "2.4.0.dev0",
4
+ "transformers": "4.37.2",
5
+ "pytorch": "2.1.0+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": null
10
+ }
configuration_hf_nomic_bert.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class NomicBertConfig(GPT2Config):
5
+ model_type = "nomic_bert"
6
+
7
+ def __init__(
8
+ self,
9
+ prenorm=False,
10
+ parallel_block=False,
11
+ parallel_block_tied_norm=False,
12
+ rotary_emb_fraction=0.0,
13
+ fused_dropout_add_ln=False,
14
+ fused_bias_fc=False,
15
+ use_flash_attn=False,
16
+ use_xentropy=False,
17
+ qkv_proj_bias=True,
18
+ rotary_emb_base=10_000,
19
+ rotary_emb_scale_base=None,
20
+ rotary_emb_interleaved=False,
21
+ mlp_fc1_bias=True,
22
+ mlp_fc2_bias=True,
23
+ use_rms_norm=False,
24
+ causal=False,
25
+ type_vocab_size=2,
26
+ dense_seq_output=True,
27
+ pad_vocab_size_multiple=1,
28
+ tie_word_embeddings=True,
29
+ rotary_scaling_factor=None,
30
+ max_trained_positions=2048,
31
+ **kwargs,
32
+ ):
33
+ self.prenorm = prenorm
34
+ self.parallel_block = parallel_block
35
+ self.parallel_block_tied_norm = parallel_block_tied_norm
36
+ self.rotary_emb_fraction = rotary_emb_fraction
37
+ self.tie_word_embeddings = tie_word_embeddings
38
+ self.fused_dropout_add_ln = fused_dropout_add_ln
39
+ self.fused_bias_fc = fused_bias_fc
40
+ self.use_flash_attn = use_flash_attn
41
+ self.use_xentropy = use_xentropy
42
+ self.qkv_proj_bias = qkv_proj_bias
43
+ self.rotary_emb_base = rotary_emb_base
44
+ self.rotary_emb_scale_base = rotary_emb_scale_base
45
+ self.rotary_emb_interleaved = rotary_emb_interleaved
46
+ self.mlp_fc1_bias = mlp_fc1_bias
47
+ self.mlp_fc2_bias = mlp_fc2_bias
48
+ self.use_rms_norm = use_rms_norm
49
+ self.causal = causal
50
+ self.type_vocab_size = type_vocab_size
51
+ self.dense_seq_output = dense_seq_output
52
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
53
+ self.rotary_scaling_factor = rotary_scaling_factor
54
+ self.max_trained_positions = max_trained_positions
55
+
56
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69fe41349d5efc8669c5d8ac9e0fe86fec944f8f2886d10641b6ab278c7f634b
3
+ size 546938168
modeling_hf_nomic_bert.py ADDED
@@ -0,0 +1,1234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ import logging
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+ import os
10
+ import re
11
+ from collections import OrderedDict
12
+ from functools import partial
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from einops import rearrange, repeat
19
+ from safetensors.torch import load_file as safe_load_file
20
+ from transformers import GPT2Config, PreTrainedModel
21
+ from transformers.models.bert.modeling_bert import (
22
+ BaseModelOutputWithPoolingAndCrossAttentions,
23
+ MaskedLMOutput,
24
+ SequenceClassifierOutput,
25
+ )
26
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
27
+ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
28
+
29
+ from .configuration_hf_nomic_bert import NomicBertConfig
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # adapted from flash attention, added safe serialization option for hf models
35
+ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
36
+ # If not fp32, then we don't want to load directly to the GPU
37
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
38
+ is_sharded = False
39
+ load_safe = False
40
+ resolved_archive_file = None
41
+
42
+ weights_path = os.path.join(model_name, WEIGHTS_NAME)
43
+ weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
44
+ safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
45
+ safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
46
+
47
+ if os.path.isfile(weights_path):
48
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
49
+ elif os.path.isfile(weights_index_path):
50
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
51
+ is_sharded = True
52
+ elif os.path.isfile(safe_weights_path):
53
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
54
+ load_safe = True
55
+ elif os.path.isfile(safe_weights_index_path):
56
+ resolved_archive_file = cached_file(
57
+ model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
58
+ )
59
+ is_sharded = True
60
+ load_safe = True
61
+ else: # Try loading from HF hub instead of from local files
62
+ resolved_archive_file = None
63
+ for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
64
+ resolved_archive_file = cached_file(
65
+ model_name, weight_name, _raise_exceptions_for_missing_entries=False
66
+ )
67
+ if resolved_archive_file is not None:
68
+ if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
69
+ load_safe = True
70
+ if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
71
+ is_sharded = True
72
+ break
73
+
74
+ if resolved_archive_file is None:
75
+ raise EnvironmentError(f"Model name {model_name} was not found.")
76
+
77
+ if load_safe:
78
+ loader = partial(safe_load_file, device=mapped_device)
79
+ else:
80
+ loader = partial(torch.load, map_location=mapped_device)
81
+
82
+ if is_sharded:
83
+ # resolved_archive_file becomes a list of files that point to the different
84
+ # checkpoint shards in this case.
85
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
86
+ state_dict = {}
87
+ for sharded_file in resolved_archive_file:
88
+ state_dict.update(loader(sharded_file))
89
+ else:
90
+ state_dict = loader(resolved_archive_file)
91
+ # Convert dtype before moving to GPU to save memory
92
+ if dtype is not None:
93
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
94
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
95
+ return state_dict
96
+
97
+
98
+ def filter_shapes(state_dict, model):
99
+ """
100
+ Filters the state dict to match the current model shape.
101
+ """
102
+ filtered_state_dict = {}
103
+ for key, value in state_dict.items():
104
+ if key in model.state_dict():
105
+ if value.shape == model.state_dict()[key].shape:
106
+ filtered_state_dict[key] = value
107
+ return filtered_state_dict
108
+
109
+
110
+ def remap_bert_state_dict(
111
+ state_dict,
112
+ config,
113
+ remove_bert=False,
114
+ remove_cls_weights=False,
115
+ add_pooling_layer=False,
116
+ ):
117
+ """
118
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
119
+ """
120
+
121
+ def add_bert_prefix(key):
122
+ # prepend bert. to the key
123
+ if key.startswith("bert.") or key.startswith("cls."):
124
+ return key
125
+ return f"bert.{key}"
126
+
127
+ state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
128
+
129
+ # LayerNorm
130
+ def key_mapping_ln_gamma_beta(key):
131
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
132
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
133
+ return key
134
+
135
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
136
+
137
+ # Layers
138
+ def key_mapping_layers(key):
139
+ return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
140
+
141
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
142
+
143
+ # LayerNorm
144
+ def key_mapping_ln(key):
145
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
146
+ key = re.sub(
147
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
148
+ r"bert.encoder.layers.\1.norm1.\2",
149
+ key,
150
+ )
151
+ key = re.sub(
152
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
153
+ r"bert.encoder.layers.\1.norm2.\2",
154
+ key,
155
+ )
156
+ key = re.sub(
157
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
158
+ r"cls.predictions.transform.layer_norm.\1",
159
+ key,
160
+ )
161
+ return key
162
+
163
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
164
+
165
+ # MLP
166
+ def key_mapping_mlp(key):
167
+ key = re.sub(
168
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
169
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
170
+ key,
171
+ )
172
+ key = re.sub(
173
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
174
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
175
+ key,
176
+ )
177
+ return key
178
+
179
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
180
+
181
+ # Attention
182
+ last_layer_subset = getattr(config, "last_layer_subset", False)
183
+ for d in range(config.num_hidden_layers):
184
+ if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
185
+ continue
186
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
187
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
188
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
189
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
190
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
191
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
192
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
193
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
194
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
195
+ else:
196
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
197
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
198
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
199
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
200
+
201
+ def key_mapping_attn(key):
202
+ return re.sub(
203
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
204
+ r"bert.encoder.layers.\1.attn.out_proj.\2",
205
+ key,
206
+ )
207
+
208
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
209
+
210
+ def key_mapping_decoder_bias(key):
211
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
212
+
213
+ # remove nsp weights, we don't use
214
+ state_dict.pop("cls.seq_relationship.weight", None)
215
+ state_dict.pop("cls.seq_relationship.bias", None)
216
+ state_dict.pop("bert.embeddings.position_ids", None)
217
+
218
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
219
+
220
+ if remove_cls_weights:
221
+ cls_weights = [
222
+ "cls.predictions.decoder.bias",
223
+ "cls.predictions.transform.dense.weight",
224
+ "cls.predictions.transform.dense.bias",
225
+ "cls.predictions.transform.layer_norm.weight",
226
+ "cls.predictions.transform.layer_norm.bias",
227
+ "cls.predictions.decoder.weight",
228
+ ]
229
+ for weight in cls_weights:
230
+ state_dict.pop(weight, None)
231
+
232
+ # Word embedding
233
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
234
+ if pad_vocab_size_multiple > 1:
235
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
236
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
237
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
238
+ )
239
+ if not remove_cls_weights:
240
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
241
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
242
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
243
+ )
244
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
245
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
246
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
247
+ if "cls.predictions.decoder.bias" in state_dict:
248
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
249
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
250
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
251
+ )
252
+
253
+ if add_pooling_layer is False:
254
+ pooler_weights = [
255
+ "bert.pooler.dense.weight",
256
+ "bert.pooler.dense.bias",
257
+ ]
258
+ for key in pooler_weights:
259
+ state_dict.pop(key, None)
260
+
261
+ if remove_bert:
262
+
263
+ def remove_bert_prefix(key):
264
+ key = re.sub(r"^bert.", "", key)
265
+ return key
266
+
267
+ state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
268
+
269
+ return state_dict
270
+
271
+
272
+ class NomicBertPreTrainedModel(PreTrainedModel):
273
+ """An abstract class to handle weights initialization and
274
+ a simple interface for dowloading and loading pretrained models.
275
+ """
276
+
277
+ config_class = NomicBertConfig
278
+ base_model_prefix = "model"
279
+ supports_gradient_checkpointing = True
280
+ _no_split_modules = ["Block"]
281
+ _skip_keys_device_placement = "past_key_values"
282
+
283
+ def __init__(self, config, *inputs, **kwargs):
284
+ super().__init__(config)
285
+ if not isinstance(config, GPT2Config):
286
+ raise ValueError(
287
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
288
+ "To create a model from a Google pretrained model use "
289
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
290
+ self.__class__.__name__, self.__class__.__name__
291
+ )
292
+ )
293
+ self.config = config
294
+
295
+ @classmethod
296
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
297
+ """
298
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
299
+ Download and cache the pre-trained model file if needed.
300
+
301
+ Params:
302
+ pretrained_model_name_or_path: either:
303
+ - a path or url to a pretrained model archive containing:
304
+ . `bert_config.json` a configuration file for the model
305
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
306
+ - a path or url to a pretrained model archive containing:
307
+ . `bert_config.json` a configuration file for the model
308
+ . `model.chkpt` a TensorFlow checkpoint
309
+ *inputs, **kwargs: additional input for the specific NomicBert class
310
+ (ex: num_labels for NomicBertForSequenceClassification)
311
+ """
312
+ # Instantiate model.
313
+ if config is None:
314
+ config = cls.config_class.from_pretrained(model_name)
315
+ remove_cls = cls != NomicBertForPreTraining
316
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
317
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
318
+ num_labels = kwargs.pop("num_labels", None)
319
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
320
+ strict = kwargs.pop("strict", True)
321
+ if rotary_scaling_factor:
322
+ config.rotary_scaling_factor = rotary_scaling_factor
323
+
324
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
325
+ config.n_positions = 2048
326
+ if num_labels:
327
+ config.num_labels = num_labels
328
+
329
+ if "add_pooling_layer" in kwargs:
330
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
331
+ else:
332
+ if cls == NomicBertModel:
333
+ model = cls(config, *inputs, add_pooling_layer=False)
334
+ else:
335
+ model = cls(config, *inputs)
336
+ # TODO: fix this
337
+ # Assuming we know what we're doing when loading from disk
338
+ # Prob a bad assumption but i'm tired and want to train this asap
339
+ if os.path.exists(model_name):
340
+ model_path = f"{model_name}/pytorch_model.bin"
341
+ if os.path.exists(model_path):
342
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
343
+ else:
344
+ model_path = f"{model_name}/model.safetensors"
345
+ if not os.path.exists(model_path):
346
+ raise ValueError(f"Model path {model_path} not found")
347
+ state_dict = safe_load_file(model_path)
348
+
349
+ if ignore_mismatched_shapes:
350
+ state_dict = filter_shapes(state_dict, model)
351
+ load_return = model.load_state_dict(state_dict, strict=False)
352
+ else:
353
+ # TODO: can probably check config class and see if we need to remap from a bert model
354
+ state_dict = state_dict_from_pretrained(model_name)
355
+ state_dict = remap_bert_state_dict(
356
+ state_dict,
357
+ config,
358
+ remove_bert=remove_bert_prefix,
359
+ remove_cls_weights=remove_cls,
360
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
361
+ )
362
+ if ignore_mismatched_shapes:
363
+ state_dict = filter_shapes(state_dict, model)
364
+
365
+ load_return = model.load_state_dict(state_dict, strict=strict)
366
+ logger.warning(load_return)
367
+ return model
368
+
369
+ def _set_gradient_checkpointing(self, module, value=False):
370
+ if isinstance(module, NomicBertEncoder):
371
+ module.gradient_checkpointing = value
372
+
373
+
374
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
375
+ def _init_weights(module, initializer_range=0.02):
376
+ if isinstance(module, nn.Linear):
377
+ nn.init.normal_(module.weight, std=initializer_range)
378
+ if module.bias is not None:
379
+ nn.init.zeros_(module.bias)
380
+ elif isinstance(module, nn.Embedding):
381
+ nn.init.normal_(module.weight, std=initializer_range)
382
+ if module.padding_idx is not None:
383
+ nn.init.zeros_(module.weight[module.padding_idx])
384
+
385
+
386
+ class NomicBertEmbeddings(nn.Module):
387
+ def __init__(self, config):
388
+ """
389
+ If max_position_embeddings <= 0, there's no position embeddings
390
+ If type_vocab_size <= 0, there's no token type embeddings
391
+ """
392
+ super().__init__()
393
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
394
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
395
+ self.type_vocab_size = config.type_vocab_size
396
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
397
+ self.position_embeddings = nn.Embedding(
398
+ config.max_position_embeddings,
399
+ config.hidden_size,
400
+ )
401
+ if self.type_vocab_size > 0:
402
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
403
+
404
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
405
+ """
406
+ input_ids: (batch, seqlen)
407
+ position_ids: (batch, seqlen)
408
+ token_type_ids: (batch, seqlen)
409
+ """
410
+ batch_size, seqlen = input_ids.shape
411
+ embeddings = self.word_embeddings(input_ids)
412
+
413
+ if self.type_vocab_size > 0:
414
+ if token_type_ids is None:
415
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
416
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
417
+ embeddings = embeddings + token_type_embeddings
418
+
419
+ if self.max_position_embeddings > 0:
420
+ if position_ids is None:
421
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
422
+ position_embeddings = self.position_embeddings(position_ids)
423
+ embeddings = embeddings + position_embeddings
424
+ return embeddings
425
+
426
+
427
+ class NomicBertMLP(nn.Module):
428
+ def __init__(
429
+ self,
430
+ in_features,
431
+ hidden_features=None,
432
+ out_features=None,
433
+ activation=F.gelu,
434
+ bias1=True,
435
+ bias2=True,
436
+ return_residual=False,
437
+ fused_bias_fc=False,
438
+ ):
439
+ super().__init__()
440
+ out_features = out_features if out_features is not None else in_features
441
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
442
+ self.return_residual = return_residual
443
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
444
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
445
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
446
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
447
+
448
+ def forward(self, x):
449
+ y = self.fc1(x)
450
+ y = self.activation(y)
451
+ y = self.fc2(y)
452
+ return y if not self.return_residual else (y, x)
453
+
454
+
455
+ class NomciBertGatedMLP(nn.Module):
456
+ def __init__(
457
+ self,
458
+ in_features,
459
+ hidden_features=None,
460
+ out_features=None,
461
+ activation=F.sigmoid,
462
+ bias1=True,
463
+ bias2=True,
464
+ multiple_of=256,
465
+ return_residual=False,
466
+ fused_bias_fc=True,
467
+ device=None,
468
+ dtype=None,
469
+ ):
470
+ super().__init__()
471
+ out_features = out_features if out_features is not None else in_features
472
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
473
+ hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
474
+ self.return_residual = return_residual
475
+
476
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
477
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
478
+ self.activation = activation
479
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
480
+
481
+ def forward(self, x):
482
+ y = self.fc11(x)
483
+ gate = self.fc12(x)
484
+ if self.activation == F.sigmoid: # Special case for GLU
485
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
486
+ else:
487
+ y = y * self.activation(gate)
488
+ y = self.fc2(y)
489
+ return y if not self.return_residual else (y, x)
490
+
491
+
492
+ def rotate_half(x, interleaved=False):
493
+ if not interleaved:
494
+ x1, x2 = x.chunk(2, dim=-1)
495
+ return torch.cat((-x2, x1), dim=-1)
496
+ else:
497
+ x1, x2 = x[..., ::2], x[..., 1::2]
498
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
499
+
500
+
501
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
502
+ """
503
+ x: (batch_size, seqlen, nheads, headdim)
504
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
505
+ """
506
+ ro_dim = cos.shape[-1] * 2
507
+ assert ro_dim <= x.shape[-1]
508
+ cos, sin = (
509
+ cos[offset : offset + x.shape[1]],
510
+ sin[offset : offset + x.shape[1]],
511
+ )
512
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
513
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
514
+ return torch.cat(
515
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
516
+ dim=-1,
517
+ )
518
+
519
+
520
+ class NomicBertRotaryEmbedding(nn.Module):
521
+ def __init__(
522
+ self,
523
+ dim: int,
524
+ base=10000.0,
525
+ interleaved=False,
526
+ scale_base=None,
527
+ pos_idx_in_fp32=True,
528
+ device=None,
529
+ ):
530
+ """
531
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
532
+ of 1st half and 2nd half (GPT-NeoX style).
533
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
534
+ otherwise they might be in lower precision.
535
+ This option was added because previously (before 2023-07-02), when we construct
536
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
537
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
538
+ self.inv_freq would be bf16, and the position indices are also in bf16.
539
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
540
+ embeddings for some positions will coincide.
541
+ To maintain compatibility with models previously trained in pure bf16,
542
+ we add this option.
543
+ """
544
+ super().__init__()
545
+ self.dim = dim
546
+ self.base = float(base)
547
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
548
+ # Generate and save the inverse frequency buffer (non trainable)
549
+ inv_freq = self._compute_inv_freq(device)
550
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
551
+ self.interleaved = interleaved
552
+ self.scale_base = scale_base
553
+ scale = (
554
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
555
+ if scale_base is not None
556
+ else None
557
+ )
558
+ self.register_buffer("scale", scale, persistent=False)
559
+
560
+ self._seq_len_cached = 0
561
+ self._cos_cached = None
562
+ self._sin_cached = None
563
+ self._cos_k_cached = None
564
+ self._sin_k_cached = None
565
+
566
+ def _compute_inv_freq(self, device=None):
567
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
568
+
569
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
570
+ # Reset the tables if the sequence length has changed,
571
+ # if we're on a new device (possibly due to tracing for instance),
572
+ # or if we're switching from inference mode to training
573
+ if (
574
+ seqlen > self._seq_len_cached
575
+ or self._cos_cached is None
576
+ or self._cos_cached.device != device
577
+ or self._cos_cached.dtype != dtype
578
+ or (self.training and self._cos_cached.is_inference())
579
+ ):
580
+ self._seq_len_cached = seqlen
581
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
582
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
583
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
584
+ if self.pos_idx_in_fp32:
585
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
586
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
587
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
588
+ # cos & sin output to change significantly.
589
+ # We want to recompute self.inv_freq if it was not loaded in fp32
590
+ if self.inv_freq.dtype != torch.float32:
591
+ inv_freq = self._compute_inv_freq(device=device)
592
+ else:
593
+ inv_freq = self.inv_freq
594
+ else:
595
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
596
+ inv_freq = self.inv_freq
597
+ # Don't do einsum, it converts fp32 to fp16 under AMP
598
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
599
+ freqs = torch.outer(t, inv_freq)
600
+ self._cos_cached = torch.cos(freqs).to(dtype)
601
+ self._sin_cached = torch.sin(freqs).to(dtype)
602
+
603
+ def forward(
604
+ self,
605
+ qkv: torch.Tensor,
606
+ kv: Optional[torch.Tensor] = None,
607
+ seqlen_offset: Union[int, torch.Tensor] = 0,
608
+ max_seqlen: Optional[int] = None,
609
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
610
+ """
611
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
612
+ else it's just q of shape (batch, seqlen, nheads, headdim)
613
+ kv: (batch, seqlen, 2, nheads, headdim)
614
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
615
+ Most commonly used in inference when we have KV cache.
616
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
617
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
618
+ Apply rotary embedding *inplace* to qkv and / or kv.
619
+ """
620
+ seqlen = qkv.shape[1]
621
+ if seqlen > self._seq_len_cached:
622
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
623
+ elif max_seqlen is not None:
624
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
625
+ elif isinstance(seqlen_offset, int):
626
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
627
+
628
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
629
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
630
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
631
+
632
+
633
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
634
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
635
+ super().__init__(**kwargs)
636
+ self.rotary_scaling_factor = rotary_scaling_factor
637
+ self.max_position_embeddings = max_position_embeddings
638
+
639
+ def _compute_inv_freq(self, base=None, device=None):
640
+ if base is None:
641
+ base = self.base
642
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
643
+
644
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
645
+ # Reset the tables if the sequence length has changed,
646
+ # if we're on a new device (possibly due to tracing for instance),
647
+ # or if we're switching from inference mode to training
648
+ if seqlen > self.max_position_embeddings:
649
+ base = self.base * (
650
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
651
+ ) ** (self.dim / (self.dim - 2))
652
+ inv_freq = self._compute_inv_freq(base=base, device=device)
653
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
654
+
655
+ if (
656
+ seqlen > self._seq_len_cached
657
+ or self._cos_cached is None
658
+ or self._cos_cached.device != device
659
+ or self._cos_cached.dtype != dtype
660
+ or (self.training and self._cos_cached.is_inference())
661
+ ):
662
+ self._seq_len_cached = seqlen
663
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
664
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
665
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
666
+ if self.pos_idx_in_fp32:
667
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
668
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
669
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
670
+ # cos & sin output to change significantly.
671
+ # We want to recompute self.inv_freq if it was not loaded in fp32
672
+ if self.inv_freq.dtype != torch.float32:
673
+ if seqlen > self.max_position_embeddings:
674
+ base = self.base * (
675
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
676
+ ) ** (self.dim / (self.dim - 2))
677
+ else:
678
+ base = self.base
679
+ inv_freq = self._compute_inv_freq(device=device, base=base)
680
+ else:
681
+ inv_freq = self.inv_freq
682
+ else:
683
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
684
+ inv_freq = self.inv_freq
685
+ # Don't do einsum, it converts fp32 to fp16 under AMP
686
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
687
+ freqs = torch.outer(t, inv_freq)
688
+ if self.scale is None:
689
+ self._cos_cached = torch.cos(freqs).to(dtype)
690
+ self._sin_cached = torch.sin(freqs).to(dtype)
691
+ else:
692
+ power = (
693
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
694
+ ) / self.scale_base
695
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
696
+ # We want the multiplication by scale to happen in fp32
697
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
698
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
699
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
700
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
701
+
702
+
703
+ class NomicBertAttention(nn.Module):
704
+ """Multi-head self-attention and cross-attention"""
705
+
706
+ def __init__(
707
+ self,
708
+ config,
709
+ ) -> None:
710
+ """
711
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
712
+ return_residual: whether to return the input x along with the output. This is for
713
+ performance reason: for post-norm architecture, returning the input allows us
714
+ to fuse the backward of nn.Linear with the residual connection.
715
+ """
716
+ super().__init__()
717
+ self.embed_dim = config.n_embd
718
+ self.use_flash_attn = config.use_flash_attn
719
+ self.fused_bias_fc = config.fused_bias_fc
720
+
721
+ self.num_heads = config.n_head
722
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
723
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
724
+ self.head_dim = self.embed_dim // self.num_heads
725
+ # we don't really support mqa / gqa for now
726
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
727
+
728
+ self.register_buffer(
729
+ "norm_factor",
730
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
731
+ persistent=False,
732
+ )
733
+
734
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
735
+ if self.rotary_emb_dim > 0:
736
+ if getattr(config, "rotary_scaling_factor", None):
737
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
738
+ dim=self.rotary_emb_dim,
739
+ base=config.rotary_emb_base,
740
+ scale_base=config.rotary_emb_scale_base,
741
+ interleaved=config.rotary_emb_interleaved,
742
+ rotary_scaling_factor=config.rotary_scaling_factor,
743
+ max_position_embeddings=config.max_trained_positions,
744
+ )
745
+ else:
746
+ self.rotary_emb = NomicBertRotaryEmbedding(
747
+ dim=self.rotary_emb_dim,
748
+ base=config.rotary_emb_base,
749
+ scale_base=config.rotary_emb_scale_base,
750
+ interleaved=config.rotary_emb_interleaved,
751
+ )
752
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
753
+ # uses the head dimension instead of the sequence dimension
754
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
755
+
756
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
757
+
758
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
759
+ self.causal = config.causal
760
+ self.drop = nn.Dropout(config.attn_pdrop)
761
+
762
+ def forward(
763
+ self,
764
+ hidden_states: torch.Tensor,
765
+ attention_mask: Optional[torch.Tensor] = None,
766
+ position_ids: Optional[torch.LongTensor] = None,
767
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
768
+ output_attentions: bool = False,
769
+ use_cache: bool = False,
770
+ is_padded_inputs: Optional[bool] = True,
771
+ cu_seqlens: Optional[torch.Tensor] = None,
772
+ max_seq_len: Optional[int] = None,
773
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
774
+
775
+ has_layer_past = past_key_value is not None
776
+
777
+ if has_layer_past:
778
+ past_key_value = past_key_value[0]
779
+ past_len = past_key_value[1]
780
+ else:
781
+ past_len = 0
782
+
783
+ qkv = self.Wqkv(hidden_states)
784
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
785
+
786
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
787
+
788
+ if self.rotary_emb_dim > 0:
789
+ if self.rotary_head_dim:
790
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
791
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
792
+
793
+ if self.rotary_head_dim:
794
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
795
+
796
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
797
+
798
+ query = query.permute(0, 2, 1, 3)
799
+ key = key.permute(0, 2, 1, 3)
800
+ value = value.permute(0, 2, 1, 3)
801
+
802
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
803
+ if attention_mask is not None:
804
+ attention_scores = attention_scores + attention_mask
805
+
806
+ attentions_probs = F.softmax(attention_scores, dim=-1)
807
+ attentions_probs = self.drop(attentions_probs)
808
+
809
+ attn_output = torch.matmul(attentions_probs, value)
810
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
811
+
812
+ attn_output = self.out_proj(attn_output)
813
+
814
+ return attn_output
815
+
816
+
817
+ class NomicBertBlock(NomicBertPreTrainedModel):
818
+ def __init__(
819
+ self,
820
+ config,
821
+ ):
822
+ super().__init__(config=config)
823
+ self.prenorm = config.prenorm
824
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
825
+
826
+ self.attn = NomicBertAttention(config)
827
+ activation = (
828
+ F.sigmoid
829
+ if config.activation_function == "glu"
830
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
831
+ )
832
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
833
+ self.mlp = NomciBertGatedMLP(
834
+ config.n_embd,
835
+ hidden_features=config.n_inner,
836
+ bias1=config.mlp_fc1_bias,
837
+ bias2=config.mlp_fc2_bias,
838
+ activation=activation,
839
+ fused_bias_fc=config.fused_bias_fc,
840
+ )
841
+ else:
842
+ self.mlp = NomicBertMLP(
843
+ config.n_embd,
844
+ hidden_features=config.n_inner,
845
+ bias1=config.mlp_fc1_bias,
846
+ bias2=config.mlp_fc2_bias,
847
+ activation=activation,
848
+ fused_bias_fc=config.fused_bias_fc,
849
+ )
850
+
851
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
852
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
853
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
854
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
855
+
856
+ def forward(
857
+ self,
858
+ hidden_states: torch.Tensor,
859
+ hidden_states2: torch.Tensor,
860
+ residual: Optional[torch.Tensor] = None,
861
+ attention_mask: Optional[torch.Tensor] = None,
862
+ position_ids: Optional[torch.LongTensor] = None,
863
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
864
+ is_padded_inputs: Optional[bool] = True,
865
+ output_attentions: Optional[bool] = False,
866
+ use_cache: Optional[bool] = False,
867
+ cu_seqlens: Optional[torch.Tensor] = None,
868
+ max_seq_len: Optional[int] = None,
869
+ ):
870
+ r"""Pass the input through the encoder layer.
871
+
872
+ Args:
873
+ hidden_states: the sequence to the encoder layer (required).
874
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
875
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
876
+ before applying the query projection. Useful for e.g., ViT where we only care
877
+ about the CLS token in the last layer.
878
+ """
879
+ if self.prenorm:
880
+ dropped = self.dropout1(hidden_states)
881
+ residual = (dropped + residual) if residual is not None else dropped
882
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
883
+ hidden_states = self.attn(
884
+ hidden_states,
885
+ attention_mask=attention_mask,
886
+ is_padded_inputs=is_padded_inputs,
887
+ cu_seqlens=cu_seqlens,
888
+ max_seq_len=max_seq_len,
889
+ )
890
+
891
+ dropped = self.dropout2(hidden_states)
892
+ residual = (dropped + residual) if residual is not None else dropped
893
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
894
+ hidden_states = self.mlp(hidden_states)
895
+
896
+ return hidden_states, None, residual
897
+ else:
898
+ assert residual is None
899
+ attn_outputs = self.attn(
900
+ hidden_states,
901
+ attention_mask=attention_mask,
902
+ is_padded_inputs=is_padded_inputs,
903
+ cu_seqlens=cu_seqlens,
904
+ max_seq_len=max_seq_len,
905
+ )
906
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
907
+ mlp_out = self.mlp(hidden_states)
908
+
909
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
910
+ return hidden_states, None, None
911
+
912
+
913
+ class NomicBertEncoder(nn.Module):
914
+ def __init__(self, config: GPT2Config):
915
+ super().__init__()
916
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
917
+ self.gradient_checkpointing = False
918
+ self.config = config
919
+
920
+ def forward(
921
+ self,
922
+ hidden_states: torch.LongTensor = None,
923
+ attention_mask: Optional[torch.Tensor] = None,
924
+ position_ids: Optional[torch.LongTensor] = None,
925
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
926
+ inputs_embeds: Optional[torch.FloatTensor] = None,
927
+ use_cache: Optional[bool] = None,
928
+ output_attentions: Optional[bool] = None,
929
+ output_hidden_states: Optional[bool] = None,
930
+ return_dict: Optional[bool] = None,
931
+ is_padded_inputs: Optional[bool] = True,
932
+ ):
933
+ """If subset_mask is not None, we only want output for the subset of the sequence.
934
+ This means that we only compute the last layer output for these tokens.
935
+ subset_mask: (batch, seqlen), dtype=torch.bool
936
+ """
937
+ hidden_states2 = None
938
+ residual = None
939
+
940
+ for _, layer in enumerate(self.layers):
941
+ if self.gradient_checkpointing and self.training:
942
+
943
+ def create_custom_forward(module):
944
+ def custom_forward(*inputs):
945
+ # None for past_key_value
946
+ return module(*inputs)
947
+
948
+ return custom_forward
949
+
950
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
951
+ create_custom_forward(layer),
952
+ hidden_states,
953
+ hidden_states2,
954
+ residual,
955
+ attention_mask,
956
+ None,
957
+ None,
958
+ is_padded_inputs,
959
+ # if you freeze ANY layers, you need `use_reentrant=False`
960
+ # https://github.com/huggingface/transformers/issues/21381
961
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
962
+ use_reentrant=False,
963
+ )
964
+
965
+ else:
966
+ hidden_states, hidden_states2, residual = layer(
967
+ hidden_states,
968
+ hidden_states2,
969
+ residual,
970
+ attention_mask,
971
+ position_ids,
972
+ None,
973
+ is_padded_inputs,
974
+ output_attentions,
975
+ use_cache,
976
+ )
977
+ return hidden_states
978
+
979
+
980
+ class NomicBertPooler(nn.Module):
981
+ def __init__(self, config):
982
+ super().__init__()
983
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
984
+ self.activation = nn.Tanh()
985
+
986
+ def forward(self, hidden_states, pool=True):
987
+ # We "pool" the model by simply taking the hidden state corresponding
988
+ # to the first token.
989
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
990
+ pooled_output = self.dense(first_token_tensor)
991
+ pooled_output = self.activation(pooled_output)
992
+ return pooled_output
993
+
994
+
995
+ class NomicBertPredictionHeadTransform(nn.Module):
996
+ def __init__(self, config):
997
+ super().__init__()
998
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
999
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1000
+ if config.activation_function == "swiglu":
1001
+ self.transform_act_fn = F.silu
1002
+ else:
1003
+ self.transform_act_fn = nn.GELU(approximate=approximate)
1004
+
1005
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1006
+
1007
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1008
+ hidden_states = self.dense(hidden_states)
1009
+ hidden_states = self.transform_act_fn(hidden_states)
1010
+ hidden_states = self.layer_norm(hidden_states)
1011
+
1012
+ return hidden_states
1013
+
1014
+
1015
+ class NomicBertLMPredictionHead(nn.Module):
1016
+ def __init__(self, config):
1017
+ super().__init__()
1018
+
1019
+ self.transform = NomicBertPredictionHeadTransform(config)
1020
+
1021
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1022
+
1023
+ def forward(self, hidden_states):
1024
+ hidden_states = self.transform(hidden_states)
1025
+ hidden_states = self.decoder(hidden_states)
1026
+ return hidden_states
1027
+
1028
+
1029
+ class NomicBertPreTrainingHeads(nn.Module):
1030
+ def __init__(self, config):
1031
+ super().__init__()
1032
+ self.predictions = NomicBertLMPredictionHead(config)
1033
+
1034
+ def forward(self, sequence_output):
1035
+ prediction_scores = self.predictions(sequence_output)
1036
+ return prediction_scores
1037
+
1038
+
1039
+ class NomicBertModel(NomicBertPreTrainedModel):
1040
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
1041
+ super().__init__(config)
1042
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1043
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
1044
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1045
+
1046
+ assert config.activation_function in [
1047
+ "gelu",
1048
+ "gelu_new",
1049
+ "gelu_fast",
1050
+ "gelu_pytorch_tanh",
1051
+ "swiglu",
1052
+ "geglu",
1053
+ "glu",
1054
+ ]
1055
+
1056
+ self.embeddings = NomicBertEmbeddings(config)
1057
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
1058
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1059
+ self.encoder = NomicBertEncoder(config)
1060
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1061
+
1062
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1063
+
1064
+ def forward(
1065
+ self,
1066
+ input_ids,
1067
+ attention_mask=None,
1068
+ position_ids=None,
1069
+ token_type_ids=None,
1070
+ return_dict=None,
1071
+ matryoshka_dim=None,
1072
+ ):
1073
+ if token_type_ids is None:
1074
+ token_type_ids = torch.zeros_like(input_ids)
1075
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1076
+ hidden_states = self.emb_ln(hidden_states)
1077
+ hidden_states = self.emb_drop(hidden_states)
1078
+
1079
+ attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1080
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1081
+
1082
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1083
+
1084
+ if matryoshka_dim:
1085
+ sequence_output = sequence_output[:, :matryoshka_dim]
1086
+
1087
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1088
+ last_hidden_state=sequence_output,
1089
+ pooler_output=pooled_output,
1090
+ )
1091
+
1092
+
1093
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1094
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1095
+
1096
+ def __init__(self, config: GPT2Config):
1097
+ super().__init__(config)
1098
+
1099
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1100
+ self.cls = NomicBertPreTrainingHeads(config)
1101
+ self.mlm_loss = nn.CrossEntropyLoss()
1102
+
1103
+ # Initialize weights and apply final processing
1104
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1105
+ self.tie_weights()
1106
+
1107
+ def tie_weights(self):
1108
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1109
+
1110
+ def forward(
1111
+ self,
1112
+ input_ids,
1113
+ position_ids=None,
1114
+ token_type_ids=None,
1115
+ attention_mask=None,
1116
+ labels=None,
1117
+ ):
1118
+ """
1119
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1120
+ mask).
1121
+ Outputs:
1122
+ if `labels` and `next_sentence_label` are not `None`:
1123
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1124
+ sentence classification loss.
1125
+ if `labels` or `next_sentence_label` is `None`:
1126
+ Outputs a tuple comprising
1127
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1128
+ - the next sentence classification logits of shape [batch_size, 2].
1129
+
1130
+ """
1131
+ outputs = self.bert(
1132
+ input_ids,
1133
+ position_ids=position_ids,
1134
+ token_type_ids=token_type_ids,
1135
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1136
+ )
1137
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1138
+
1139
+ prediction_scores = self.cls(sequence_output)
1140
+
1141
+ total_loss = None
1142
+ if labels is not None:
1143
+ masked_lm_loss = self.mlm_loss(
1144
+ rearrange(prediction_scores, "... v -> (...) v"),
1145
+ rearrange(labels, "... -> (...)"),
1146
+ )
1147
+ total_loss = masked_lm_loss.float()
1148
+
1149
+ return MaskedLMOutput(
1150
+ loss=total_loss,
1151
+ logits=prediction_scores,
1152
+ hidden_states=outputs.hidden_states,
1153
+ attentions=None,
1154
+ )
1155
+
1156
+
1157
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1158
+ def __init__(self, config):
1159
+ super().__init__(config)
1160
+ self.num_labels = config.num_labels
1161
+ self.config = config
1162
+
1163
+ self.bert = NomicBertModel(config)
1164
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1165
+ self.dropout = nn.Dropout(classifier_dropout)
1166
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
1167
+
1168
+ # Initialize weights and apply final processing
1169
+ self.post_init()
1170
+
1171
+ def forward(
1172
+ self,
1173
+ input_ids: Optional[torch.Tensor] = None,
1174
+ attention_mask: Optional[torch.Tensor] = None,
1175
+ token_type_ids: Optional[torch.Tensor] = None,
1176
+ position_ids: Optional[torch.Tensor] = None,
1177
+ head_mask: Optional[torch.Tensor] = None,
1178
+ inputs_embeds: Optional[torch.Tensor] = None,
1179
+ labels: Optional[torch.Tensor] = None,
1180
+ output_attentions: Optional[bool] = None,
1181
+ output_hidden_states: Optional[bool] = None,
1182
+ return_dict: Optional[bool] = None,
1183
+ ):
1184
+ r"""
1185
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1186
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1187
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1188
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1189
+ """
1190
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1191
+ outputs = self.bert(
1192
+ input_ids,
1193
+ position_ids=position_ids,
1194
+ token_type_ids=token_type_ids,
1195
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1196
+ )
1197
+
1198
+ pooled_output = outputs[1]
1199
+
1200
+ pooled_output = self.dropout(pooled_output)
1201
+ logits = self.classifier(pooled_output)
1202
+
1203
+ loss = None
1204
+ if labels is not None:
1205
+ if self.config.problem_type is None:
1206
+ if self.num_labels == 1:
1207
+ self.config.problem_type = "regression"
1208
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1209
+ self.config.problem_type = "single_label_classification"
1210
+ else:
1211
+ self.config.problem_type = "multi_label_classification"
1212
+
1213
+ if self.config.problem_type == "regression":
1214
+ loss_fct = nn.MSELoss()
1215
+ if self.num_labels == 1:
1216
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1217
+ else:
1218
+ loss = loss_fct(logits, labels)
1219
+ elif self.config.problem_type == "single_label_classification":
1220
+ loss_fct = nn.CrossEntropyLoss()
1221
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1222
+ elif self.config.problem_type == "multi_label_classification":
1223
+ loss_fct = nn.BCEWithLogitsLoss()
1224
+ loss = loss_fct(logits, labels)
1225
+ if not return_dict:
1226
+ output = (logits,) + outputs[2:]
1227
+ return ((loss,) + output) if loss is not None else output
1228
+
1229
+ return SequenceClassifierOutput(
1230
+ loss=loss,
1231
+ logits=logits,
1232
+ hidden_states=outputs.hidden_states,
1233
+ attentions=outputs.attentions,
1234
+ )
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 8192,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "max_length": 8192,
49
+ "model_max_length": 8192,
50
+ "pad_to_multiple_of": null,
51
+ "pad_token": "[PAD]",
52
+ "pad_token_type_id": 0,
53
+ "padding_side": "right",
54
+ "sep_token": "[SEP]",
55
+ "stride": 0,
56
+ "strip_accents": null,
57
+ "tokenize_chinese_chars": true,
58
+ "tokenizer_class": "BertTokenizer",
59
+ "truncation_side": "right",
60
+ "truncation_strategy": "longest_first",
61
+ "unk_token": "[UNK]"
62
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff