initial commit
Browse files- 1_Pooling/config.json +10 -0
- README.md +115 -0
- config.json +100 -0
- config_sentence_transformers.json +9 -0
- configuration_nvembed.py +88 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +311 -0
- modeling_nvembed.py +436 -0
- modules.json +20 -0
- sentence_bert_config.json +4 -0
- special_tokens_map.json +30 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +42 -0
1_Pooling/config.json
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"word_embedding_dimension": 4096,
|
3 |
+
"pooling_mode_cls_token": false,
|
4 |
+
"pooling_mode_mean_tokens": true,
|
5 |
+
"pooling_mode_max_tokens": false,
|
6 |
+
"pooling_mode_mean_sqrt_len_tokens": false,
|
7 |
+
"pooling_mode_weightedmean_tokens": false,
|
8 |
+
"pooling_mode_lasttoken": false,
|
9 |
+
"include_prompt": false
|
10 |
+
}
|
README.md
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
license: cc-by-nc-4.0
|
5 |
+
---
|
6 |
+
## Introduction
|
7 |
+
This is fixed version of `nvidia/NV-Embed-v1` to download model without error.
|
8 |
+
|
9 |
+
## How to use
|
10 |
+
|
11 |
+
Here is an example of how to encode queries and passages using Huggingface-transformer and Sentence-transformer.
|
12 |
+
|
13 |
+
### Usage (HuggingFace Transformers)
|
14 |
+
|
15 |
+
```python
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from transformers import AutoTokenizer, AutoModel
|
19 |
+
|
20 |
+
# Each query needs to be accompanied by an corresponding instruction describing the task.
|
21 |
+
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}
|
22 |
+
|
23 |
+
query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
|
24 |
+
queries = [
|
25 |
+
'are judo throws allowed in wrestling?',
|
26 |
+
'how to become a radiology technician in michigan?'
|
27 |
+
]
|
28 |
+
|
29 |
+
# No instruction needed for retrieval passages
|
30 |
+
passage_prefix = ""
|
31 |
+
passages = [
|
32 |
+
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
|
33 |
+
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
|
34 |
+
]
|
35 |
+
|
36 |
+
# load model with tokenizer
|
37 |
+
model = AutoModel.from_pretrained('bzantium/NV-Embed-v1', trust_remote_code=True)
|
38 |
+
|
39 |
+
# get the embeddings
|
40 |
+
max_length = 4096
|
41 |
+
query_embeddings = model.encode(queries, instruction=query_prefix, max_length=max_length)
|
42 |
+
passage_embeddings = model.encode(passages, instruction=passage_prefix, max_length=max_length)
|
43 |
+
|
44 |
+
# normalize embeddings
|
45 |
+
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
|
46 |
+
passage_embeddings = F.normalize(passage_embeddings, p=2, dim=1)
|
47 |
+
|
48 |
+
# get the embeddings with DataLoader (spliting the datasets into multiple mini-batches)
|
49 |
+
# batch_size=2
|
50 |
+
# query_embeddings = model._do_encode(queries, batch_size=batch_size, instruction=query_prefix, max_length=max_length, num_workers=32)
|
51 |
+
# passage_embeddings = model._do_encode(passages, batch_size=batch_size, instruction=passage_prefix, max_length=max_length, num_workers=32)
|
52 |
+
|
53 |
+
scores = (query_embeddings @ passage_embeddings.T) * 100
|
54 |
+
print(scores.tolist())
|
55 |
+
#[[77.9402084350586, 0.4248958230018616], [3.757718086242676, 79.60113525390625]]
|
56 |
+
```
|
57 |
+
|
58 |
+
|
59 |
+
### Usage (Sentence-Transformers)
|
60 |
+
|
61 |
+
```python
|
62 |
+
import torch
|
63 |
+
from sentence_transformers import SentenceTransformer
|
64 |
+
|
65 |
+
# Each query needs to be accompanied by an corresponding instruction describing the task.
|
66 |
+
task_name_to_instruct = {"example": "Given a question, retrieve passages that answer the question",}
|
67 |
+
|
68 |
+
query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
|
69 |
+
queries = [
|
70 |
+
'are judo throws allowed in wrestling?',
|
71 |
+
'how to become a radiology technician in michigan?'
|
72 |
+
]
|
73 |
+
|
74 |
+
# No instruction needed for retrieval passages
|
75 |
+
passages = [
|
76 |
+
"Since you're reading this, you are probably someone from a judo background or someone who is just wondering how judo techniques can be applied under wrestling rules. So without further ado, let's get to the question. Are Judo throws allowed in wrestling? Yes, judo throws are allowed in freestyle and folkstyle wrestling. You only need to be careful to follow the slam rules when executing judo throws. In wrestling, a slam is lifting and returning an opponent to the mat with unnecessary force.",
|
77 |
+
"Below are the basic steps to becoming a radiologic technologist in Michigan:Earn a high school diploma. As with most careers in health care, a high school education is the first step to finding entry-level employment. Taking classes in math and science, such as anatomy, biology, chemistry, physiology, and physics, can help prepare students for their college studies and future careers.Earn an associate degree. Entry-level radiologic positions typically require at least an Associate of Applied Science. Before enrolling in one of these degree programs, students should make sure it has been properly accredited by the Joint Review Committee on Education in Radiologic Technology (JRCERT).Get licensed or certified in the state of Michigan."
|
78 |
+
]
|
79 |
+
|
80 |
+
# load model with tokenizer
|
81 |
+
model = SentenceTransformer('bzantium/NV-Embed-v1', trust_remote_code=True)
|
82 |
+
model.max_seq_length = 4096
|
83 |
+
model.tokenizer.padding_side="right"
|
84 |
+
|
85 |
+
def add_eos(input_examples):
|
86 |
+
input_examples = [input_example + model.tokenizer.eos_token for input_example in input_examples]
|
87 |
+
return input_examples
|
88 |
+
|
89 |
+
# get the embeddings
|
90 |
+
batch_size = 2
|
91 |
+
query_embeddings = model.encode(add_eos(queries), batch_size=batch_size, prompt=query_prefix, normalize_embeddings=True)
|
92 |
+
passage_embeddings = model.encode(add_eos(passages), batch_size=batch_size, normalize_embeddings=True)
|
93 |
+
|
94 |
+
scores = (query_embeddings @ passage_embeddings.T) * 100
|
95 |
+
print(scores.tolist())
|
96 |
+
```
|
97 |
+
|
98 |
+
## Correspondence to
|
99 |
+
Chankyu Lee ([email protected]), Wei Ping ([email protected])
|
100 |
+
|
101 |
+
## Citation
|
102 |
+
If you find this code useful in your research, please consider citing:
|
103 |
+
|
104 |
+
```bibtex
|
105 |
+
@misc{lee2024nvembed,
|
106 |
+
title={NV-Embed: Improved Techniques for Training LLMs as Generalist Embedding Models},
|
107 |
+
author={Chankyu Lee and Rajarshi Roy and Mengyao Xu and Jonathan Raiman and Mohammad Shoeybi and Bryan Catanzaro and Wei Ping},
|
108 |
+
year={2024},
|
109 |
+
eprint={2405.17428},
|
110 |
+
archivePrefix={arXiv},
|
111 |
+
primaryClass={cs.CL}
|
112 |
+
}
|
113 |
+
```
|
114 |
+
## License
|
115 |
+
This model should not be used for any commercial purpose. Refer the [license](https://spdx.org/licenses/CC-BY-NC-4.0) for the detailed terms.
|
config.json
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "bzantium/NV-Embed-v1",
|
3 |
+
"add_eos": true,
|
4 |
+
"add_pad_token": true,
|
5 |
+
"architectures": [
|
6 |
+
"NVEmbedModel"
|
7 |
+
],
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_nvembed.NVEmbedConfig",
|
10 |
+
"AutoModel": "modeling_nvembed.NVEmbedModel"
|
11 |
+
},
|
12 |
+
"is_mask_instruction": true,
|
13 |
+
"latent_attention_config": {
|
14 |
+
"model_type": "latent_attention"
|
15 |
+
},
|
16 |
+
"mask_type": "b",
|
17 |
+
"model_type": "nvembed",
|
18 |
+
"padding_side": "right",
|
19 |
+
"text_config": {
|
20 |
+
"_name_or_path": "bzantium/NV-Embed-v1",
|
21 |
+
"add_cross_attention": false,
|
22 |
+
"architectures": [
|
23 |
+
"MistralModel"
|
24 |
+
],
|
25 |
+
"attention_dropout": 0.0,
|
26 |
+
"bad_words_ids": null,
|
27 |
+
"begin_suppress_tokens": null,
|
28 |
+
"bos_token_id": 1,
|
29 |
+
"chunk_size_feed_forward": 0,
|
30 |
+
"cross_attention_hidden_size": null,
|
31 |
+
"decoder_start_token_id": null,
|
32 |
+
"diversity_penalty": 0.0,
|
33 |
+
"do_sample": false,
|
34 |
+
"early_stopping": false,
|
35 |
+
"encoder_no_repeat_ngram_size": 0,
|
36 |
+
"eos_token_id": 2,
|
37 |
+
"exponential_decay_length_penalty": null,
|
38 |
+
"finetuning_task": null,
|
39 |
+
"forced_bos_token_id": null,
|
40 |
+
"forced_eos_token_id": null,
|
41 |
+
"hidden_act": "silu",
|
42 |
+
"hidden_size": 4096,
|
43 |
+
"id2label": {
|
44 |
+
"0": "LABEL_0",
|
45 |
+
"1": "LABEL_1"
|
46 |
+
},
|
47 |
+
"initializer_range": 0.02,
|
48 |
+
"intermediate_size": 14336,
|
49 |
+
"is_decoder": false,
|
50 |
+
"is_encoder_decoder": false,
|
51 |
+
"label2id": {
|
52 |
+
"LABEL_0": 0,
|
53 |
+
"LABEL_1": 1
|
54 |
+
},
|
55 |
+
"length_penalty": 1.0,
|
56 |
+
"max_length": 20,
|
57 |
+
"max_position_embeddings": 32768,
|
58 |
+
"min_length": 0,
|
59 |
+
"model_type": "bidir_mistral",
|
60 |
+
"no_repeat_ngram_size": 0,
|
61 |
+
"num_attention_heads": 32,
|
62 |
+
"num_beam_groups": 1,
|
63 |
+
"num_beams": 1,
|
64 |
+
"num_hidden_layers": 32,
|
65 |
+
"num_key_value_heads": 8,
|
66 |
+
"num_return_sequences": 1,
|
67 |
+
"output_attentions": false,
|
68 |
+
"output_hidden_states": false,
|
69 |
+
"output_scores": false,
|
70 |
+
"pad_token_id": null,
|
71 |
+
"prefix": null,
|
72 |
+
"problem_type": null,
|
73 |
+
"pruned_heads": {},
|
74 |
+
"remove_invalid_values": false,
|
75 |
+
"repetition_penalty": 1.0,
|
76 |
+
"return_dict": true,
|
77 |
+
"return_dict_in_generate": false,
|
78 |
+
"rms_norm_eps": 1e-05,
|
79 |
+
"rope_theta": 10000.0,
|
80 |
+
"sep_token_id": null,
|
81 |
+
"sliding_window": 4096,
|
82 |
+
"suppress_tokens": null,
|
83 |
+
"task_specific_params": null,
|
84 |
+
"temperature": 1.0,
|
85 |
+
"tf_legacy_loss": false,
|
86 |
+
"tie_encoder_decoder": false,
|
87 |
+
"tie_word_embeddings": false,
|
88 |
+
"tokenizer_class": null,
|
89 |
+
"top_k": 50,
|
90 |
+
"top_p": 1.0,
|
91 |
+
"torch_dtype": "float32",
|
92 |
+
"torchscript": false,
|
93 |
+
"typical_p": 1.0,
|
94 |
+
"use_bfloat16": false,
|
95 |
+
"use_cache": true,
|
96 |
+
"vocab_size": 32000
|
97 |
+
},
|
98 |
+
"torch_dtype": "float16",
|
99 |
+
"transformers_version": "4.37.2"
|
100 |
+
}
|
config_sentence_transformers.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"__version__": {
|
3 |
+
"sentence_transformers": "2.7.0",
|
4 |
+
"transformers": "4.37.2",
|
5 |
+
"pytorch": "2.2.0+cu121"
|
6 |
+
},
|
7 |
+
"prompts": {},
|
8 |
+
"default_prompt_name": null
|
9 |
+
}
|
configuration_nvembed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from typing import Literal
|
3 |
+
from transformers import AutoConfig
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
from transformers.models.auto import CONFIG_MAPPING
|
6 |
+
from transformers.models.mistral import MistralConfig
|
7 |
+
|
8 |
+
NVEMBED_TYPE = "nvembed"
|
9 |
+
LATENT_ATTENTION_TYPE = "latent_attention"
|
10 |
+
BIDIR_MISTRAL_TYPE = "bidir_mistral"
|
11 |
+
|
12 |
+
class NVEmbedConfig(PretrainedConfig):
|
13 |
+
model_type = "nvembed"
|
14 |
+
is_composition = False
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
hidden_size=4096,
|
19 |
+
latent_attention_config=None,
|
20 |
+
text_config=None,
|
21 |
+
padding_side: Literal["right", "left"]="right",
|
22 |
+
add_pad_token: bool=True,
|
23 |
+
is_mask_instruction: bool = True,
|
24 |
+
add_eos: bool=True,
|
25 |
+
mask_type: str="b",
|
26 |
+
**kwargs,
|
27 |
+
):
|
28 |
+
if isinstance(latent_attention_config, dict):
|
29 |
+
latent_attention_config["model_type"] = (
|
30 |
+
latent_attention_config["model_type"] if "model_type" in latent_attention_config else LATENT_ATTENTION_TYPE
|
31 |
+
)
|
32 |
+
latent_attention_config = CONFIG_MAPPING[latent_attention_config["model_type"]](**latent_attention_config)
|
33 |
+
elif latent_attention_config is None:
|
34 |
+
latent_attention_config = CONFIG_MAPPING[LATENT_ATTENTION_TYPE]()
|
35 |
+
|
36 |
+
self.latent_attention_config = latent_attention_config
|
37 |
+
|
38 |
+
if isinstance(text_config, dict):
|
39 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
40 |
+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
41 |
+
elif text_config is None:
|
42 |
+
text_config = None
|
43 |
+
|
44 |
+
self.hidden_size = hidden_size
|
45 |
+
self.text_config = text_config
|
46 |
+
self.padding_side = padding_side
|
47 |
+
self.is_mask_instruction = is_mask_instruction
|
48 |
+
self.add_pad_token = add_pad_token
|
49 |
+
self.add_eos = add_eos
|
50 |
+
self.mask_type = mask_type
|
51 |
+
|
52 |
+
super().__init__(**kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
class LatentAttentionConfig(PretrainedConfig):
|
56 |
+
model_type = LATENT_ATTENTION_TYPE
|
57 |
+
is_composition = False
|
58 |
+
_name_or_path = "latent_attention"
|
59 |
+
|
60 |
+
def __init__(
|
61 |
+
self,
|
62 |
+
num_latents_value: int=512,
|
63 |
+
num_cross_heads: int=8,
|
64 |
+
output_normalize: bool=True,
|
65 |
+
hidden_dim: int=4096,
|
66 |
+
latent_dim: int=4096,
|
67 |
+
cross_dim_head: int=4096,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
self.num_latents_value = num_latents_value
|
71 |
+
self.num_cross_heads = num_cross_heads
|
72 |
+
self.output_normalize = output_normalize
|
73 |
+
self.hidden_dim = hidden_dim
|
74 |
+
self.latent_dim = latent_dim
|
75 |
+
self.cross_dim_head = cross_dim_head
|
76 |
+
|
77 |
+
|
78 |
+
class BidirectionalMistralConfig(MistralConfig):
|
79 |
+
model_type = BIDIR_MISTRAL_TYPE
|
80 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
81 |
+
|
82 |
+
AutoConfig.register(NVEMBED_TYPE, NVEmbedConfig)
|
83 |
+
AutoConfig.register(LATENT_ATTENTION_TYPE, LatentAttentionConfig)
|
84 |
+
AutoConfig.register(BIDIR_MISTRAL_TYPE, BidirectionalMistralConfig)
|
85 |
+
|
86 |
+
NVEmbedConfig.register_for_auto_class()
|
87 |
+
LatentAttentionConfig.register_for_auto_class()
|
88 |
+
BidirectionalMistralConfig.register_for_auto_class()
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3cead86cbeba2eae0b622600b6dd85481d77d4b446f9bba218e6ce281ef771b
|
3 |
+
size 4997761248
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13075d43d46244ceda81aa21c7ddfc7f15c3b07b4f8e0aeed3d3fcfad6131490
|
3 |
+
size 4915917048
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:46f2a9e87089992b5539e5d18cf17247e0b1ed5244c18810d084e509fcd45391
|
3 |
+
size 4999820296
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8441b80846588e19669c8f908c8c23f72950730f329a3384ca84243636d94829
|
3 |
+
size 788571960
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 15702032384
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"embedding_model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
7 |
+
"embedding_model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
8 |
+
"embedding_model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
9 |
+
"embedding_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
10 |
+
"embedding_model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
11 |
+
"embedding_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
12 |
+
"embedding_model.layers.0.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"embedding_model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
14 |
+
"embedding_model.layers.0.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
15 |
+
"embedding_model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
16 |
+
"embedding_model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"embedding_model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
18 |
+
"embedding_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"embedding_model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
20 |
+
"embedding_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"embedding_model.layers.1.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
22 |
+
"embedding_model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"embedding_model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"embedding_model.layers.1.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
25 |
+
"embedding_model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
26 |
+
"embedding_model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
27 |
+
"embedding_model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
28 |
+
"embedding_model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
29 |
+
"embedding_model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
30 |
+
"embedding_model.layers.10.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
31 |
+
"embedding_model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
32 |
+
"embedding_model.layers.10.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
33 |
+
"embedding_model.layers.10.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
34 |
+
"embedding_model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
35 |
+
"embedding_model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
36 |
+
"embedding_model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
37 |
+
"embedding_model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
38 |
+
"embedding_model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"embedding_model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
40 |
+
"embedding_model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
41 |
+
"embedding_model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
42 |
+
"embedding_model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"embedding_model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
44 |
+
"embedding_model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"embedding_model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
46 |
+
"embedding_model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"embedding_model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
48 |
+
"embedding_model.layers.12.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
49 |
+
"embedding_model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
50 |
+
"embedding_model.layers.12.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
51 |
+
"embedding_model.layers.12.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
52 |
+
"embedding_model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
53 |
+
"embedding_model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
54 |
+
"embedding_model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
55 |
+
"embedding_model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
56 |
+
"embedding_model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
57 |
+
"embedding_model.layers.13.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
58 |
+
"embedding_model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
59 |
+
"embedding_model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
60 |
+
"embedding_model.layers.13.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
61 |
+
"embedding_model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
62 |
+
"embedding_model.layers.14.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
63 |
+
"embedding_model.layers.14.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
64 |
+
"embedding_model.layers.14.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
65 |
+
"embedding_model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
66 |
+
"embedding_model.layers.14.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
67 |
+
"embedding_model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
68 |
+
"embedding_model.layers.14.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
69 |
+
"embedding_model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
70 |
+
"embedding_model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
71 |
+
"embedding_model.layers.15.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
72 |
+
"embedding_model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
73 |
+
"embedding_model.layers.15.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
74 |
+
"embedding_model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
75 |
+
"embedding_model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
76 |
+
"embedding_model.layers.15.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
77 |
+
"embedding_model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
78 |
+
"embedding_model.layers.15.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
79 |
+
"embedding_model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
80 |
+
"embedding_model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
81 |
+
"embedding_model.layers.16.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
82 |
+
"embedding_model.layers.16.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
83 |
+
"embedding_model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
84 |
+
"embedding_model.layers.16.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
85 |
+
"embedding_model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
86 |
+
"embedding_model.layers.16.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
87 |
+
"embedding_model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
88 |
+
"embedding_model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
89 |
+
"embedding_model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
90 |
+
"embedding_model.layers.17.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
91 |
+
"embedding_model.layers.17.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
92 |
+
"embedding_model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
93 |
+
"embedding_model.layers.17.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
94 |
+
"embedding_model.layers.17.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
95 |
+
"embedding_model.layers.17.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
96 |
+
"embedding_model.layers.17.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
97 |
+
"embedding_model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
98 |
+
"embedding_model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
99 |
+
"embedding_model.layers.18.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
100 |
+
"embedding_model.layers.18.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
101 |
+
"embedding_model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
102 |
+
"embedding_model.layers.18.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
103 |
+
"embedding_model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
104 |
+
"embedding_model.layers.18.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
105 |
+
"embedding_model.layers.18.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
106 |
+
"embedding_model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
107 |
+
"embedding_model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
108 |
+
"embedding_model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
109 |
+
"embedding_model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
110 |
+
"embedding_model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
111 |
+
"embedding_model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
112 |
+
"embedding_model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
113 |
+
"embedding_model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
114 |
+
"embedding_model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
115 |
+
"embedding_model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
116 |
+
"embedding_model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
117 |
+
"embedding_model.layers.2.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
118 |
+
"embedding_model.layers.2.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
119 |
+
"embedding_model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
120 |
+
"embedding_model.layers.2.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
121 |
+
"embedding_model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
122 |
+
"embedding_model.layers.2.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
123 |
+
"embedding_model.layers.2.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
124 |
+
"embedding_model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
125 |
+
"embedding_model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
126 |
+
"embedding_model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
127 |
+
"embedding_model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
128 |
+
"embedding_model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
129 |
+
"embedding_model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
130 |
+
"embedding_model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
131 |
+
"embedding_model.layers.20.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
132 |
+
"embedding_model.layers.20.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
133 |
+
"embedding_model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
134 |
+
"embedding_model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
135 |
+
"embedding_model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
136 |
+
"embedding_model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
137 |
+
"embedding_model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
138 |
+
"embedding_model.layers.21.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
139 |
+
"embedding_model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
140 |
+
"embedding_model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
141 |
+
"embedding_model.layers.21.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
142 |
+
"embedding_model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
143 |
+
"embedding_model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
144 |
+
"embedding_model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
145 |
+
"embedding_model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
146 |
+
"embedding_model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
147 |
+
"embedding_model.layers.22.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
148 |
+
"embedding_model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
149 |
+
"embedding_model.layers.22.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
150 |
+
"embedding_model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
151 |
+
"embedding_model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
152 |
+
"embedding_model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
153 |
+
"embedding_model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
154 |
+
"embedding_model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
155 |
+
"embedding_model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
156 |
+
"embedding_model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
157 |
+
"embedding_model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
158 |
+
"embedding_model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
159 |
+
"embedding_model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
160 |
+
"embedding_model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
161 |
+
"embedding_model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
162 |
+
"embedding_model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
163 |
+
"embedding_model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
164 |
+
"embedding_model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
165 |
+
"embedding_model.layers.24.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
166 |
+
"embedding_model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
167 |
+
"embedding_model.layers.24.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
168 |
+
"embedding_model.layers.24.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
169 |
+
"embedding_model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
170 |
+
"embedding_model.layers.25.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
171 |
+
"embedding_model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
172 |
+
"embedding_model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
173 |
+
"embedding_model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
174 |
+
"embedding_model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
175 |
+
"embedding_model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
176 |
+
"embedding_model.layers.25.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
177 |
+
"embedding_model.layers.25.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
178 |
+
"embedding_model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
179 |
+
"embedding_model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
180 |
+
"embedding_model.layers.26.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
181 |
+
"embedding_model.layers.26.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
182 |
+
"embedding_model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
183 |
+
"embedding_model.layers.26.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
184 |
+
"embedding_model.layers.26.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
185 |
+
"embedding_model.layers.26.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
186 |
+
"embedding_model.layers.26.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
187 |
+
"embedding_model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
188 |
+
"embedding_model.layers.27.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
189 |
+
"embedding_model.layers.27.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
190 |
+
"embedding_model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
191 |
+
"embedding_model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
192 |
+
"embedding_model.layers.27.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
193 |
+
"embedding_model.layers.27.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
194 |
+
"embedding_model.layers.27.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
195 |
+
"embedding_model.layers.27.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
196 |
+
"embedding_model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
197 |
+
"embedding_model.layers.28.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
198 |
+
"embedding_model.layers.28.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
199 |
+
"embedding_model.layers.28.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
200 |
+
"embedding_model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
201 |
+
"embedding_model.layers.28.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
202 |
+
"embedding_model.layers.28.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
203 |
+
"embedding_model.layers.28.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
204 |
+
"embedding_model.layers.28.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
205 |
+
"embedding_model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
206 |
+
"embedding_model.layers.29.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
207 |
+
"embedding_model.layers.29.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
208 |
+
"embedding_model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
209 |
+
"embedding_model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
210 |
+
"embedding_model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
211 |
+
"embedding_model.layers.29.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
212 |
+
"embedding_model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
213 |
+
"embedding_model.layers.29.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
214 |
+
"embedding_model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
215 |
+
"embedding_model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
216 |
+
"embedding_model.layers.3.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
217 |
+
"embedding_model.layers.3.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
218 |
+
"embedding_model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
219 |
+
"embedding_model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
220 |
+
"embedding_model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
221 |
+
"embedding_model.layers.3.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
222 |
+
"embedding_model.layers.3.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
223 |
+
"embedding_model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
224 |
+
"embedding_model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
225 |
+
"embedding_model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
226 |
+
"embedding_model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
227 |
+
"embedding_model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
228 |
+
"embedding_model.layers.30.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
229 |
+
"embedding_model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
230 |
+
"embedding_model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
231 |
+
"embedding_model.layers.30.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
232 |
+
"embedding_model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
233 |
+
"embedding_model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
234 |
+
"embedding_model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
235 |
+
"embedding_model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
236 |
+
"embedding_model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
237 |
+
"embedding_model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
238 |
+
"embedding_model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
239 |
+
"embedding_model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
240 |
+
"embedding_model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
241 |
+
"embedding_model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
242 |
+
"embedding_model.layers.4.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
243 |
+
"embedding_model.layers.4.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
244 |
+
"embedding_model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
245 |
+
"embedding_model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
246 |
+
"embedding_model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
247 |
+
"embedding_model.layers.4.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
248 |
+
"embedding_model.layers.4.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
249 |
+
"embedding_model.layers.4.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
250 |
+
"embedding_model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
251 |
+
"embedding_model.layers.5.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
252 |
+
"embedding_model.layers.5.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
253 |
+
"embedding_model.layers.5.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
254 |
+
"embedding_model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
255 |
+
"embedding_model.layers.5.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
256 |
+
"embedding_model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
257 |
+
"embedding_model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
258 |
+
"embedding_model.layers.5.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
259 |
+
"embedding_model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
260 |
+
"embedding_model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
261 |
+
"embedding_model.layers.6.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
262 |
+
"embedding_model.layers.6.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
263 |
+
"embedding_model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
264 |
+
"embedding_model.layers.6.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
265 |
+
"embedding_model.layers.6.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
266 |
+
"embedding_model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
267 |
+
"embedding_model.layers.6.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
268 |
+
"embedding_model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
269 |
+
"embedding_model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
270 |
+
"embedding_model.layers.7.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
271 |
+
"embedding_model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
272 |
+
"embedding_model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
273 |
+
"embedding_model.layers.7.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
274 |
+
"embedding_model.layers.7.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
275 |
+
"embedding_model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
276 |
+
"embedding_model.layers.7.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
277 |
+
"embedding_model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
278 |
+
"embedding_model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
279 |
+
"embedding_model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
280 |
+
"embedding_model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
281 |
+
"embedding_model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
282 |
+
"embedding_model.layers.8.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
283 |
+
"embedding_model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
284 |
+
"embedding_model.layers.8.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
285 |
+
"embedding_model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
286 |
+
"embedding_model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
287 |
+
"embedding_model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
288 |
+
"embedding_model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
289 |
+
"embedding_model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
290 |
+
"embedding_model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
291 |
+
"embedding_model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
292 |
+
"embedding_model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
293 |
+
"embedding_model.layers.9.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
294 |
+
"embedding_model.layers.9.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
295 |
+
"embedding_model.norm.weight": "model-00004-of-00004.safetensors",
|
296 |
+
"latent_attention_model.cross_attend_blocks.0.fn.to_kv.weight": "model-00001-of-00004.safetensors",
|
297 |
+
"latent_attention_model.cross_attend_blocks.0.fn.to_out.weight": "model-00001-of-00004.safetensors",
|
298 |
+
"latent_attention_model.cross_attend_blocks.0.fn.to_q.weight": "model-00001-of-00004.safetensors",
|
299 |
+
"latent_attention_model.cross_attend_blocks.0.norm.bias": "model-00001-of-00004.safetensors",
|
300 |
+
"latent_attention_model.cross_attend_blocks.0.norm.weight": "model-00001-of-00004.safetensors",
|
301 |
+
"latent_attention_model.cross_attend_blocks.0.norm_context.bias": "model-00001-of-00004.safetensors",
|
302 |
+
"latent_attention_model.cross_attend_blocks.0.norm_context.weight": "model-00001-of-00004.safetensors",
|
303 |
+
"latent_attention_model.cross_attend_blocks.1.fn.net.0.bias": "model-00001-of-00004.safetensors",
|
304 |
+
"latent_attention_model.cross_attend_blocks.1.fn.net.0.weight": "model-00001-of-00004.safetensors",
|
305 |
+
"latent_attention_model.cross_attend_blocks.1.fn.net.2.bias": "model-00001-of-00004.safetensors",
|
306 |
+
"latent_attention_model.cross_attend_blocks.1.fn.net.2.weight": "model-00001-of-00004.safetensors",
|
307 |
+
"latent_attention_model.cross_attend_blocks.1.norm.bias": "model-00001-of-00004.safetensors",
|
308 |
+
"latent_attention_model.cross_attend_blocks.1.norm.weight": "model-00001-of-00004.safetensors",
|
309 |
+
"latent_attention_model.latents": "model-00001-of-00004.safetensors"
|
310 |
+
}
|
311 |
+
}
|
modeling_nvembed.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union, Dict, Mapping, Optional, Tuple, TypedDict
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
from functools import partial
|
7 |
+
from contextlib import nullcontext
|
8 |
+
from transformers import AutoModel, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding
|
9 |
+
from transformers.modeling_utils import PreTrainedModel
|
10 |
+
from transformers.models.auto import AutoTokenizer
|
11 |
+
from transformers.models.mistral.modeling_mistral import MISTRAL_INPUTS_DOCSTRING
|
12 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
13 |
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
|
14 |
+
from transformers import MistralModel, MistralConfig
|
15 |
+
from transformers.cache_utils import Cache, DynamicCache
|
16 |
+
from transformers.utils import (
|
17 |
+
add_start_docstrings_to_model_forward,
|
18 |
+
logging,
|
19 |
+
)
|
20 |
+
from einops import rearrange, repeat
|
21 |
+
from tqdm.auto import tqdm
|
22 |
+
from datasets import Dataset
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
from .configuration_nvembed import NVEmbedConfig, LatentAttentionConfig, BidirectionalMistralConfig
|
25 |
+
|
26 |
+
logger = logging.get_logger(__name__)
|
27 |
+
|
28 |
+
class NVEmbedFeatures(TypedDict):
|
29 |
+
input_dict: torch.Tensor
|
30 |
+
attention_mask: torch.Tensor
|
31 |
+
pool_mask: torch.Tensor
|
32 |
+
|
33 |
+
class BidirectionalMistralModel(MistralModel):
|
34 |
+
config_class = BidirectionalMistralConfig
|
35 |
+
|
36 |
+
def __init__(self, config: MistralConfig):
|
37 |
+
super().__init__(config)
|
38 |
+
for layer in self.layers:
|
39 |
+
layer.self_attn.is_causal = False
|
40 |
+
self._attn_implementation = "eager"
|
41 |
+
|
42 |
+
@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
input_ids: torch.LongTensor = None,
|
46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
position_ids: Optional[torch.LongTensor] = None,
|
48 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
49 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
50 |
+
use_cache: Optional[bool] = None,
|
51 |
+
output_attentions: Optional[bool] = None,
|
52 |
+
output_hidden_states: Optional[bool] = None,
|
53 |
+
return_dict: Optional[bool] = None,
|
54 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
55 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
56 |
+
output_hidden_states = (
|
57 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
58 |
+
)
|
59 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
60 |
+
|
61 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
62 |
+
|
63 |
+
# retrieve input_ids and inputs_embeds
|
64 |
+
if input_ids is not None and inputs_embeds is not None:
|
65 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
66 |
+
elif input_ids is not None:
|
67 |
+
batch_size, seq_length = input_ids.shape
|
68 |
+
elif inputs_embeds is not None:
|
69 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
70 |
+
else:
|
71 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
72 |
+
|
73 |
+
if self.gradient_checkpointing and self.training:
|
74 |
+
if use_cache:
|
75 |
+
logger.warning_once(
|
76 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
77 |
+
)
|
78 |
+
use_cache = False
|
79 |
+
|
80 |
+
past_key_values_length = 0
|
81 |
+
|
82 |
+
if use_cache:
|
83 |
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
84 |
+
if use_legacy_cache:
|
85 |
+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
86 |
+
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
87 |
+
|
88 |
+
if position_ids is None:
|
89 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
90 |
+
position_ids = torch.arange(
|
91 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
92 |
+
)
|
93 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
94 |
+
else:
|
95 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
96 |
+
|
97 |
+
if inputs_embeds is None:
|
98 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
99 |
+
|
100 |
+
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
|
101 |
+
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
102 |
+
if is_padding_right:
|
103 |
+
raise ValueError(
|
104 |
+
"You are attempting to perform batched generation with padding_side='right'"
|
105 |
+
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
|
106 |
+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
107 |
+
)
|
108 |
+
|
109 |
+
if self._attn_implementation == "flash_attention_2":
|
110 |
+
# 2d mask is passed through the layers
|
111 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
112 |
+
elif self._attn_implementation == "sdpa" and not output_attentions:
|
113 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
114 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
115 |
+
attention_mask = _prepare_4d_attention_mask_for_sdpa(
|
116 |
+
attention_mask, inputs_embeds.dtype
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
# 4d mask is passed through the layers
|
120 |
+
attention_mask = _prepare_4d_attention_mask(
|
121 |
+
attention_mask, inputs_embeds.dtype,
|
122 |
+
)
|
123 |
+
|
124 |
+
hidden_states = inputs_embeds
|
125 |
+
|
126 |
+
# decoder layers
|
127 |
+
all_hidden_states = () if output_hidden_states else None
|
128 |
+
all_self_attns = () if output_attentions else None
|
129 |
+
next_decoder_cache = None
|
130 |
+
|
131 |
+
for decoder_layer in self.layers:
|
132 |
+
if output_hidden_states:
|
133 |
+
all_hidden_states += (hidden_states,)
|
134 |
+
|
135 |
+
if self.gradient_checkpointing and self.training:
|
136 |
+
layer_outputs = self._gradient_checkpointing_func(
|
137 |
+
decoder_layer.__call__,
|
138 |
+
hidden_states,
|
139 |
+
attention_mask,
|
140 |
+
position_ids,
|
141 |
+
past_key_values,
|
142 |
+
output_attentions,
|
143 |
+
use_cache,
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
layer_outputs = decoder_layer(
|
147 |
+
hidden_states,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
position_ids=position_ids,
|
150 |
+
past_key_value=past_key_values,
|
151 |
+
output_attentions=output_attentions,
|
152 |
+
use_cache=use_cache,
|
153 |
+
)
|
154 |
+
|
155 |
+
hidden_states = layer_outputs[0]
|
156 |
+
|
157 |
+
if use_cache:
|
158 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
159 |
+
|
160 |
+
if output_attentions:
|
161 |
+
all_self_attns += (layer_outputs[1],)
|
162 |
+
|
163 |
+
hidden_states = self.norm(hidden_states)
|
164 |
+
|
165 |
+
# add hidden states from the last decoder layer
|
166 |
+
if output_hidden_states:
|
167 |
+
all_hidden_states += (hidden_states,)
|
168 |
+
|
169 |
+
next_cache = None
|
170 |
+
if use_cache:
|
171 |
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
172 |
+
|
173 |
+
if not return_dict:
|
174 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
175 |
+
return BaseModelOutputWithPast(
|
176 |
+
last_hidden_state=hidden_states,
|
177 |
+
past_key_values=next_cache,
|
178 |
+
hidden_states=all_hidden_states,
|
179 |
+
attentions=all_self_attns,
|
180 |
+
)
|
181 |
+
|
182 |
+
def _move_to_device(maybe_tensor, device: torch.device):
|
183 |
+
if torch.is_tensor(maybe_tensor):
|
184 |
+
return maybe_tensor.to(device, non_blocking=device.type == "cuda")
|
185 |
+
elif isinstance(maybe_tensor, dict):
|
186 |
+
return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()}
|
187 |
+
elif isinstance(maybe_tensor, list):
|
188 |
+
return [_move_to_device(x, device) for x in maybe_tensor]
|
189 |
+
elif isinstance(maybe_tensor, tuple):
|
190 |
+
return tuple([_move_to_device(x, device) for x in maybe_tensor])
|
191 |
+
elif isinstance(maybe_tensor, Mapping):
|
192 |
+
return type(maybe_tensor)({k: _move_to_device(v, device) for k, v in maybe_tensor.items()})
|
193 |
+
else:
|
194 |
+
return maybe_tensor
|
195 |
+
|
196 |
+
def move_to_device(sample, device: torch.device):
|
197 |
+
if device.type == "cpu":
|
198 |
+
return sample
|
199 |
+
|
200 |
+
if len(sample) == 0:
|
201 |
+
return {}
|
202 |
+
return _move_to_device(sample, device)
|
203 |
+
|
204 |
+
|
205 |
+
def input_transform_func(
|
206 |
+
tokenizer: PreTrainedTokenizerFast,
|
207 |
+
examples: Dict[str, List],
|
208 |
+
always_add_eos: bool,
|
209 |
+
max_length: int,
|
210 |
+
instruction: str,
|
211 |
+
) -> BatchEncoding:
|
212 |
+
if always_add_eos:
|
213 |
+
examples['input_texts'] = [instruction + input_example + tokenizer.eos_token for input_example in examples['input_texts']]
|
214 |
+
batch_dict = tokenizer(
|
215 |
+
examples['input_texts'],
|
216 |
+
max_length=max_length,
|
217 |
+
padding=True,
|
218 |
+
return_token_type_ids=False,
|
219 |
+
return_tensors="pt",
|
220 |
+
truncation=True)
|
221 |
+
return batch_dict
|
222 |
+
|
223 |
+
|
224 |
+
class PreNorm(torch.nn.Module):
|
225 |
+
def __init__(self, dim, fn, context_dim = None):
|
226 |
+
super().__init__()
|
227 |
+
self.fn = fn
|
228 |
+
self.norm = torch.nn.LayerNorm(dim)
|
229 |
+
self.norm_context = torch.nn.LayerNorm(context_dim) if exists(context_dim) else None
|
230 |
+
|
231 |
+
def forward(self, x, **kwargs):
|
232 |
+
x = self.norm(x)
|
233 |
+
if exists(self.norm_context):
|
234 |
+
context = kwargs['context']
|
235 |
+
normed_context = self.norm_context(context)
|
236 |
+
kwargs.update(context = normed_context)
|
237 |
+
return self.fn(x, **kwargs)
|
238 |
+
|
239 |
+
class GEGLU(torch.nn.Module):
|
240 |
+
def forward(self, x):
|
241 |
+
x, gates = x.chunk(2, dim = -1)
|
242 |
+
return x * torch.nn.functional.gelu(gates)
|
243 |
+
|
244 |
+
class FeedForward(torch.nn.Module):
|
245 |
+
def __init__(self, dim, mult = 4):
|
246 |
+
super().__init__()
|
247 |
+
self.net = torch.nn.Sequential(torch.nn.Linear(dim, dim * mult * 2),
|
248 |
+
GEGLU(),
|
249 |
+
torch.nn.Linear(dim * mult, dim))
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
return self.net(x)
|
253 |
+
|
254 |
+
def exists(val):
|
255 |
+
return val is not None
|
256 |
+
|
257 |
+
def default(val, d):
|
258 |
+
return val if exists(val) else d
|
259 |
+
|
260 |
+
|
261 |
+
class Attention(torch.nn.Module):
|
262 |
+
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64):
|
263 |
+
super().__init__()
|
264 |
+
inner_dim = dim_head * heads
|
265 |
+
context_dim = default(context_dim, query_dim)
|
266 |
+
self.scale = dim_head ** -0.5
|
267 |
+
self.heads = heads
|
268 |
+
|
269 |
+
self.to_q = torch.nn.Linear(query_dim, inner_dim, bias = False)
|
270 |
+
self.to_kv = torch.nn.Linear(context_dim, inner_dim * 2, bias = False)
|
271 |
+
self.to_out = torch.nn.Linear(inner_dim, query_dim, bias = False)
|
272 |
+
|
273 |
+
def forward(self, x, context = None, mask = None):
|
274 |
+
h = self.heads
|
275 |
+
q = self.to_q(x)
|
276 |
+
context = default(context, x)
|
277 |
+
k, v = self.to_kv(context).chunk(2, dim = -1)
|
278 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
|
279 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
|
280 |
+
out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
281 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
282 |
+
return self.to_out(out)
|
283 |
+
|
284 |
+
|
285 |
+
class LatentAttentionModel(PreTrainedModel):
|
286 |
+
config_class = LatentAttentionConfig
|
287 |
+
|
288 |
+
def __init__(self, config: LatentAttentionConfig):
|
289 |
+
super().__init__(config)
|
290 |
+
## cross-attention block
|
291 |
+
num_latents, latent_dim, cross_heads, cross_dim_head = config.num_latents_value, config.latent_dim, config.num_cross_heads, config.cross_dim_head
|
292 |
+
dim = config.hidden_dim
|
293 |
+
# init latent_attention and latents
|
294 |
+
self.cross_attend_blocks = torch.nn.ModuleList([
|
295 |
+
PreNorm(latent_dim, Attention(latent_dim, dim, heads = cross_heads, dim_head = cross_dim_head),
|
296 |
+
context_dim = dim),
|
297 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
298 |
+
])
|
299 |
+
self.output_normalize = config.output_normalize
|
300 |
+
self.register_parameter("latents", torch.nn.Parameter(torch.randn(num_latents, latent_dim)))
|
301 |
+
|
302 |
+
def forward(self, hiddens, attention_mask: torch.Tensor=None):
|
303 |
+
## cross-attention block
|
304 |
+
cross_attn, cross_ff = self.cross_attend_blocks
|
305 |
+
b, *_, device = *hiddens.shape, hiddens.device
|
306 |
+
x = repeat(self.latents, 'n d -> b n d', b = b)
|
307 |
+
hiddens = cross_attn(hiddens, context = x, mask = None) + hiddens
|
308 |
+
hiddens = cross_ff(hiddens) + hiddens
|
309 |
+
if attention_mask !=None:
|
310 |
+
s = torch.sum(hiddens * attention_mask.unsqueeze(-1).float(), dim=1)
|
311 |
+
d = attention_mask.sum(dim=1, keepdim=True).float()
|
312 |
+
hiddens = s / d
|
313 |
+
if self.output_normalize:
|
314 |
+
hiddens = torch.nn.functional.normalize(hiddens, p=2, dim=-1)
|
315 |
+
return hiddens
|
316 |
+
|
317 |
+
class NVEmbedModel(PreTrainedModel):
|
318 |
+
config_class = NVEmbedConfig
|
319 |
+
|
320 |
+
def __init__(self, config: NVEmbedConfig):
|
321 |
+
super().__init__(config)
|
322 |
+
self.latent_attention_model = AutoModel.from_config(config.latent_attention_config)
|
323 |
+
self.embedding_model = AutoModel.from_config(
|
324 |
+
config.text_config,
|
325 |
+
) if config.text_config is not None else None
|
326 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path, trust_remote_code=True) if config.text_config is not None else None
|
327 |
+
self.padding_side = config.padding_side
|
328 |
+
self.is_mask_instruction = config.is_mask_instruction
|
329 |
+
self.add_eos = config.add_eos
|
330 |
+
self.mask_type = config.mask_type
|
331 |
+
if config.add_pad_token and self.tokenizer is not None:
|
332 |
+
self.add_pad_token()
|
333 |
+
|
334 |
+
def add_pad_token(self):
|
335 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
336 |
+
self.tokenizer.padding_side = self.padding_side
|
337 |
+
|
338 |
+
def prepare_kwargs_from_batch(self, batch_dict: dict, instruction_lens: int, device: torch.device):
|
339 |
+
batch_dict = move_to_device(batch_dict, device)
|
340 |
+
attention_mask = batch_dict['attention_mask'].clone() if 'attention_mask' in batch_dict else None
|
341 |
+
if (attention_mask is not None and
|
342 |
+
self.padding_side == "right" and
|
343 |
+
self.is_mask_instruction == True and
|
344 |
+
instruction_lens > 0):
|
345 |
+
# Mask out the instruction tokens for mean-pooling
|
346 |
+
attention_mask[:, :instruction_lens] = 0
|
347 |
+
features: NVEmbedFeatures = {
|
348 |
+
'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
|
349 |
+
'attention_mask': batch_dict['attention_mask'],
|
350 |
+
'pool_mask': attention_mask,
|
351 |
+
}
|
352 |
+
return features
|
353 |
+
|
354 |
+
@torch.no_grad()
|
355 |
+
def _do_encode(self,
|
356 |
+
prompts: List[str],
|
357 |
+
batch_size,
|
358 |
+
instruction: str,
|
359 |
+
max_length: int=4096,
|
360 |
+
num_workers: int=32,
|
361 |
+
) -> np.ndarray:
|
362 |
+
dataset: Dataset = Dataset.from_dict({'input_texts': prompts})
|
363 |
+
dataset.set_transform(partial(input_transform_func,
|
364 |
+
self.tokenizer,
|
365 |
+
always_add_eos=True,
|
366 |
+
max_length=max_length,
|
367 |
+
instruction=instruction))
|
368 |
+
|
369 |
+
data_collator = DataCollatorWithPadding(self.tokenizer)
|
370 |
+
data_loader = DataLoader(
|
371 |
+
dataset,
|
372 |
+
batch_size=batch_size,
|
373 |
+
shuffle=False,
|
374 |
+
drop_last=False,
|
375 |
+
num_workers=num_workers,
|
376 |
+
collate_fn=data_collator,
|
377 |
+
pin_memory=True)
|
378 |
+
|
379 |
+
if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
|
380 |
+
instruction_lens = len(self.tokenizer.tokenize(instruction))
|
381 |
+
else:
|
382 |
+
instruction_lens = 0
|
383 |
+
|
384 |
+
encoded_embeds = []
|
385 |
+
device = next(self.embedding_model.parameters()).device
|
386 |
+
for batch_dict in tqdm(data_loader, desc='encoding', mininterval=10):
|
387 |
+
features = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
|
388 |
+
embeds=self(**features)["sentence_embeddings"].squeeze(1)
|
389 |
+
encoded_embeds.append(embeds.cpu().numpy())
|
390 |
+
return np.concatenate(encoded_embeds, axis=0)
|
391 |
+
|
392 |
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pool_mask: Optional[torch.Tensor]=None, return_dict: bool=True):
|
393 |
+
autocast_ctx = torch.cuda.amp.autocast if torch.cuda.is_available() else nullcontext
|
394 |
+
with autocast_ctx():
|
395 |
+
## decoder only layer
|
396 |
+
outputs = self.embedding_model(
|
397 |
+
input_ids=input_ids,
|
398 |
+
attention_mask=attention_mask,
|
399 |
+
)
|
400 |
+
## latent attention layer
|
401 |
+
embeds = self.latent_attention_model(
|
402 |
+
outputs.last_hidden_state,
|
403 |
+
pool_mask,
|
404 |
+
)
|
405 |
+
if not return_dict:
|
406 |
+
return (embeds,)
|
407 |
+
return {"sentence_embeddings": embeds}
|
408 |
+
|
409 |
+
|
410 |
+
@torch.no_grad()
|
411 |
+
def encode(self, prompts: List[str], instruction: str="", max_length: int=4096):
|
412 |
+
if self.padding_side == "right" and self.is_mask_instruction == True and len(instruction) > 0:
|
413 |
+
instruction_lens = len(self.tokenizer.tokenize(instruction))
|
414 |
+
else:
|
415 |
+
instruction_lens = 0
|
416 |
+
|
417 |
+
device = next(self.embedding_model.parameters()).device
|
418 |
+
batch_dict = input_transform_func(self.tokenizer,
|
419 |
+
{"input_texts": [prompt for prompt in prompts]},
|
420 |
+
always_add_eos=True,
|
421 |
+
max_length=max_length,
|
422 |
+
instruction=instruction)
|
423 |
+
|
424 |
+
features: NVEmbedFeatures = self.prepare_kwargs_from_batch(batch_dict, instruction_lens, device=device)
|
425 |
+
return self(**features)["sentence_embeddings"].squeeze(1)
|
426 |
+
|
427 |
+
|
428 |
+
## AutoModel Register
|
429 |
+
AutoModel.register(NVEmbedConfig, NVEmbedModel)
|
430 |
+
AutoModel.register(LatentAttentionConfig, LatentAttentionModel)
|
431 |
+
AutoModel.register(BidirectionalMistralConfig, BidirectionalMistralModel)
|
432 |
+
|
433 |
+
## Register for auto class
|
434 |
+
NVEmbedModel.register_for_auto_class("AutoModel")
|
435 |
+
LatentAttentionModel.register_for_auto_class("AutoModel")
|
436 |
+
BidirectionalMistralModel.register_for_auto_class("AutoModel")
|
modules.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"idx": 0,
|
4 |
+
"name": "0",
|
5 |
+
"path": "",
|
6 |
+
"type": "sentence_transformers.models.Transformer"
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"idx": 1,
|
10 |
+
"name": "1",
|
11 |
+
"path": "1_Pooling",
|
12 |
+
"type": "sentence_transformers.models.Pooling"
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"idx": 2,
|
16 |
+
"name": "2",
|
17 |
+
"path": "2_Normalize",
|
18 |
+
"type": "sentence_transformers.models.Normalize"
|
19 |
+
}
|
20 |
+
]
|
sentence_bert_config.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"max_seq_length": 4096,
|
3 |
+
"do_lower_case": false
|
4 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "</s>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
|
3 |
+
size 493443
|
tokenizer_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"additional_special_tokens": [],
|
31 |
+
"bos_token": "<s>",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "</s>",
|
34 |
+
"legacy": true,
|
35 |
+
"model_max_length": 1000000000000000019884624838656,
|
36 |
+
"pad_token": "</s>",
|
37 |
+
"sp_model_kwargs": {},
|
38 |
+
"spaces_between_special_tokens": false,
|
39 |
+
"tokenizer_class": "LlamaTokenizer",
|
40 |
+
"unk_token": "<unk>",
|
41 |
+
"use_default_system_prompt": false
|
42 |
+
}
|