base model release
Browse files- .gitignore +1 -0
- README.md +165 -6
- added_tokens.json +14 -0
- config.json +26 -0
- configuration_xgenmm.py +155 -0
- demo.ipynb +0 -0
- generation_config.json +7 -0
- image_processing_xgenmm.py +409 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +669 -0
- modeling_xgenmm.py +105 -0
- preprocessor_config.json +23 -0
- special_tokens_map.json +30 -0
- test_samples/few_shots.json +22 -0
- test_samples/images/000adfe5b817011c.jpg +0 -0
- test_samples/images/COCO_val2014_000000176466.jpg +0 -0
- test_samples/images/COCO_val2014_000000267408.jpg +0 -0
- test_samples/images/COCO_val2014_000000392640.jpg +0 -0
- test_samples/images/COCO_val2014_000000486568.jpg +0 -0
- test_samples/zero_shot.json +4 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +137 -0
- utils.py +383 -0
- vlm.py +1308 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
**/__pycache__/**
|
README.md
CHANGED
@@ -7,26 +7,185 @@ pipeline_tag: image-text-to-text
|
|
7 |
|
8 |
|
9 |
# Model description
|
|
|
10 |
|
11 |
-
|
|
|
12 |
|
13 |
-
|
|
|
|
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
# Bias, Risks, Limitations, and Ethical Considerations
|
17 |
|
18 |
# How to use
|
19 |
|
20 |
-
> We require use the
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
|
23 |
# License
|
24 |
|
25 |
-
Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Troubleshoot
|
28 |
|
29 |
-
1. If you
|
30 |
|
31 |
```
|
32 |
pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
|
|
|
7 |
|
8 |
|
9 |
# Model description
|
10 |
+
We are excited to announce the continuation and rebranding of our **BLIP series** into **XGen-MM**, to be better aligned with Salesforce's unified XGen initiative for large foundation models! This rebranding marks a significant step in our ongoing development of cutting-edge multimodal technologies.
|
11 |
|
12 |
+
`XGen-MM` is a series of the latest foundational Large Multimodal Models (LMMs) developed by Salesforce AI Research. This series advances upon the successful designs of the `BLIP` series, incorporating fundamental enhancements that ensure a more robust and superior foundation. \
|
13 |
+
These models have been trained at scale on high-quality image caption datasets and interleaved image-text data. XGen-MM highlights a few features below,
|
14 |
|
15 |
+
* The **pretrained** foundation model, `xgen-mm-phi3-mini-base-r-v1`, achieves state-of-the-art performance under 5b parameters and demonstrates strong in-context learning capabilities.
|
16 |
+
* The **instruct** fine-tuned model, `xgen-mm-phi3-mini-instruct-r-v1`, achieves state-of-the-art performance among open-source and closed-source VLMs under 5b parameters.
|
17 |
+
* `xgen-mm-phi3-mini-instruct-r-v1` supports flexible high-resolution image encoding with efficient visual token sampling.
|
18 |
|
19 |
+
More technical details will come with a technical report soon.
|
20 |
+
|
21 |
+
|
22 |
+
# Datasets
|
23 |
+
|
24 |
+
| Dataset Type| Dataset(s) Used |
|
25 |
+
|--------|------------------------------------------|
|
26 |
+
| Pretrain | caption data: (datacomp, cc12m, cc3m, SBU, vg) && interleaved data: obelics |
|
27 |
+
| Instruction Tuning | LLaVA-Instruct-150K, ShareGPT4V captions, a mixture of academic VQA data including OCR/Document/Chart-focused tasks, publicly available text-only instruction data |
|
28 |
+
|
29 |
+
# Results
|
30 |
+
|
31 |
+
### Pretrain (base model without instruction tuning)
|
32 |
+
| Model | Shot | COCO (val) | NoCaps (val) | TextCaps (val) | OKVQA (val) | TextVQA (val) | VizWiz (testdev) | VQAv2 (testdev) |
|
33 |
+
|-------------|------|------------|--------------|----------------|--------------|---------------|------------------|-----------------|
|
34 |
+
| Flamingo-3B | 4 | 85.0 | - | - | 43.3 | 32.7 | 34 | 53.2 |
|
35 |
+
| | 8 | 90.6 | - | - | 44.6 | 32.4 | 38.4 | 55.4 |
|
36 |
+
| MM1-3B | 0 | 73.5 | 55.6 | 63.3 | 26.1 | 29.4 | 15.6 | 46.2 |
|
37 |
+
| | 4 | 112.3 | 99.7 | 84.1 | 48.6 | 45.3 | 38.0 | 57.9 |
|
38 |
+
| | 8 | 114.6 | 104.7 | 88.8 | 48.4 | 44.6 | 46.4 | 63.6 |
|
39 |
+
| **xgen-mm-phi3-mini-base-r-v1 (Ours)**| 0 | **81.7** | **80.2** | 60.7 | **26.5** | **36.0** | **21.2** | **48.1** |
|
40 |
+
| | 4 | 110.5 | **101.7** | **84.6** | **49.2** | **46.1** | **38.4** | **63.9** |
|
41 |
+
| | 8 | 112.1 | 104.4 | 87.7 | **49.1** | **46.4** | 44.3 | **63.8** |
|
42 |
+
|
43 |
+
### Instruct (after instruction tuning)
|
44 |
+
| Model | SEED-IMG | MMBench(dev) | MME-total | MME-P | MME-C | MMStar | MMMU (val) | MMVet | MathVista (mini) | ScienceQA (test) | POPE | AI2D | |
|
45 |
+
|----------------------------|----------|--------------|-----------|----------|---------|----------|------------|----------|------------------|------------------|----------|----------|---|
|
46 |
+
| MM1-3B-Chat | 68.8 | 67.8 | 1761 | **1482** | 279 | - | 33.9 | 43.7 | - | - | **87.4** | - | |
|
47 |
+
| openbmb/MiniCPM-V-2 | 67.1 | 69.6 | 1808 | - | - | - | 38.2 | - | 38.7 | - | - | - | |
|
48 |
+
| VILA1.5-3B | 67.9 | 63.4 | - | 1442 | - | - | 33.3 | 35.4 | - | 69.0 | 85.9 | - | |
|
49 |
+
| xtuner/llava-phi-3-mini-hf | 70.0 | 69.2 | 1790 | 1477 | 313 | 43.7 | **41.4** | - | - | 73.7 | 87.3 | 69.3 | |
|
50 |
+
| **xgen-mm-phi3-mini-instruct-r-v1 (Ours)** | **72.1** | **74.1** | **1827** | 1467 | **360** | **44.6** | 39.8 | **45.1** | **39.3** | **74.2** | 87.2 | **75.8** | |
|
51 |
|
|
|
52 |
|
53 |
# How to use
|
54 |
|
55 |
+
> We require the use of the development version (`"4.41.0.dev0"`) of the `transformers` library. To get it, as of 05/07/2024, one can use `pip uninstall -y transformers && pip install git+https://github.com/huggingface/transformers.`
|
56 |
+
|
57 |
+
```python
|
58 |
+
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor
|
59 |
+
import json
|
60 |
+
import PIL
|
61 |
+
import IPython.display as display
|
62 |
+
import torch
|
63 |
+
model = AutoModelForVision2Seq.from_pretrained("./", trust_remote_code=True)
|
64 |
+
tokenizer = AutoTokenizer.from_pretrained("./", trust_remote_code=True, use_fast=True, legacy=False)
|
65 |
+
image_processor = AutoImageProcessor.from_pretrained("./", trust_remote_code=True)
|
66 |
+
tokenizer = model.update_special_tokens(tokenizer)
|
67 |
+
|
68 |
+
model = model.to('cuda')
|
69 |
+
tokenizer.padding_side = "left"
|
70 |
+
|
71 |
+
def apply_prompt_template(prompt, num_images=1, num_tokens_per_vis = 128, in_context=False, output=None):
|
72 |
+
"""
|
73 |
+
num_tokens_per_vis: model.vlm.num_tokens_per_vis
|
74 |
+
"""
|
75 |
+
placeholder_image_tokens = "<image placeholder>" * (num_tokens_per_vis - 1)
|
76 |
+
if in_context:
|
77 |
+
formatted_prompt = f"<image>{placeholder_image_tokens}" + f"{prompt}" + f"{output}" + "<|endofchunk|>"
|
78 |
+
else:
|
79 |
+
formatted_prompt = f"<image>{placeholder_image_tokens}"*num_images + f"{prompt}"
|
80 |
+
return formatted_prompt
|
81 |
+
|
82 |
+
############ Zero shot inference ##########
|
83 |
+
with open('./test_samples/zero_shot.json') as f:
|
84 |
+
sample = json.load(f)
|
85 |
+
instruction = sample['instruction']
|
86 |
+
img = PIL.Image.open(sample['image_path'])
|
87 |
+
print("==> Instruction: ", instruction)
|
88 |
+
print("==> Image: ")
|
89 |
+
display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
|
90 |
+
inputs = image_processor([img], return_tensors="pt")
|
91 |
+
prompt = apply_prompt_template(instruction)
|
92 |
+
language_inputs = tokenizer([prompt], return_tensors="pt")
|
93 |
+
inputs.update(language_inputs)
|
94 |
+
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
|
95 |
+
|
96 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
97 |
+
generated_text = model.generate(**inputs,
|
98 |
+
pad_token_id=tokenizer.pad_token_id,
|
99 |
+
do_sample=False, max_new_tokens=256, top_p=None, num_beams=1,
|
100 |
+
length_penalty=1.0, repetition_penalty=2.0)
|
101 |
+
prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True)
|
102 |
+
print("==> prediciton: ", prediction)
|
103 |
+
print("-"*120)
|
104 |
+
# ==> prediciton: A man sits on a bench in front of the Red Corner Cafe.
|
105 |
+
|
106 |
+
############ Few shots inference ##########
|
107 |
+
# prepare in-context examples
|
108 |
+
with open('./test_samples/few_shots.json') as f:
|
109 |
+
incontext_data = json.load(f)
|
110 |
+
print(f'In-context learning with {len(incontext_data)} examples.')
|
111 |
+
context_images, context_text = [], ""
|
112 |
+
for example in incontext_data:
|
113 |
+
print("-"*40 + f" {example} " + "-"*40)
|
114 |
+
img = PIL.Image.open(incontext_data[example]['image_path'])
|
115 |
+
instruction = incontext_data[example]['instruction']
|
116 |
+
example_text = apply_prompt_template(prompt=instruction, in_context=True, output=incontext_data[example]['output'])
|
117 |
+
context_images.append(img)
|
118 |
+
context_text += (example_text)
|
119 |
+
print("==> Instruction: ", instruction)
|
120 |
+
print("==> Image: ")
|
121 |
+
display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
|
122 |
+
print("==> Output: ", incontext_data[example]['output'])
|
123 |
+
# prepare test example
|
124 |
+
with open('./test_samples/zero_shot.json') as f:
|
125 |
+
sample = json.load(f)
|
126 |
+
instruction = "A short description of this image in one sentence:"
|
127 |
+
print("-"*40 + " Prediction " + "-"*40)
|
128 |
+
img = PIL.Image.open(sample['image_path'])
|
129 |
+
print("==> Instruction: ", instruction)
|
130 |
+
print("==> Image: ")
|
131 |
+
display.display(img.resize((int(img.width*0.3), int(img.height*0.3))))
|
132 |
+
prompt = apply_prompt_template(instruction)
|
133 |
+
batch_images = context_images + [img]
|
134 |
+
batch_text = context_text + prompt
|
135 |
+
# prepare inputs
|
136 |
+
inputs = image_processor(batch_images, return_tensors="pt")
|
137 |
+
language_inputs = tokenizer([batch_text], return_tensors="pt")
|
138 |
+
inputs.update(language_inputs)
|
139 |
+
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
|
140 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
141 |
+
generated_text = model.generate(**inputs,
|
142 |
+
pad_token_id=tokenizer.pad_token_id,
|
143 |
+
do_sample=False, max_new_tokens=256, top_p=None, num_beams=1,
|
144 |
+
length_penalty=1.0)
|
145 |
+
prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True)
|
146 |
+
print("==> prediciton: ", prediction)
|
147 |
+
print("-"*120)
|
148 |
+
```
|
149 |
+
|
150 |
+
More comprehensive examples can be found in the [notebook](demo.ipynb).
|
151 |
+
|
152 |
+
# Reproducibility:
|
153 |
+
|
154 |
+
Our SFT evaluation is based on the VLMEvalKit, in which we fixed some inconsistencies with the official benchmarks (e.g., LLM judge API). During our development, we noticed that the raw resolution of the input image would noticeably affect the model output in some cases.
|
155 |
+
|
156 |
+
|
157 |
+
# Bias, Risks, Limitations, and Ethical Considerations
|
158 |
+
The main data sources are from the internet, including webpages,
|
159 |
+
image stock sites, and curated datasets released by the research community. We have excluded certain data, such as LAION, due to known CSAM concerns.
|
160 |
+
The model may be subject to bias from the original data source, as well as bias from LLMs and commercial APIs.
|
161 |
+
We strongly recommend users assess safety and fairness before applying to downstream applications.
|
162 |
|
163 |
|
164 |
# License
|
165 |
|
166 |
+
Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt). Please fill out a form at [here](https://forms.gle/ffPc9oZC2ZGeJ1N68) to consult the commercial use of model weights.
|
167 |
+
|
168 |
+
# Code acknowledgement
|
169 |
+
|
170 |
+
[LAVIS](https://github.com/salesforce/LAVIS) \
|
171 |
+
[openflamingo](https://github.com/mlfoundations/open_flamingo) \
|
172 |
+
[VLMEvalKit](https://github.com/open-compass/VLMEvalKit/tree/main)
|
173 |
+
|
174 |
+
|
175 |
+
# Citation
|
176 |
+
```
|
177 |
+
@misc{xgen_mm_phi3_mini,
|
178 |
+
title={xgen-mm-phi3-mini-base Model Card},
|
179 |
+
url={https://huggingface.co/Salesforce/xgen-mm-phi3-mini-base-r-v1},
|
180 |
+
author={Salesforce AI Research},
|
181 |
+
month={May},
|
182 |
+
year={2024}
|
183 |
+
}
|
184 |
+
```
|
185 |
|
186 |
# Troubleshoot
|
187 |
|
188 |
+
1. If you missed any packages, please consider the following
|
189 |
|
190 |
```
|
191 |
pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
|
added_tokens.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<pad>": 32011,
|
3 |
+
"<|assistant|>": 32001,
|
4 |
+
"<|endoftext|>": 32000,
|
5 |
+
"<|end|>": 32007,
|
6 |
+
"<|placeholder1|>": 32002,
|
7 |
+
"<|placeholder2|>": 32003,
|
8 |
+
"<|placeholder3|>": 32004,
|
9 |
+
"<|placeholder4|>": 32005,
|
10 |
+
"<|placeholder5|>": 32008,
|
11 |
+
"<|placeholder6|>": 32009,
|
12 |
+
"<|system|>": 32006,
|
13 |
+
"<|user|>": 32010
|
14 |
+
}
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"XGenMMModelForConditionalGeneration"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_xgenmm.XGenMMConfig",
|
7 |
+
"AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
|
8 |
+
},
|
9 |
+
"model_type": "xgenmm",
|
10 |
+
"text_config": {
|
11 |
+
"initial_tokenizer_len": 32012,
|
12 |
+
"model_type": "phi3",
|
13 |
+
"sliding_window": 2047,
|
14 |
+
"torch_dtype": "bfloat16"
|
15 |
+
},
|
16 |
+
"torch_dtype": "float32",
|
17 |
+
"transformers_version": "4.41.0.dev0",
|
18 |
+
"vision_encoder_config": {
|
19 |
+
"anyres_patch_sampling": false,
|
20 |
+
"image_aspect_ratio": "pad",
|
21 |
+
"model_type": "xgenmm_vision_encoder"
|
22 |
+
},
|
23 |
+
"vision_tokenizer_config": {
|
24 |
+
"model_type": "xgenmm_vision_tokenizer"
|
25 |
+
}
|
26 |
+
}
|
configuration_xgenmm.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from transformers import logging
|
3 |
+
from transformers import CONFIG_MAPPING
|
4 |
+
|
5 |
+
logger = logging.get_logger(__name__)
|
6 |
+
|
7 |
+
class XGenMMVisionEncoderConfig(PretrainedConfig):
|
8 |
+
model_type = "xgenmm_vision_encoder"
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
model_name: str = 'ViT-H-14-378-quickgelu',
|
12 |
+
force_image_size: int = 378,
|
13 |
+
**kwargs):
|
14 |
+
self.model_name = model_name
|
15 |
+
self.force_image_size = force_image_size
|
16 |
+
super().__init__(**kwargs)
|
17 |
+
|
18 |
+
|
19 |
+
class XGenMMVisionTokenizerConfig(PretrainedConfig):
|
20 |
+
model_type = "xgenmm_vision_tokenizer"
|
21 |
+
|
22 |
+
def __init__(self,
|
23 |
+
vis_feature_dim: int = 1280,
|
24 |
+
lang_embedding_dim: int = 3072,
|
25 |
+
**kwargs):
|
26 |
+
self.vis_feature_dim = vis_feature_dim
|
27 |
+
self.lang_embedding_dim = lang_embedding_dim
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
|
30 |
+
|
31 |
+
class XGenMMConfig(PretrainedConfig):
|
32 |
+
model_type = "xgenmm"
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
vision_encoder_config: dict = None,
|
36 |
+
vision_tokenizer_config: dict = None,
|
37 |
+
text_config: dict = None,
|
38 |
+
**kwargs):
|
39 |
+
|
40 |
+
if vision_encoder_config is None:
|
41 |
+
vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
|
42 |
+
logger.info("vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values.")
|
43 |
+
|
44 |
+
if vision_tokenizer_config is None:
|
45 |
+
vision_tokenizer_config = {}
|
46 |
+
logger.info("vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values.")
|
47 |
+
|
48 |
+
if text_config is None:
|
49 |
+
text_config = {
|
50 |
+
'initial_tokenizer_len':32012,
|
51 |
+
'pad_token_id':32011,
|
52 |
+
'bos_token_id':1,
|
53 |
+
'eos_token_id':32000,
|
54 |
+
'vocab_size': 32064,
|
55 |
+
'hidden_size': 3072,
|
56 |
+
'intermediate_size': 8192,
|
57 |
+
'num_hidden_layers': 32,
|
58 |
+
'num_attention_heads': 32,
|
59 |
+
'num_key_value_heads': 32,
|
60 |
+
'resid_pdrop': 0.0,
|
61 |
+
'embd_pdrop': 0.0,
|
62 |
+
'attention_dropout': 0.0,
|
63 |
+
'hidden_act': 'silu',
|
64 |
+
'max_position_embeddings': 4096,
|
65 |
+
'original_max_position_embeddings': 4096,
|
66 |
+
'initializer_range': 0.02,
|
67 |
+
'rms_norm_eps': 1e-05,
|
68 |
+
'use_cache': True,
|
69 |
+
'rope_theta': 10000.0,
|
70 |
+
'rope_scaling': None,
|
71 |
+
'sliding_window': 2047,
|
72 |
+
'return_dict': True,
|
73 |
+
'output_hidden_states': False,
|
74 |
+
'output_attentions': False,
|
75 |
+
'torchscript': False,
|
76 |
+
'torch_dtype': 'bfloat16',
|
77 |
+
'use_bfloat16': False,
|
78 |
+
'tf_legacy_loss': False,
|
79 |
+
'pruned_heads': {},
|
80 |
+
'tie_word_embeddings': False,
|
81 |
+
'chunk_size_feed_forward': 0,
|
82 |
+
'is_encoder_decoder': False,
|
83 |
+
'is_decoder': False,
|
84 |
+
'cross_attention_hidden_size': None,
|
85 |
+
'add_cross_attention': False,
|
86 |
+
'tie_encoder_decoder': False,
|
87 |
+
'max_length': 20,
|
88 |
+
'min_length': 0,
|
89 |
+
'do_sample': False,
|
90 |
+
'early_stopping': False,
|
91 |
+
'num_beams': 1,
|
92 |
+
'num_beam_groups': 1,
|
93 |
+
'diversity_penalty': 0.0,
|
94 |
+
'temperature': 1.0,
|
95 |
+
'top_k': 50,
|
96 |
+
'top_p': 1.0,
|
97 |
+
'typical_p': 1.0,
|
98 |
+
'repetition_penalty': 1.0,
|
99 |
+
'length_penalty': 1.0,
|
100 |
+
'no_repeat_ngram_size': 0,
|
101 |
+
'encoder_no_repeat_ngram_size': 0,
|
102 |
+
'bad_words_ids': None,
|
103 |
+
'num_return_sequences': 1,
|
104 |
+
'output_scores': False,
|
105 |
+
'return_dict_in_generate': False,
|
106 |
+
'forced_bos_token_id': None,
|
107 |
+
'forced_eos_token_id': None,
|
108 |
+
'remove_invalid_values': False,
|
109 |
+
'exponential_decay_length_penalty': None,
|
110 |
+
'suppress_tokens': None,
|
111 |
+
'begin_suppress_tokens': None,
|
112 |
+
'finetuning_task': None,
|
113 |
+
'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
|
114 |
+
'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
|
115 |
+
'tokenizer_class': None,
|
116 |
+
'prefix': None,
|
117 |
+
'bos_token_id': 1,
|
118 |
+
'pad_token_id': 32000,
|
119 |
+
'eos_token_id': 32000,
|
120 |
+
'sep_token_id': None,
|
121 |
+
'decoder_start_token_id': None,
|
122 |
+
'task_specific_params': None,
|
123 |
+
'problem_type': None,
|
124 |
+
'model_type': 'phi3'
|
125 |
+
}
|
126 |
+
logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
|
127 |
+
|
128 |
+
self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
|
129 |
+
|
130 |
+
self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(**vision_tokenizer_config)
|
131 |
+
|
132 |
+
text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
|
133 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
134 |
+
|
135 |
+
for key in ['initial_tokenizer_len', 'pad_token_id']:
|
136 |
+
if key not in self.text_config.to_dict():
|
137 |
+
raise ValueError(f"The key `{key}` is missing in the text_config.")
|
138 |
+
|
139 |
+
super().__init__(**kwargs)
|
140 |
+
|
141 |
+
@classmethod
|
142 |
+
def from_vision_encoder_vision_tokenizer_text_configs(
|
143 |
+
cls,
|
144 |
+
vision_encoder_config: XGenMMVisionEncoderConfig,
|
145 |
+
vision_tokenizer_config: XGenMMVisionTokenizerConfig,
|
146 |
+
text_config: PretrainedConfig,
|
147 |
+
**kwargs):
|
148 |
+
|
149 |
+
return cls(
|
150 |
+
vision_encoder_config=vision_encoder_config.to_dict(),
|
151 |
+
vision_tokenizer_config=vision_tokenizer_config.to_dict(),
|
152 |
+
text_config=text_config.to_dict(),
|
153 |
+
**kwargs,
|
154 |
+
)
|
155 |
+
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 32000,
|
5 |
+
"pad_token_id": 32000,
|
6 |
+
"transformers_version": "4.41.0.dev0"
|
7 |
+
}
|
image_processing_xgenmm.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
3 |
+
import torchvision.transforms.functional as F
|
4 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
5 |
+
CenterCrop, ColorJitter, Grayscale
|
6 |
+
import numbers
|
7 |
+
import torch
|
8 |
+
import ast
|
9 |
+
import math
|
10 |
+
from PIL import Image
|
11 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
12 |
+
from transformers.image_utils import ImageInput
|
13 |
+
from transformers.utils import TensorType
|
14 |
+
|
15 |
+
|
16 |
+
class XGenMMImageProcessor(BaseImageProcessor):
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
do_resize: bool = True,
|
21 |
+
resize_mode: str = "squash",
|
22 |
+
interpolation_mode: str = "bicubic",
|
23 |
+
size: Union[Tuple[int, int], List[int]] = None,
|
24 |
+
image_mean: Optional[Union[float, List[float]]] = None,
|
25 |
+
image_std: Optional[Union[float, List[float]]] = None,
|
26 |
+
**kwargs,
|
27 |
+
) -> None:
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
self.do_resize = do_resize
|
30 |
+
self.resize_mode = resize_mode
|
31 |
+
self.interpolation_mode = interpolation_mode
|
32 |
+
self.size = size if size is not None else (378, 378)
|
33 |
+
self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
|
34 |
+
self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
|
35 |
+
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
|
39 |
+
interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
|
40 |
+
if resize_mode == 'longest':
|
41 |
+
transforms = [
|
42 |
+
ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
|
43 |
+
CenterCropOrPad(image_size, fill=fill_color)
|
44 |
+
]
|
45 |
+
elif resize_mode == 'squash':
|
46 |
+
if isinstance(image_size, int):
|
47 |
+
image_size = (image_size, image_size)
|
48 |
+
transforms = [
|
49 |
+
Resize(image_size, interpolation=interpolation_mode),
|
50 |
+
]
|
51 |
+
else:
|
52 |
+
assert resize_mode == 'shortest'
|
53 |
+
if not isinstance(image_size, (tuple, list)):
|
54 |
+
image_size = (image_size, image_size)
|
55 |
+
if image_size[0] == image_size[1]:
|
56 |
+
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
|
57 |
+
transforms = [
|
58 |
+
Resize(image_size[0], interpolation=interpolation_mode)
|
59 |
+
]
|
60 |
+
else:
|
61 |
+
# resize shortest edge to matching target dim for non-square target
|
62 |
+
transforms = [ResizeKeepRatio(image_size)]
|
63 |
+
transforms += [CenterCrop(image_size)]
|
64 |
+
return transforms
|
65 |
+
|
66 |
+
@classmethod
|
67 |
+
def convert_rgb(cls, image):
|
68 |
+
return image.convert("RGB")
|
69 |
+
|
70 |
+
|
71 |
+
def _preprocess(self,
|
72 |
+
images: ImageInput
|
73 |
+
) -> torch.Tensor:
|
74 |
+
transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
|
75 |
+
transforms.extend([
|
76 |
+
self.convert_rgb,
|
77 |
+
ToTensor(),
|
78 |
+
Normalize(mean=self.image_mean, std=self.image_std)
|
79 |
+
])
|
80 |
+
composed_transforms = Compose(transforms)
|
81 |
+
images_tensor = composed_transforms(images)
|
82 |
+
return images_tensor
|
83 |
+
|
84 |
+
def preprocess(self,
|
85 |
+
images: ImageInput,
|
86 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
87 |
+
**kwargs) -> BatchFeature:
|
88 |
+
if 'image_aspect_ratio' in kwargs:
|
89 |
+
image_aspect_ratio = kwargs['image_aspect_ratio']
|
90 |
+
else:
|
91 |
+
image_aspect_ratio = 'pad'
|
92 |
+
new_images = []
|
93 |
+
if image_aspect_ratio == 'pad':
|
94 |
+
for image in images:
|
95 |
+
image = self._preprocess(image)
|
96 |
+
new_images.append(image)
|
97 |
+
else:
|
98 |
+
if isinstance(self.size, (tuple, list)):
|
99 |
+
base_img_size = self.size[0]
|
100 |
+
else:
|
101 |
+
raise ValueError("size should be list or tuple")
|
102 |
+
for image in images:
|
103 |
+
image = process_anyres_image(image, self._preprocess, self.size,
|
104 |
+
[
|
105 |
+
[base_img_size,base_img_size*2],
|
106 |
+
[base_img_size*2,base_img_size],
|
107 |
+
[base_img_size*2,base_img_size*2],
|
108 |
+
[base_img_size*3,base_img_size],
|
109 |
+
[base_img_size,base_img_size*3]
|
110 |
+
])
|
111 |
+
new_images.append(image)
|
112 |
+
|
113 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
114 |
+
new_images = torch.stack(new_images, dim=0)
|
115 |
+
if image_aspect_ratio == 'pad':
|
116 |
+
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(1).unsqueeze(0)}, tensor_type=return_tensors)
|
117 |
+
else:
|
118 |
+
new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0)}, tensor_type=return_tensors)
|
119 |
+
return new_images
|
120 |
+
# def preprocess(self,
|
121 |
+
# images: ImageInput,
|
122 |
+
# return_tensors: Optional[Union[str, TensorType]] = None,
|
123 |
+
# **kwargs) -> BatchFeature:
|
124 |
+
# transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
|
125 |
+
# transforms.extend([
|
126 |
+
# self.convert_rgb,
|
127 |
+
# ToTensor(),
|
128 |
+
# Normalize(mean=self.image_mean, std=self.image_std)
|
129 |
+
# ])
|
130 |
+
# composed_transforms = Compose(transforms)
|
131 |
+
# images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
|
132 |
+
# encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
|
133 |
+
# return encoded_outputs
|
134 |
+
|
135 |
+
|
136 |
+
class ResizeKeepRatio:
|
137 |
+
""" Resize and Keep Ratio
|
138 |
+
|
139 |
+
Copy & paste from `timm`
|
140 |
+
"""
|
141 |
+
|
142 |
+
def __init__(
|
143 |
+
self,
|
144 |
+
size,
|
145 |
+
longest=0.,
|
146 |
+
interpolation=InterpolationMode.BICUBIC,
|
147 |
+
random_scale_prob=0.,
|
148 |
+
random_scale_range=(0.85, 1.05),
|
149 |
+
random_aspect_prob=0.,
|
150 |
+
random_aspect_range=(0.9, 1.11)
|
151 |
+
):
|
152 |
+
if isinstance(size, (list, tuple)):
|
153 |
+
self.size = tuple(size)
|
154 |
+
else:
|
155 |
+
self.size = (size, size)
|
156 |
+
self.interpolation = interpolation
|
157 |
+
self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
|
158 |
+
self.random_scale_prob = random_scale_prob
|
159 |
+
self.random_scale_range = random_scale_range
|
160 |
+
self.random_aspect_prob = random_aspect_prob
|
161 |
+
self.random_aspect_range = random_aspect_range
|
162 |
+
|
163 |
+
@staticmethod
|
164 |
+
def get_params(
|
165 |
+
img,
|
166 |
+
target_size,
|
167 |
+
longest,
|
168 |
+
random_scale_prob=0.,
|
169 |
+
random_scale_range=(0.85, 1.05),
|
170 |
+
random_aspect_prob=0.,
|
171 |
+
random_aspect_range=(0.9, 1.11)
|
172 |
+
):
|
173 |
+
"""Get parameters
|
174 |
+
"""
|
175 |
+
source_size = img.size[::-1] # h, w
|
176 |
+
h, w = source_size
|
177 |
+
target_h, target_w = target_size
|
178 |
+
ratio_h = h / target_h
|
179 |
+
ratio_w = w / target_w
|
180 |
+
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
|
181 |
+
if random_scale_prob > 0 and random.random() < random_scale_prob:
|
182 |
+
ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
|
183 |
+
ratio_factor = (ratio_factor, ratio_factor)
|
184 |
+
else:
|
185 |
+
ratio_factor = (1., 1.)
|
186 |
+
if random_aspect_prob > 0 and random.random() < random_aspect_prob:
|
187 |
+
aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
|
188 |
+
ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
|
189 |
+
size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
|
190 |
+
return size
|
191 |
+
|
192 |
+
def __call__(self, img):
|
193 |
+
"""
|
194 |
+
Args:
|
195 |
+
img (PIL Image): Image to be cropped and resized.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
|
199 |
+
"""
|
200 |
+
size = self.get_params(
|
201 |
+
img, self.size, self.longest,
|
202 |
+
self.random_scale_prob, self.random_scale_range,
|
203 |
+
self.random_aspect_prob, self.random_aspect_range
|
204 |
+
)
|
205 |
+
img = F.resize(img, size, self.interpolation)
|
206 |
+
return img
|
207 |
+
|
208 |
+
def __repr__(self):
|
209 |
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
210 |
+
format_string += f', interpolation={self.interpolation})'
|
211 |
+
format_string += f', longest={self.longest:.3f})'
|
212 |
+
return format_string
|
213 |
+
|
214 |
+
def _setup_size(size, error_msg):
|
215 |
+
if isinstance(size, numbers.Number):
|
216 |
+
return int(size), int(size)
|
217 |
+
|
218 |
+
if isinstance(size, Sequence) and len(size) == 1:
|
219 |
+
return size[0], size[0]
|
220 |
+
|
221 |
+
if len(size) != 2:
|
222 |
+
raise ValueError(error_msg)
|
223 |
+
|
224 |
+
return size
|
225 |
+
|
226 |
+
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
|
227 |
+
"""Center crops and/or pads the given image.
|
228 |
+
If the image is torch Tensor, it is expected
|
229 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
230 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
img (PIL Image or Tensor): Image to be cropped.
|
234 |
+
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
|
235 |
+
it is used for both directions.
|
236 |
+
fill (int, Tuple[int]): Padding color
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
PIL Image or Tensor: Cropped image.
|
240 |
+
"""
|
241 |
+
if isinstance(output_size, numbers.Number):
|
242 |
+
output_size = (int(output_size), int(output_size))
|
243 |
+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
|
244 |
+
output_size = (output_size[0], output_size[0])
|
245 |
+
|
246 |
+
_, image_height, image_width = F.get_dimensions(img)
|
247 |
+
crop_height, crop_width = output_size
|
248 |
+
|
249 |
+
if crop_width > image_width or crop_height > image_height:
|
250 |
+
padding_ltrb = [
|
251 |
+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
|
252 |
+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
|
253 |
+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
|
254 |
+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
|
255 |
+
]
|
256 |
+
img = F.pad(img, padding_ltrb, fill=fill)
|
257 |
+
_, image_height, image_width = F.get_dimensions(img)
|
258 |
+
if crop_width == image_width and crop_height == image_height:
|
259 |
+
return img
|
260 |
+
|
261 |
+
crop_top = int(round((image_height - crop_height) / 2.0))
|
262 |
+
crop_left = int(round((image_width - crop_width) / 2.0))
|
263 |
+
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
|
264 |
+
|
265 |
+
class CenterCropOrPad(torch.nn.Module):
|
266 |
+
"""Crops the given image at the center.
|
267 |
+
If the image is torch Tensor, it is expected
|
268 |
+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
269 |
+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
273 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
274 |
+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, size, fill=0):
|
278 |
+
super().__init__()
|
279 |
+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
|
280 |
+
self.fill = fill
|
281 |
+
|
282 |
+
def forward(self, img):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
img (PIL Image or Tensor): Image to be cropped.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
PIL Image or Tensor: Cropped image.
|
289 |
+
"""
|
290 |
+
return center_crop_or_pad(img, self.size, fill=self.fill)
|
291 |
+
|
292 |
+
def __repr__(self) -> str:
|
293 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
294 |
+
|
295 |
+
def process_anyres_image(image, processor, processor_size, grid_pinpoints):
|
296 |
+
"""
|
297 |
+
Process an image with variable resolutions.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
image (PIL.Image.Image): The input image to be processed.
|
301 |
+
processor: The image processor object.
|
302 |
+
processor_size (tuple, list): The size of the image processor.
|
303 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
304 |
+
|
305 |
+
Returns:
|
306 |
+
torch.Tensor: A tensor containing the processed image patches.
|
307 |
+
"""
|
308 |
+
# FIXME: determine grid_pinpoints from image sizes.
|
309 |
+
if type(grid_pinpoints) is list:
|
310 |
+
possible_resolutions = grid_pinpoints
|
311 |
+
else:
|
312 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
313 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
314 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
315 |
+
|
316 |
+
# processor_size = processor.transforms[0].size
|
317 |
+
patches = divide_to_patches(image_padded, processor_size[0])
|
318 |
+
|
319 |
+
image_original_resize = image.resize((processor_size[0], processor_size[0]))
|
320 |
+
|
321 |
+
image_patches = [image_original_resize] + patches
|
322 |
+
image_patches = [processor(image_patch)
|
323 |
+
for image_patch in image_patches]
|
324 |
+
return torch.stack(image_patches, dim=0)
|
325 |
+
|
326 |
+
|
327 |
+
def select_best_resolution(original_size, possible_resolutions):
|
328 |
+
"""
|
329 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
333 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
tuple: The best fit resolution in the format (width, height).
|
337 |
+
"""
|
338 |
+
original_width, original_height = original_size
|
339 |
+
best_fit = None
|
340 |
+
max_effective_resolution = 0
|
341 |
+
min_wasted_resolution = float('inf')
|
342 |
+
|
343 |
+
for width, height in possible_resolutions:
|
344 |
+
scale = min(width / original_width, height / original_height)
|
345 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
346 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
347 |
+
wasted_resolution = (width * height) - effective_resolution
|
348 |
+
|
349 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
350 |
+
max_effective_resolution = effective_resolution
|
351 |
+
min_wasted_resolution = wasted_resolution
|
352 |
+
best_fit = (width, height)
|
353 |
+
|
354 |
+
return best_fit
|
355 |
+
|
356 |
+
def resize_and_pad_image(image, target_resolution):
|
357 |
+
"""
|
358 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
image (PIL.Image.Image): The input image.
|
362 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
PIL.Image.Image: The resized and padded image.
|
366 |
+
"""
|
367 |
+
original_width, original_height = image.size
|
368 |
+
target_width, target_height = target_resolution
|
369 |
+
|
370 |
+
scale_w = target_width / original_width
|
371 |
+
scale_h = target_height / original_height
|
372 |
+
|
373 |
+
if scale_w < scale_h:
|
374 |
+
new_width = target_width
|
375 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
376 |
+
else:
|
377 |
+
new_height = target_height
|
378 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
379 |
+
|
380 |
+
# Resize the image
|
381 |
+
resized_image = image.resize((new_width, new_height))
|
382 |
+
|
383 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
384 |
+
paste_x = (target_width - new_width) // 2
|
385 |
+
paste_y = (target_height - new_height) // 2
|
386 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
387 |
+
|
388 |
+
return new_image
|
389 |
+
|
390 |
+
def divide_to_patches(image, patch_size):
|
391 |
+
"""
|
392 |
+
Divides an image into patches of a specified size.
|
393 |
+
|
394 |
+
Args:
|
395 |
+
image (PIL.Image.Image): The input image.
|
396 |
+
patch_size (int): The size of each patch.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
400 |
+
"""
|
401 |
+
patches = []
|
402 |
+
width, height = image.size
|
403 |
+
for i in range(0, height, patch_size):
|
404 |
+
for j in range(0, width, patch_size):
|
405 |
+
box = (j, i, j + patch_size, i + patch_size)
|
406 |
+
patch = image.crop(box)
|
407 |
+
patches.append(patch)
|
408 |
+
|
409 |
+
return patches
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca8bbede81e4a6265ab0d3fb2af2ed7b3bf64af352d2af5a81c8117723103569
|
3 |
+
size 4954761920
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:531b227276744050439141b2c6a6429c2b2d72b83717472dfce35194d63667d5
|
3 |
+
size 4983112128
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0067a3e761a18e08eab4a48142845652329a76c3b2a0fa6780cc10aabe1b89ec
|
3 |
+
size 4983112168
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5d34abbbba4bfd0fca5c99b17d7ff5d26d3e1a9d8fbb3cce8127e9e521a10dd4
|
3 |
+
size 3414256548
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 18335156492
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"vlm.lang_model.lm_head.additional_fc.bias": "model-00004-of-00004.safetensors",
|
7 |
+
"vlm.lang_model.lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
|
8 |
+
"vlm.lang_model.lm_head.bias": "model-00004-of-00004.safetensors",
|
9 |
+
"vlm.lang_model.lm_head.weight": "model-00004-of-00004.safetensors",
|
10 |
+
"vlm.lang_model.model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
|
11 |
+
"vlm.lang_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
|
12 |
+
"vlm.lang_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
13 |
+
"vlm.lang_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
14 |
+
"vlm.lang_model.model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
15 |
+
"vlm.lang_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
16 |
+
"vlm.lang_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
17 |
+
"vlm.lang_model.model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
18 |
+
"vlm.lang_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
19 |
+
"vlm.lang_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
20 |
+
"vlm.lang_model.model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
21 |
+
"vlm.lang_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
22 |
+
"vlm.lang_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
23 |
+
"vlm.lang_model.model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
24 |
+
"vlm.lang_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
25 |
+
"vlm.lang_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
26 |
+
"vlm.lang_model.model.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
27 |
+
"vlm.lang_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
28 |
+
"vlm.lang_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
29 |
+
"vlm.lang_model.model.layers.10.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
30 |
+
"vlm.lang_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
31 |
+
"vlm.lang_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
32 |
+
"vlm.lang_model.model.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
33 |
+
"vlm.lang_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
34 |
+
"vlm.lang_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
35 |
+
"vlm.lang_model.model.layers.11.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
36 |
+
"vlm.lang_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
37 |
+
"vlm.lang_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
38 |
+
"vlm.lang_model.model.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
39 |
+
"vlm.lang_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
40 |
+
"vlm.lang_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
41 |
+
"vlm.lang_model.model.layers.12.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
42 |
+
"vlm.lang_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
43 |
+
"vlm.lang_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
44 |
+
"vlm.lang_model.model.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
45 |
+
"vlm.lang_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
46 |
+
"vlm.lang_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
47 |
+
"vlm.lang_model.model.layers.13.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
48 |
+
"vlm.lang_model.model.layers.14.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
49 |
+
"vlm.lang_model.model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
50 |
+
"vlm.lang_model.model.layers.14.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
51 |
+
"vlm.lang_model.model.layers.14.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
52 |
+
"vlm.lang_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
53 |
+
"vlm.lang_model.model.layers.14.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
54 |
+
"vlm.lang_model.model.layers.15.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
55 |
+
"vlm.lang_model.model.layers.15.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
56 |
+
"vlm.lang_model.model.layers.15.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
57 |
+
"vlm.lang_model.model.layers.15.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
58 |
+
"vlm.lang_model.model.layers.15.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
59 |
+
"vlm.lang_model.model.layers.15.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
60 |
+
"vlm.lang_model.model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
61 |
+
"vlm.lang_model.model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
62 |
+
"vlm.lang_model.model.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
63 |
+
"vlm.lang_model.model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
64 |
+
"vlm.lang_model.model.layers.16.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
65 |
+
"vlm.lang_model.model.layers.16.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
66 |
+
"vlm.lang_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
67 |
+
"vlm.lang_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
68 |
+
"vlm.lang_model.model.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
69 |
+
"vlm.lang_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
70 |
+
"vlm.lang_model.model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
71 |
+
"vlm.lang_model.model.layers.17.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
72 |
+
"vlm.lang_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
73 |
+
"vlm.lang_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
74 |
+
"vlm.lang_model.model.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
75 |
+
"vlm.lang_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
76 |
+
"vlm.lang_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
77 |
+
"vlm.lang_model.model.layers.18.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
78 |
+
"vlm.lang_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
79 |
+
"vlm.lang_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
80 |
+
"vlm.lang_model.model.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
81 |
+
"vlm.lang_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
82 |
+
"vlm.lang_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
83 |
+
"vlm.lang_model.model.layers.19.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
84 |
+
"vlm.lang_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
85 |
+
"vlm.lang_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
86 |
+
"vlm.lang_model.model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
|
87 |
+
"vlm.lang_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
88 |
+
"vlm.lang_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
89 |
+
"vlm.lang_model.model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
90 |
+
"vlm.lang_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
91 |
+
"vlm.lang_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
92 |
+
"vlm.lang_model.model.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
93 |
+
"vlm.lang_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
94 |
+
"vlm.lang_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
95 |
+
"vlm.lang_model.model.layers.20.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
96 |
+
"vlm.lang_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
97 |
+
"vlm.lang_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
98 |
+
"vlm.lang_model.model.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
99 |
+
"vlm.lang_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
100 |
+
"vlm.lang_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
101 |
+
"vlm.lang_model.model.layers.21.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
102 |
+
"vlm.lang_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
103 |
+
"vlm.lang_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
104 |
+
"vlm.lang_model.model.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
105 |
+
"vlm.lang_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
106 |
+
"vlm.lang_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
107 |
+
"vlm.lang_model.model.layers.22.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
108 |
+
"vlm.lang_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
109 |
+
"vlm.lang_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
110 |
+
"vlm.lang_model.model.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
111 |
+
"vlm.lang_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
112 |
+
"vlm.lang_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
113 |
+
"vlm.lang_model.model.layers.23.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
114 |
+
"vlm.lang_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
115 |
+
"vlm.lang_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
116 |
+
"vlm.lang_model.model.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
|
117 |
+
"vlm.lang_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
118 |
+
"vlm.lang_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
119 |
+
"vlm.lang_model.model.layers.24.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
120 |
+
"vlm.lang_model.model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
121 |
+
"vlm.lang_model.model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
122 |
+
"vlm.lang_model.model.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
123 |
+
"vlm.lang_model.model.layers.25.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
124 |
+
"vlm.lang_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
125 |
+
"vlm.lang_model.model.layers.25.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
|
126 |
+
"vlm.lang_model.model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
127 |
+
"vlm.lang_model.model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
128 |
+
"vlm.lang_model.model.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
129 |
+
"vlm.lang_model.model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
130 |
+
"vlm.lang_model.model.layers.26.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
131 |
+
"vlm.lang_model.model.layers.26.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
132 |
+
"vlm.lang_model.model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
133 |
+
"vlm.lang_model.model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
134 |
+
"vlm.lang_model.model.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
135 |
+
"vlm.lang_model.model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
136 |
+
"vlm.lang_model.model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
137 |
+
"vlm.lang_model.model.layers.27.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
138 |
+
"vlm.lang_model.model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
139 |
+
"vlm.lang_model.model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
140 |
+
"vlm.lang_model.model.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
141 |
+
"vlm.lang_model.model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
142 |
+
"vlm.lang_model.model.layers.28.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
143 |
+
"vlm.lang_model.model.layers.28.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
144 |
+
"vlm.lang_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
145 |
+
"vlm.lang_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
146 |
+
"vlm.lang_model.model.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
147 |
+
"vlm.lang_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
148 |
+
"vlm.lang_model.model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
149 |
+
"vlm.lang_model.model.layers.29.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
150 |
+
"vlm.lang_model.model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
151 |
+
"vlm.lang_model.model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
152 |
+
"vlm.lang_model.model.layers.3.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
153 |
+
"vlm.lang_model.model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
154 |
+
"vlm.lang_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
155 |
+
"vlm.lang_model.model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
|
156 |
+
"vlm.lang_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
157 |
+
"vlm.lang_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
158 |
+
"vlm.lang_model.model.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
159 |
+
"vlm.lang_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
160 |
+
"vlm.lang_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
161 |
+
"vlm.lang_model.model.layers.30.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
162 |
+
"vlm.lang_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
163 |
+
"vlm.lang_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
164 |
+
"vlm.lang_model.model.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
|
165 |
+
"vlm.lang_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
166 |
+
"vlm.lang_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
167 |
+
"vlm.lang_model.model.layers.31.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
|
168 |
+
"vlm.lang_model.model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
169 |
+
"vlm.lang_model.model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
170 |
+
"vlm.lang_model.model.layers.4.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
171 |
+
"vlm.lang_model.model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
172 |
+
"vlm.lang_model.model.layers.4.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
173 |
+
"vlm.lang_model.model.layers.4.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
174 |
+
"vlm.lang_model.model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
175 |
+
"vlm.lang_model.model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
176 |
+
"vlm.lang_model.model.layers.5.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
177 |
+
"vlm.lang_model.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
178 |
+
"vlm.lang_model.model.layers.5.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
179 |
+
"vlm.lang_model.model.layers.5.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
180 |
+
"vlm.lang_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
181 |
+
"vlm.lang_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
182 |
+
"vlm.lang_model.model.layers.6.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
183 |
+
"vlm.lang_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
184 |
+
"vlm.lang_model.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
185 |
+
"vlm.lang_model.model.layers.6.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
186 |
+
"vlm.lang_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
187 |
+
"vlm.lang_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
188 |
+
"vlm.lang_model.model.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
189 |
+
"vlm.lang_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
190 |
+
"vlm.lang_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
191 |
+
"vlm.lang_model.model.layers.7.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
192 |
+
"vlm.lang_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
193 |
+
"vlm.lang_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
194 |
+
"vlm.lang_model.model.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
195 |
+
"vlm.lang_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
196 |
+
"vlm.lang_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
197 |
+
"vlm.lang_model.model.layers.8.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
198 |
+
"vlm.lang_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
199 |
+
"vlm.lang_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
200 |
+
"vlm.lang_model.model.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
|
201 |
+
"vlm.lang_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
202 |
+
"vlm.lang_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
203 |
+
"vlm.lang_model.model.layers.9.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
|
204 |
+
"vlm.lang_model.model.norm.weight": "model-00004-of-00004.safetensors",
|
205 |
+
"vlm.vision_encoder.class_embedding": "model-00001-of-00004.safetensors",
|
206 |
+
"vlm.vision_encoder.conv1.weight": "model-00001-of-00004.safetensors",
|
207 |
+
"vlm.vision_encoder.ln_post.bias": "model-00001-of-00004.safetensors",
|
208 |
+
"vlm.vision_encoder.ln_post.weight": "model-00001-of-00004.safetensors",
|
209 |
+
"vlm.vision_encoder.ln_pre.bias": "model-00001-of-00004.safetensors",
|
210 |
+
"vlm.vision_encoder.ln_pre.weight": "model-00001-of-00004.safetensors",
|
211 |
+
"vlm.vision_encoder.positional_embedding": "model-00001-of-00004.safetensors",
|
212 |
+
"vlm.vision_encoder.proj": "model-00001-of-00004.safetensors",
|
213 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
214 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
215 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
216 |
+
"vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
217 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_1.bias": "model-00001-of-00004.safetensors",
|
218 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_1.weight": "model-00001-of-00004.safetensors",
|
219 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_2.bias": "model-00001-of-00004.safetensors",
|
220 |
+
"vlm.vision_encoder.transformer.resblocks.0.ln_2.weight": "model-00001-of-00004.safetensors",
|
221 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
222 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
223 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
224 |
+
"vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
225 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
226 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
227 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
228 |
+
"vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
229 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_1.bias": "model-00001-of-00004.safetensors",
|
230 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_1.weight": "model-00001-of-00004.safetensors",
|
231 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_2.bias": "model-00001-of-00004.safetensors",
|
232 |
+
"vlm.vision_encoder.transformer.resblocks.1.ln_2.weight": "model-00001-of-00004.safetensors",
|
233 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
234 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
235 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
236 |
+
"vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
237 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
238 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
239 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
240 |
+
"vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
241 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_1.bias": "model-00001-of-00004.safetensors",
|
242 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_1.weight": "model-00001-of-00004.safetensors",
|
243 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_2.bias": "model-00001-of-00004.safetensors",
|
244 |
+
"vlm.vision_encoder.transformer.resblocks.10.ln_2.weight": "model-00001-of-00004.safetensors",
|
245 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
246 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
247 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
248 |
+
"vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
249 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
250 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
251 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
252 |
+
"vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
253 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_1.bias": "model-00001-of-00004.safetensors",
|
254 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_1.weight": "model-00001-of-00004.safetensors",
|
255 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_2.bias": "model-00001-of-00004.safetensors",
|
256 |
+
"vlm.vision_encoder.transformer.resblocks.11.ln_2.weight": "model-00001-of-00004.safetensors",
|
257 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
258 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
259 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
260 |
+
"vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
261 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
262 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
263 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
264 |
+
"vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
265 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_1.bias": "model-00001-of-00004.safetensors",
|
266 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_1.weight": "model-00001-of-00004.safetensors",
|
267 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_2.bias": "model-00001-of-00004.safetensors",
|
268 |
+
"vlm.vision_encoder.transformer.resblocks.12.ln_2.weight": "model-00001-of-00004.safetensors",
|
269 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
270 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
271 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
272 |
+
"vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
273 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
274 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
275 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
276 |
+
"vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
277 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_1.bias": "model-00001-of-00004.safetensors",
|
278 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_1.weight": "model-00001-of-00004.safetensors",
|
279 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_2.bias": "model-00001-of-00004.safetensors",
|
280 |
+
"vlm.vision_encoder.transformer.resblocks.13.ln_2.weight": "model-00001-of-00004.safetensors",
|
281 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
282 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
283 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
284 |
+
"vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
285 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
286 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
287 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
288 |
+
"vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
289 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_1.bias": "model-00001-of-00004.safetensors",
|
290 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_1.weight": "model-00001-of-00004.safetensors",
|
291 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_2.bias": "model-00001-of-00004.safetensors",
|
292 |
+
"vlm.vision_encoder.transformer.resblocks.14.ln_2.weight": "model-00001-of-00004.safetensors",
|
293 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
294 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
295 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
296 |
+
"vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
297 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
298 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
299 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
300 |
+
"vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
301 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_1.bias": "model-00001-of-00004.safetensors",
|
302 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_1.weight": "model-00001-of-00004.safetensors",
|
303 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_2.bias": "model-00001-of-00004.safetensors",
|
304 |
+
"vlm.vision_encoder.transformer.resblocks.15.ln_2.weight": "model-00001-of-00004.safetensors",
|
305 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
306 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
307 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
308 |
+
"vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
309 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
310 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
311 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
312 |
+
"vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
313 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_1.bias": "model-00001-of-00004.safetensors",
|
314 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_1.weight": "model-00001-of-00004.safetensors",
|
315 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_2.bias": "model-00001-of-00004.safetensors",
|
316 |
+
"vlm.vision_encoder.transformer.resblocks.16.ln_2.weight": "model-00001-of-00004.safetensors",
|
317 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
318 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
319 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
320 |
+
"vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
321 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
322 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
323 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
324 |
+
"vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
325 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_1.bias": "model-00001-of-00004.safetensors",
|
326 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_1.weight": "model-00001-of-00004.safetensors",
|
327 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_2.bias": "model-00001-of-00004.safetensors",
|
328 |
+
"vlm.vision_encoder.transformer.resblocks.17.ln_2.weight": "model-00001-of-00004.safetensors",
|
329 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
330 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
331 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
332 |
+
"vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
333 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
334 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
335 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
336 |
+
"vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
337 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_1.bias": "model-00001-of-00004.safetensors",
|
338 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_1.weight": "model-00001-of-00004.safetensors",
|
339 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_2.bias": "model-00001-of-00004.safetensors",
|
340 |
+
"vlm.vision_encoder.transformer.resblocks.18.ln_2.weight": "model-00001-of-00004.safetensors",
|
341 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
342 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
343 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
344 |
+
"vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
345 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
346 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
347 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
348 |
+
"vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
349 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_1.bias": "model-00001-of-00004.safetensors",
|
350 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_1.weight": "model-00001-of-00004.safetensors",
|
351 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_2.bias": "model-00001-of-00004.safetensors",
|
352 |
+
"vlm.vision_encoder.transformer.resblocks.19.ln_2.weight": "model-00001-of-00004.safetensors",
|
353 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
354 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
355 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
356 |
+
"vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
357 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
358 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
359 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
360 |
+
"vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
361 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_1.bias": "model-00001-of-00004.safetensors",
|
362 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_1.weight": "model-00001-of-00004.safetensors",
|
363 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_2.bias": "model-00001-of-00004.safetensors",
|
364 |
+
"vlm.vision_encoder.transformer.resblocks.2.ln_2.weight": "model-00001-of-00004.safetensors",
|
365 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
366 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
367 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
368 |
+
"vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
369 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
370 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
371 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
372 |
+
"vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
373 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_1.bias": "model-00001-of-00004.safetensors",
|
374 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_1.weight": "model-00001-of-00004.safetensors",
|
375 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_2.bias": "model-00001-of-00004.safetensors",
|
376 |
+
"vlm.vision_encoder.transformer.resblocks.20.ln_2.weight": "model-00001-of-00004.safetensors",
|
377 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
378 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
379 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
380 |
+
"vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
381 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
382 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
383 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
384 |
+
"vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
385 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_1.bias": "model-00001-of-00004.safetensors",
|
386 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_1.weight": "model-00001-of-00004.safetensors",
|
387 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_2.bias": "model-00001-of-00004.safetensors",
|
388 |
+
"vlm.vision_encoder.transformer.resblocks.21.ln_2.weight": "model-00001-of-00004.safetensors",
|
389 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
390 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
391 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
392 |
+
"vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
393 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
394 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
395 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
396 |
+
"vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
397 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_1.bias": "model-00001-of-00004.safetensors",
|
398 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_1.weight": "model-00001-of-00004.safetensors",
|
399 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_2.bias": "model-00001-of-00004.safetensors",
|
400 |
+
"vlm.vision_encoder.transformer.resblocks.22.ln_2.weight": "model-00001-of-00004.safetensors",
|
401 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
402 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
403 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
404 |
+
"vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
405 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
406 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
407 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
408 |
+
"vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
409 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_1.bias": "model-00001-of-00004.safetensors",
|
410 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_1.weight": "model-00001-of-00004.safetensors",
|
411 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_2.bias": "model-00001-of-00004.safetensors",
|
412 |
+
"vlm.vision_encoder.transformer.resblocks.23.ln_2.weight": "model-00001-of-00004.safetensors",
|
413 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
414 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
415 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
416 |
+
"vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
417 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
418 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
419 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
420 |
+
"vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
421 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_1.bias": "model-00001-of-00004.safetensors",
|
422 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_1.weight": "model-00001-of-00004.safetensors",
|
423 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_2.bias": "model-00001-of-00004.safetensors",
|
424 |
+
"vlm.vision_encoder.transformer.resblocks.24.ln_2.weight": "model-00001-of-00004.safetensors",
|
425 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
426 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
427 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
428 |
+
"vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
429 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
430 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
431 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
432 |
+
"vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
433 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_1.bias": "model-00001-of-00004.safetensors",
|
434 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_1.weight": "model-00001-of-00004.safetensors",
|
435 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_2.bias": "model-00001-of-00004.safetensors",
|
436 |
+
"vlm.vision_encoder.transformer.resblocks.25.ln_2.weight": "model-00001-of-00004.safetensors",
|
437 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
438 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
439 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
440 |
+
"vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
441 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
442 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
443 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
444 |
+
"vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
445 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_1.bias": "model-00001-of-00004.safetensors",
|
446 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_1.weight": "model-00001-of-00004.safetensors",
|
447 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_2.bias": "model-00001-of-00004.safetensors",
|
448 |
+
"vlm.vision_encoder.transformer.resblocks.26.ln_2.weight": "model-00001-of-00004.safetensors",
|
449 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
450 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
451 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
452 |
+
"vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
453 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
454 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
455 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
456 |
+
"vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
457 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_1.bias": "model-00001-of-00004.safetensors",
|
458 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_1.weight": "model-00001-of-00004.safetensors",
|
459 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_2.bias": "model-00001-of-00004.safetensors",
|
460 |
+
"vlm.vision_encoder.transformer.resblocks.27.ln_2.weight": "model-00001-of-00004.safetensors",
|
461 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
462 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
463 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
464 |
+
"vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
465 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
466 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
467 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
468 |
+
"vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
469 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_1.bias": "model-00001-of-00004.safetensors",
|
470 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_1.weight": "model-00001-of-00004.safetensors",
|
471 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_2.bias": "model-00001-of-00004.safetensors",
|
472 |
+
"vlm.vision_encoder.transformer.resblocks.28.ln_2.weight": "model-00001-of-00004.safetensors",
|
473 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
474 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
475 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
476 |
+
"vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
477 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
478 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
479 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
480 |
+
"vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
481 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_1.bias": "model-00001-of-00004.safetensors",
|
482 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_1.weight": "model-00001-of-00004.safetensors",
|
483 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_2.bias": "model-00001-of-00004.safetensors",
|
484 |
+
"vlm.vision_encoder.transformer.resblocks.29.ln_2.weight": "model-00001-of-00004.safetensors",
|
485 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
486 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
487 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
488 |
+
"vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
489 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
490 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
491 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
492 |
+
"vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
493 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_1.bias": "model-00001-of-00004.safetensors",
|
494 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_1.weight": "model-00001-of-00004.safetensors",
|
495 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_2.bias": "model-00001-of-00004.safetensors",
|
496 |
+
"vlm.vision_encoder.transformer.resblocks.3.ln_2.weight": "model-00001-of-00004.safetensors",
|
497 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
498 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
499 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
500 |
+
"vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
501 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
502 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
503 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
504 |
+
"vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
505 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_1.bias": "model-00001-of-00004.safetensors",
|
506 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_1.weight": "model-00001-of-00004.safetensors",
|
507 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_2.bias": "model-00001-of-00004.safetensors",
|
508 |
+
"vlm.vision_encoder.transformer.resblocks.30.ln_2.weight": "model-00001-of-00004.safetensors",
|
509 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
510 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
511 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
512 |
+
"vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
513 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
514 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
515 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
516 |
+
"vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
517 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_1.bias": "model-00001-of-00004.safetensors",
|
518 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_1.weight": "model-00001-of-00004.safetensors",
|
519 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_2.bias": "model-00001-of-00004.safetensors",
|
520 |
+
"vlm.vision_encoder.transformer.resblocks.31.ln_2.weight": "model-00001-of-00004.safetensors",
|
521 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
522 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
523 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
524 |
+
"vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
525 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
526 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
527 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
528 |
+
"vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
529 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_1.bias": "model-00001-of-00004.safetensors",
|
530 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_1.weight": "model-00001-of-00004.safetensors",
|
531 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_2.bias": "model-00001-of-00004.safetensors",
|
532 |
+
"vlm.vision_encoder.transformer.resblocks.4.ln_2.weight": "model-00001-of-00004.safetensors",
|
533 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
534 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
535 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
536 |
+
"vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
537 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
538 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
539 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
540 |
+
"vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
541 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_1.bias": "model-00001-of-00004.safetensors",
|
542 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_1.weight": "model-00001-of-00004.safetensors",
|
543 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_2.bias": "model-00001-of-00004.safetensors",
|
544 |
+
"vlm.vision_encoder.transformer.resblocks.5.ln_2.weight": "model-00001-of-00004.safetensors",
|
545 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
546 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
547 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
548 |
+
"vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
549 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
550 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
551 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
552 |
+
"vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
553 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_1.bias": "model-00001-of-00004.safetensors",
|
554 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_1.weight": "model-00001-of-00004.safetensors",
|
555 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_2.bias": "model-00001-of-00004.safetensors",
|
556 |
+
"vlm.vision_encoder.transformer.resblocks.6.ln_2.weight": "model-00001-of-00004.safetensors",
|
557 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
558 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
559 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
560 |
+
"vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
561 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
562 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
563 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
564 |
+
"vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
565 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_1.bias": "model-00001-of-00004.safetensors",
|
566 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_1.weight": "model-00001-of-00004.safetensors",
|
567 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_2.bias": "model-00001-of-00004.safetensors",
|
568 |
+
"vlm.vision_encoder.transformer.resblocks.7.ln_2.weight": "model-00001-of-00004.safetensors",
|
569 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
570 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
571 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
572 |
+
"vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
573 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
574 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
575 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
576 |
+
"vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
577 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_1.bias": "model-00001-of-00004.safetensors",
|
578 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_1.weight": "model-00001-of-00004.safetensors",
|
579 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_2.bias": "model-00001-of-00004.safetensors",
|
580 |
+
"vlm.vision_encoder.transformer.resblocks.8.ln_2.weight": "model-00001-of-00004.safetensors",
|
581 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
582 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
583 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
584 |
+
"vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
585 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_bias": "model-00001-of-00004.safetensors",
|
586 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_weight": "model-00001-of-00004.safetensors",
|
587 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.bias": "model-00001-of-00004.safetensors",
|
588 |
+
"vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.weight": "model-00001-of-00004.safetensors",
|
589 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_1.bias": "model-00001-of-00004.safetensors",
|
590 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_1.weight": "model-00001-of-00004.safetensors",
|
591 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_2.bias": "model-00001-of-00004.safetensors",
|
592 |
+
"vlm.vision_encoder.transformer.resblocks.9.ln_2.weight": "model-00001-of-00004.safetensors",
|
593 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
|
594 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
|
595 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
|
596 |
+
"vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
|
597 |
+
"vlm.vision_tokenizer.latents": "model-00001-of-00004.safetensors",
|
598 |
+
"vlm.vision_tokenizer.layers.0.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
599 |
+
"vlm.vision_tokenizer.layers.0.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
600 |
+
"vlm.vision_tokenizer.layers.0.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
601 |
+
"vlm.vision_tokenizer.layers.0.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
602 |
+
"vlm.vision_tokenizer.layers.0.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
603 |
+
"vlm.vision_tokenizer.layers.0.0.to_out.weight": "model-00001-of-00004.safetensors",
|
604 |
+
"vlm.vision_tokenizer.layers.0.0.to_q.weight": "model-00001-of-00004.safetensors",
|
605 |
+
"vlm.vision_tokenizer.layers.0.1.0.bias": "model-00001-of-00004.safetensors",
|
606 |
+
"vlm.vision_tokenizer.layers.0.1.0.weight": "model-00001-of-00004.safetensors",
|
607 |
+
"vlm.vision_tokenizer.layers.0.1.1.weight": "model-00001-of-00004.safetensors",
|
608 |
+
"vlm.vision_tokenizer.layers.0.1.3.weight": "model-00001-of-00004.safetensors",
|
609 |
+
"vlm.vision_tokenizer.layers.1.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
610 |
+
"vlm.vision_tokenizer.layers.1.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
611 |
+
"vlm.vision_tokenizer.layers.1.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
612 |
+
"vlm.vision_tokenizer.layers.1.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
613 |
+
"vlm.vision_tokenizer.layers.1.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
614 |
+
"vlm.vision_tokenizer.layers.1.0.to_out.weight": "model-00001-of-00004.safetensors",
|
615 |
+
"vlm.vision_tokenizer.layers.1.0.to_q.weight": "model-00001-of-00004.safetensors",
|
616 |
+
"vlm.vision_tokenizer.layers.1.1.0.bias": "model-00001-of-00004.safetensors",
|
617 |
+
"vlm.vision_tokenizer.layers.1.1.0.weight": "model-00001-of-00004.safetensors",
|
618 |
+
"vlm.vision_tokenizer.layers.1.1.1.weight": "model-00001-of-00004.safetensors",
|
619 |
+
"vlm.vision_tokenizer.layers.1.1.3.weight": "model-00001-of-00004.safetensors",
|
620 |
+
"vlm.vision_tokenizer.layers.2.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
621 |
+
"vlm.vision_tokenizer.layers.2.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
622 |
+
"vlm.vision_tokenizer.layers.2.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
623 |
+
"vlm.vision_tokenizer.layers.2.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
624 |
+
"vlm.vision_tokenizer.layers.2.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
625 |
+
"vlm.vision_tokenizer.layers.2.0.to_out.weight": "model-00001-of-00004.safetensors",
|
626 |
+
"vlm.vision_tokenizer.layers.2.0.to_q.weight": "model-00001-of-00004.safetensors",
|
627 |
+
"vlm.vision_tokenizer.layers.2.1.0.bias": "model-00001-of-00004.safetensors",
|
628 |
+
"vlm.vision_tokenizer.layers.2.1.0.weight": "model-00001-of-00004.safetensors",
|
629 |
+
"vlm.vision_tokenizer.layers.2.1.1.weight": "model-00001-of-00004.safetensors",
|
630 |
+
"vlm.vision_tokenizer.layers.2.1.3.weight": "model-00001-of-00004.safetensors",
|
631 |
+
"vlm.vision_tokenizer.layers.3.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
632 |
+
"vlm.vision_tokenizer.layers.3.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
633 |
+
"vlm.vision_tokenizer.layers.3.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
634 |
+
"vlm.vision_tokenizer.layers.3.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
635 |
+
"vlm.vision_tokenizer.layers.3.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
636 |
+
"vlm.vision_tokenizer.layers.3.0.to_out.weight": "model-00001-of-00004.safetensors",
|
637 |
+
"vlm.vision_tokenizer.layers.3.0.to_q.weight": "model-00001-of-00004.safetensors",
|
638 |
+
"vlm.vision_tokenizer.layers.3.1.0.bias": "model-00001-of-00004.safetensors",
|
639 |
+
"vlm.vision_tokenizer.layers.3.1.0.weight": "model-00001-of-00004.safetensors",
|
640 |
+
"vlm.vision_tokenizer.layers.3.1.1.weight": "model-00001-of-00004.safetensors",
|
641 |
+
"vlm.vision_tokenizer.layers.3.1.3.weight": "model-00001-of-00004.safetensors",
|
642 |
+
"vlm.vision_tokenizer.layers.4.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
643 |
+
"vlm.vision_tokenizer.layers.4.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
644 |
+
"vlm.vision_tokenizer.layers.4.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
645 |
+
"vlm.vision_tokenizer.layers.4.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
646 |
+
"vlm.vision_tokenizer.layers.4.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
647 |
+
"vlm.vision_tokenizer.layers.4.0.to_out.weight": "model-00001-of-00004.safetensors",
|
648 |
+
"vlm.vision_tokenizer.layers.4.0.to_q.weight": "model-00001-of-00004.safetensors",
|
649 |
+
"vlm.vision_tokenizer.layers.4.1.0.bias": "model-00001-of-00004.safetensors",
|
650 |
+
"vlm.vision_tokenizer.layers.4.1.0.weight": "model-00001-of-00004.safetensors",
|
651 |
+
"vlm.vision_tokenizer.layers.4.1.1.weight": "model-00001-of-00004.safetensors",
|
652 |
+
"vlm.vision_tokenizer.layers.4.1.3.weight": "model-00001-of-00004.safetensors",
|
653 |
+
"vlm.vision_tokenizer.layers.5.0.norm_latents.bias": "model-00001-of-00004.safetensors",
|
654 |
+
"vlm.vision_tokenizer.layers.5.0.norm_latents.weight": "model-00001-of-00004.safetensors",
|
655 |
+
"vlm.vision_tokenizer.layers.5.0.norm_media.bias": "model-00001-of-00004.safetensors",
|
656 |
+
"vlm.vision_tokenizer.layers.5.0.norm_media.weight": "model-00001-of-00004.safetensors",
|
657 |
+
"vlm.vision_tokenizer.layers.5.0.to_kv.weight": "model-00001-of-00004.safetensors",
|
658 |
+
"vlm.vision_tokenizer.layers.5.0.to_out.weight": "model-00001-of-00004.safetensors",
|
659 |
+
"vlm.vision_tokenizer.layers.5.0.to_q.weight": "model-00001-of-00004.safetensors",
|
660 |
+
"vlm.vision_tokenizer.layers.5.1.0.bias": "model-00001-of-00004.safetensors",
|
661 |
+
"vlm.vision_tokenizer.layers.5.1.0.weight": "model-00001-of-00004.safetensors",
|
662 |
+
"vlm.vision_tokenizer.layers.5.1.1.weight": "model-00001-of-00004.safetensors",
|
663 |
+
"vlm.vision_tokenizer.layers.5.1.3.weight": "model-00001-of-00004.safetensors",
|
664 |
+
"vlm.vision_tokenizer.norm.bias": "model-00001-of-00004.safetensors",
|
665 |
+
"vlm.vision_tokenizer.norm.weight": "model-00001-of-00004.safetensors",
|
666 |
+
"vlm.vision_tokenizer.projection.bias": "model-00001-of-00004.safetensors",
|
667 |
+
"vlm.vision_tokenizer.projection.weight": "model-00001-of-00004.safetensors"
|
668 |
+
}
|
669 |
+
}
|
modeling_xgenmm.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel, AutoModelForCausalLM
|
2 |
+
import torch
|
3 |
+
import open_clip
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
from utils import check_embedding_fns
|
6 |
+
from vlm import PerceiverResampler, Kosmos
|
7 |
+
from configuration_xgenmm import XGenMMVisionEncoderConfig, XGenMMVisionTokenizerConfig, XGenMMConfig
|
8 |
+
|
9 |
+
class XGenMMVisionEncoder(PreTrainedModel):
|
10 |
+
main_input_name = "pixel_values"
|
11 |
+
config_class = XGenMMVisionEncoderConfig
|
12 |
+
|
13 |
+
def __init__(self, config: XGenMMVisionEncoderConfig):
|
14 |
+
super().__init__(config)
|
15 |
+
if config.model_name != 'ViT-H-14-378-quickgelu':
|
16 |
+
raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
|
17 |
+
self.model, _, _ = open_clip.create_model_and_transforms(
|
18 |
+
model_name = config.model_name,
|
19 |
+
force_image_size=config.force_image_size
|
20 |
+
)
|
21 |
+
|
22 |
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
23 |
+
# assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
|
24 |
+
return self.model.encode_image(pixel_values)
|
25 |
+
|
26 |
+
|
27 |
+
# vision tokenizer
|
28 |
+
class XGenMMVisionTokenizer(PreTrainedModel):
|
29 |
+
config_class = XGenMMVisionTokenizerConfig
|
30 |
+
def __init__(self, config: XGenMMVisionTokenizerConfig):
|
31 |
+
super().__init__(config)
|
32 |
+
self.model = PerceiverResampler(
|
33 |
+
dim=config.vis_feature_dim,
|
34 |
+
dim_inner=config.lang_embedding_dim,
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self,
|
38 |
+
vision_features: torch.Tensor,
|
39 |
+
vision_attn_masks: torch.Tensor):
|
40 |
+
return self.model(vision_features, vision_attn_masks)
|
41 |
+
|
42 |
+
# XGenMM model
|
43 |
+
class XGenMMModelForConditionalGeneration(PreTrainedModel):
|
44 |
+
config_class = XGenMMConfig
|
45 |
+
|
46 |
+
def __init__(self, config: XGenMMConfig):
|
47 |
+
super().__init__(config)
|
48 |
+
|
49 |
+
# vision encoder initialization
|
50 |
+
vision_encoder = XGenMMVisionEncoder(config.vision_encoder_config).model
|
51 |
+
vision_encoder.visual.output_tokens = True
|
52 |
+
vision_encoder = vision_encoder.visual
|
53 |
+
|
54 |
+
# language model initialization
|
55 |
+
language_model = AutoModelForCausalLM.from_config(config.text_config)
|
56 |
+
check_embedding_fns(language_model)
|
57 |
+
# Update _tied_weights_keys using the base model used.
|
58 |
+
if language_model._tied_weights_keys is not None:
|
59 |
+
self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
|
60 |
+
|
61 |
+
# vision tokenizer initialization
|
62 |
+
if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
|
63 |
+
overwrite = language_model.get_input_embeddings().weight.shape[1]
|
64 |
+
config.vision_tokenizer_config.lang_embedding_dim = overwrite
|
65 |
+
print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
|
66 |
+
|
67 |
+
vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
|
68 |
+
|
69 |
+
self.vlm = Kosmos(
|
70 |
+
vision_encoder=vision_encoder,
|
71 |
+
vision_tokenizer=vision_tokenizer,
|
72 |
+
lang_model=language_model,
|
73 |
+
initial_tokenizer_len = config.text_config.initial_tokenizer_len,
|
74 |
+
pad_token_id = config.text_config.pad_token_id,
|
75 |
+
)
|
76 |
+
# Initialize weights and apply final processing
|
77 |
+
self.post_init()
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def generate(
|
81 |
+
self,
|
82 |
+
pixel_values: torch.FloatTensor,
|
83 |
+
input_ids: Optional[torch.LongTensor] = None,
|
84 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
85 |
+
**generate_kwargs,
|
86 |
+
) -> torch.LongTensor:
|
87 |
+
self.vlm = self.vlm.eval()
|
88 |
+
return self.vlm.generate(
|
89 |
+
vision_x = pixel_values,
|
90 |
+
lang_x = input_ids,
|
91 |
+
attention_mask = attention_mask,
|
92 |
+
**generate_kwargs)
|
93 |
+
|
94 |
+
def update_special_tokens(self, tokenizer):
|
95 |
+
tokenizer.add_special_tokens(
|
96 |
+
{"additional_special_tokens": list(self.vlm.special_tokens.values())}
|
97 |
+
)
|
98 |
+
self.vlm.lang_model.config.vocab_size = len(tokenizer)
|
99 |
+
self.vlm.set_special_token_ids(
|
100 |
+
{
|
101 |
+
v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
|
102 |
+
}
|
103 |
+
)
|
104 |
+
return tokenizer
|
105 |
+
|
preprocessor_config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoImageProcessor": "image_processing_xgenmm.XGenMMImageProcessor"
|
4 |
+
},
|
5 |
+
"do_resize": true,
|
6 |
+
"image_mean": [
|
7 |
+
0.48145466,
|
8 |
+
0.4578275,
|
9 |
+
0.40821073
|
10 |
+
],
|
11 |
+
"image_processor_type": "XGenMMImageProcessor",
|
12 |
+
"image_std": [
|
13 |
+
0.26862954,
|
14 |
+
0.26130258,
|
15 |
+
0.27577711
|
16 |
+
],
|
17 |
+
"interpolation_mode": "bicubic",
|
18 |
+
"resize_mode": "squash",
|
19 |
+
"size": [
|
20 |
+
378,
|
21 |
+
378
|
22 |
+
]
|
23 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": {
|
17 |
+
"content": "<pad>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
},
|
23 |
+
"unk_token": {
|
24 |
+
"content": "<unk>",
|
25 |
+
"lstrip": false,
|
26 |
+
"normalized": false,
|
27 |
+
"rstrip": false,
|
28 |
+
"single_word": false
|
29 |
+
}
|
30 |
+
}
|
test_samples/few_shots.json
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"example_1": {
|
3 |
+
"image_path": "./test_samples/images/COCO_val2014_000000486568.jpg",
|
4 |
+
"instruction": "A short description of this image in one sentence:",
|
5 |
+
"output": "A man in a suit holding something in his office."
|
6 |
+
},
|
7 |
+
"example_2": {
|
8 |
+
"image_path": "./test_samples/images/COCO_val2014_000000176466.jpg",
|
9 |
+
"instruction": "A short description of this image in one sentence:",
|
10 |
+
"output": "The young girl is standing by the fire hydrant in curlers."
|
11 |
+
},
|
12 |
+
"example_3": {
|
13 |
+
"image_path": "./test_samples/images/COCO_val2014_000000392640.jpg",
|
14 |
+
"instruction": "A short description of this image in one sentence:",
|
15 |
+
"output": "A man with a skateboard that is jumping in the air."
|
16 |
+
},
|
17 |
+
"example_4": {
|
18 |
+
"image_path": "./test_samples/images/COCO_val2014_000000267408.jpg",
|
19 |
+
"instruction": "A short description of this image in one sentence:",
|
20 |
+
"output": "A few people looking at a television that's next to a laptop."
|
21 |
+
}
|
22 |
+
}
|
test_samples/images/000adfe5b817011c.jpg
ADDED
test_samples/images/COCO_val2014_000000176466.jpg
ADDED
test_samples/images/COCO_val2014_000000267408.jpg
ADDED
test_samples/images/COCO_val2014_000000392640.jpg
ADDED
test_samples/images/COCO_val2014_000000486568.jpg
ADDED
test_samples/zero_shot.json
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"image_path": "./test_samples/images/000adfe5b817011c.jpg",
|
3 |
+
"instruction": "Please provide a short description of this image:"
|
4 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
tokenizer_config.json
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": true,
|
26 |
+
"single_word": false,
|
27 |
+
"special": false
|
28 |
+
},
|
29 |
+
"32000": {
|
30 |
+
"content": "<|endoftext|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"32001": {
|
38 |
+
"content": "<|assistant|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": true,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"32002": {
|
46 |
+
"content": "<|placeholder1|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": true,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"32003": {
|
54 |
+
"content": "<|placeholder2|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": true,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"32004": {
|
62 |
+
"content": "<|placeholder3|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": true,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"32005": {
|
70 |
+
"content": "<|placeholder4|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": true,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"32006": {
|
78 |
+
"content": "<|system|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": true,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"32007": {
|
86 |
+
"content": "<|end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": true,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"32008": {
|
94 |
+
"content": "<|placeholder5|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": true,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"32009": {
|
102 |
+
"content": "<|placeholder6|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": true,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"32010": {
|
110 |
+
"content": "<|user|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": true,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"32011": {
|
118 |
+
"content": "<pad>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": true
|
124 |
+
}
|
125 |
+
},
|
126 |
+
"bos_token": "<s>",
|
127 |
+
"chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
128 |
+
"clean_up_tokenization_spaces": false,
|
129 |
+
"eos_token": "<|endoftext|>",
|
130 |
+
"model_max_length": 4096,
|
131 |
+
"pad_token": "<pad>",
|
132 |
+
"padding_side": "left",
|
133 |
+
"sp_model_kwargs": {},
|
134 |
+
"tokenizer_class": "LlamaTokenizer",
|
135 |
+
"unk_token": "<unk>",
|
136 |
+
"use_default_system_prompt": false
|
137 |
+
}
|
utils.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import ast
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
|
7 |
+
def has_fn(model, fn_name):
|
8 |
+
"""Check if model has a function fn_name"""
|
9 |
+
return callable(getattr(model, fn_name, None))
|
10 |
+
|
11 |
+
def exists(val):
|
12 |
+
return val is not None
|
13 |
+
|
14 |
+
def num_params(module, filter_to_trainable=False):
|
15 |
+
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
|
16 |
+
if filter_to_trainable:
|
17 |
+
return sum(p.numel() for p in module.parameters() if p.requires_grad)
|
18 |
+
else:
|
19 |
+
return sum(p.numel() for p in module.parameters())
|
20 |
+
|
21 |
+
def hasattr_recursive(obj, att):
|
22 |
+
"""
|
23 |
+
Check if obj has nested attribute
|
24 |
+
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
|
25 |
+
"""
|
26 |
+
if att == "":
|
27 |
+
return True
|
28 |
+
i = att.find(".")
|
29 |
+
if i < 0:
|
30 |
+
return hasattr(obj, att)
|
31 |
+
else:
|
32 |
+
try:
|
33 |
+
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
34 |
+
except:
|
35 |
+
return False
|
36 |
+
|
37 |
+
def getattr_recursive(obj, att):
|
38 |
+
"""
|
39 |
+
Return nested attribute of obj
|
40 |
+
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
|
41 |
+
"""
|
42 |
+
if att == "":
|
43 |
+
return obj
|
44 |
+
i = att.find(".")
|
45 |
+
if i < 0:
|
46 |
+
return getattr(obj, att)
|
47 |
+
else:
|
48 |
+
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
|
49 |
+
|
50 |
+
|
51 |
+
def setattr_recursive(obj, att, val):
|
52 |
+
"""
|
53 |
+
Set nested attribute of obj
|
54 |
+
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
|
55 |
+
"""
|
56 |
+
if "." in att:
|
57 |
+
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
|
58 |
+
setattr(obj, att.split(".")[-1], val)
|
59 |
+
|
60 |
+
|
61 |
+
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
|
62 |
+
"""
|
63 |
+
Stack a list of tensors with padding on one side
|
64 |
+
Args:
|
65 |
+
list_of_tensors (list[torch.Tensor]): List of tensors to stack
|
66 |
+
padding_value (int, optional): Value to pad with. Defaults to 0.
|
67 |
+
padding_side (str, optional): Side to pad on. Defaults to "right".
|
68 |
+
Returns:
|
69 |
+
torch.Tensor: Stacked tensors
|
70 |
+
"""
|
71 |
+
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
|
72 |
+
padded_tensors = []
|
73 |
+
for tensor in list_of_tensors:
|
74 |
+
num_tokens = tensor.size(0)
|
75 |
+
if len(tensor.size()) == 1:
|
76 |
+
padding = torch.full(
|
77 |
+
(max_tokens - num_tokens,),
|
78 |
+
padding_value,
|
79 |
+
dtype=tensor.dtype,
|
80 |
+
device=tensor.device,
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
padding = torch.full(
|
84 |
+
(max_tokens - num_tokens, tensor.size(1)),
|
85 |
+
padding_value,
|
86 |
+
dtype=tensor.dtype,
|
87 |
+
device=tensor.device,
|
88 |
+
)
|
89 |
+
padded_tensor = (
|
90 |
+
torch.cat((tensor, padding), dim=0)
|
91 |
+
if padding_side == "right"
|
92 |
+
else torch.cat((padding, tensor), dim=0)
|
93 |
+
)
|
94 |
+
padded_tensors.append(padded_tensor)
|
95 |
+
return torch.stack(padded_tensors)
|
96 |
+
|
97 |
+
|
98 |
+
def check_embedding_fns(lang_model):
|
99 |
+
"""Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
|
100 |
+
if not has_fn(lang_model, "get_input_embeddings"):
|
101 |
+
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
|
102 |
+
lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
|
103 |
+
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
|
104 |
+
lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
|
105 |
+
else:
|
106 |
+
raise ValueError(
|
107 |
+
"We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
|
108 |
+
)
|
109 |
+
|
110 |
+
if not has_fn(lang_model, "set_input_embeddings"):
|
111 |
+
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
|
112 |
+
lang_model.set_input_embeddings = lambda x: setattr_recursive(
|
113 |
+
lang_model, "transformer.wte", x
|
114 |
+
)
|
115 |
+
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
|
116 |
+
lang_model.set_input_embeddings = lambda x: setattr_recursive(
|
117 |
+
lang_model, "model.decoder.embed_tokens", x
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
raise ValueError(
|
121 |
+
"We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
|
122 |
+
)
|
123 |
+
|
124 |
+
if not has_fn(lang_model, "get_output_embeddings"):
|
125 |
+
if hasattr_recursive(lang_model, "lm_head"):
|
126 |
+
lang_model.get_output_embeddings = lambda: lang_model.lm_head
|
127 |
+
else:
|
128 |
+
raise ValueError(
|
129 |
+
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
|
130 |
+
)
|
131 |
+
|
132 |
+
if not has_fn(lang_model, "set_output_embeddings"):
|
133 |
+
if hasattr_recursive(lang_model, "lm_head"):
|
134 |
+
lang_model.set_output_embeddings = lambda x: setattr_recursive(
|
135 |
+
lang_model, "lm_head", x
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
raise ValueError(
|
139 |
+
"We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def has_fn(model, fn_name):
|
144 |
+
"""Check if model has a function fn_name"""
|
145 |
+
return callable(getattr(model, fn_name, None))
|
146 |
+
|
147 |
+
|
148 |
+
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
|
149 |
+
#
|
150 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
151 |
+
# you may not use this file except in compliance with the License.
|
152 |
+
# You may obtain a copy of the License at
|
153 |
+
#
|
154 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
155 |
+
#
|
156 |
+
# Unless required by applicable law or agreed to in writing, software
|
157 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
158 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
159 |
+
# See the License for the specific language governing permissions and
|
160 |
+
# limitations under the License.
|
161 |
+
|
162 |
+
def unpad_image(tensor, original_size, keep_original_shape=False):
|
163 |
+
"""
|
164 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
168 |
+
original_size (tuple): The original size of the image (height, width).
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
torch.Tensor: The unpadded image tensor.
|
172 |
+
"""
|
173 |
+
original_width, original_height = original_size
|
174 |
+
current_height, current_width = tensor.shape[1:]
|
175 |
+
|
176 |
+
original_aspect_ratio = original_width / original_height
|
177 |
+
current_aspect_ratio = current_width / current_height
|
178 |
+
|
179 |
+
if original_aspect_ratio > current_aspect_ratio:
|
180 |
+
scale_factor = current_width / original_width
|
181 |
+
new_height = int(original_height * scale_factor)
|
182 |
+
padding = (current_height - new_height) // 2
|
183 |
+
if keep_original_shape:
|
184 |
+
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
|
185 |
+
attention_mask[:padding, :] = 0
|
186 |
+
attention_mask[current_height - padding:, :] = 0
|
187 |
+
return tensor, attention_mask
|
188 |
+
else:
|
189 |
+
unpadded_tensor = tensor[:, padding:current_height - padding, :]
|
190 |
+
return unpadded_tensor, None
|
191 |
+
else:
|
192 |
+
scale_factor = current_height / original_height
|
193 |
+
new_width = int(original_width * scale_factor)
|
194 |
+
padding = (current_width - new_width) // 2
|
195 |
+
if keep_original_shape:
|
196 |
+
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
|
197 |
+
attention_mask[:, :padding] = 0
|
198 |
+
attention_mask[:, current_width - padding:] = 0
|
199 |
+
return tensor, attention_mask
|
200 |
+
else:
|
201 |
+
unpadded_tensor = tensor[:, :, padding:current_width - padding]
|
202 |
+
return unpadded_tensor, None
|
203 |
+
|
204 |
+
|
205 |
+
def select_best_resolution(original_size, possible_resolutions):
|
206 |
+
"""
|
207 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
211 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
tuple: The best fit resolution in the format (width, height).
|
215 |
+
"""
|
216 |
+
original_width, original_height = original_size
|
217 |
+
best_fit = None
|
218 |
+
max_effective_resolution = 0
|
219 |
+
min_wasted_resolution = float('inf')
|
220 |
+
|
221 |
+
for width, height in possible_resolutions:
|
222 |
+
scale = min(width / original_width, height / original_height)
|
223 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
224 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
225 |
+
wasted_resolution = (width * height) - effective_resolution
|
226 |
+
|
227 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
228 |
+
max_effective_resolution = effective_resolution
|
229 |
+
min_wasted_resolution = wasted_resolution
|
230 |
+
best_fit = (width, height)
|
231 |
+
|
232 |
+
return best_fit
|
233 |
+
|
234 |
+
|
235 |
+
def resize_and_pad_image(image, target_resolution):
|
236 |
+
"""
|
237 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
image (PIL.Image.Image): The input image.
|
241 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
PIL.Image.Image: The resized and padded image.
|
245 |
+
"""
|
246 |
+
original_width, original_height = image.size
|
247 |
+
target_width, target_height = target_resolution
|
248 |
+
|
249 |
+
scale_w = target_width / original_width
|
250 |
+
scale_h = target_height / original_height
|
251 |
+
|
252 |
+
if scale_w < scale_h:
|
253 |
+
new_width = target_width
|
254 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
255 |
+
else:
|
256 |
+
new_height = target_height
|
257 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
258 |
+
|
259 |
+
# Resize the image
|
260 |
+
resized_image = image.resize((new_width, new_height))
|
261 |
+
|
262 |
+
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
|
263 |
+
paste_x = (target_width - new_width) // 2
|
264 |
+
paste_y = (target_height - new_height) // 2
|
265 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
266 |
+
|
267 |
+
return new_image
|
268 |
+
|
269 |
+
|
270 |
+
def divide_to_patches(image, patch_size):
|
271 |
+
"""
|
272 |
+
Divides an image into patches of a specified size.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
image (PIL.Image.Image): The input image.
|
276 |
+
patch_size (int): The size of each patch.
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
280 |
+
"""
|
281 |
+
patches = []
|
282 |
+
width, height = image.size
|
283 |
+
for i in range(0, height, patch_size):
|
284 |
+
for j in range(0, width, patch_size):
|
285 |
+
box = (j, i, j + patch_size, i + patch_size)
|
286 |
+
patch = image.crop(box)
|
287 |
+
patches.append(patch)
|
288 |
+
|
289 |
+
return patches
|
290 |
+
|
291 |
+
|
292 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
293 |
+
"""
|
294 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
298 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
299 |
+
patch_size (int): The size of each image patch.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
303 |
+
"""
|
304 |
+
if type(grid_pinpoints) is list:
|
305 |
+
possible_resolutions = grid_pinpoints
|
306 |
+
else:
|
307 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
308 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
309 |
+
return width // patch_size, height // patch_size
|
310 |
+
|
311 |
+
|
312 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
313 |
+
"""
|
314 |
+
Process an image with variable resolutions.
|
315 |
+
|
316 |
+
Args:
|
317 |
+
image (PIL.Image.Image): The input image to be processed.
|
318 |
+
processor: The image processor object.
|
319 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
torch.Tensor: A tensor containing the processed image patches.
|
323 |
+
"""
|
324 |
+
# FIXME: determine grid_pinpoints from image sizes.
|
325 |
+
if type(grid_pinpoints) is list:
|
326 |
+
possible_resolutions = grid_pinpoints
|
327 |
+
else:
|
328 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
329 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
330 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
331 |
+
|
332 |
+
processor_size = processor.transforms[0].size
|
333 |
+
patches = divide_to_patches(image_padded, processor_size[0])
|
334 |
+
|
335 |
+
image_original_resize = image.resize((processor_size[0], processor_size[0]))
|
336 |
+
|
337 |
+
image_patches = [image_original_resize] + patches
|
338 |
+
image_patches = [processor(image_patch)
|
339 |
+
for image_patch in image_patches]
|
340 |
+
return torch.stack(image_patches, dim=0)
|
341 |
+
|
342 |
+
|
343 |
+
def expand2square(pil_img, background_color):
|
344 |
+
width, height = pil_img.size
|
345 |
+
if width == height:
|
346 |
+
return pil_img
|
347 |
+
elif width > height:
|
348 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
349 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
350 |
+
return result
|
351 |
+
else:
|
352 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
353 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
354 |
+
return result
|
355 |
+
|
356 |
+
|
357 |
+
def process_images(images, image_processor, model_cfg):
|
358 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
359 |
+
new_images = []
|
360 |
+
if image_aspect_ratio == 'pad':
|
361 |
+
for image in images:
|
362 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
|
363 |
+
image = image_processor(image)
|
364 |
+
new_images.append(image)
|
365 |
+
elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
|
366 |
+
base_img_size = image_processor.transforms[0].size[0]
|
367 |
+
for image in images:
|
368 |
+
image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
|
369 |
+
[base_img_size*2,base_img_size],
|
370 |
+
[base_img_size*2,base_img_size*2],
|
371 |
+
[base_img_size*3,base_img_size],
|
372 |
+
[base_img_size,base_img_size*3]])
|
373 |
+
|
374 |
+
# Debug any res inference by only using 672x672.
|
375 |
+
# image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
|
376 |
+
new_images.append(image)
|
377 |
+
else:
|
378 |
+
return image_processor(images)
|
379 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
380 |
+
new_images = torch.stack(new_images, dim=0)
|
381 |
+
return new_images
|
382 |
+
|
383 |
+
|
vlm.py
ADDED
@@ -0,0 +1,1308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch import einsum, nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
from einops_exts import rearrange_many
|
6 |
+
from einops import rearrange
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10 |
+
from dataclasses import dataclass
|
11 |
+
from transformers import CLIPVisionModel
|
12 |
+
import transformers
|
13 |
+
|
14 |
+
from utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
|
15 |
+
|
16 |
+
|
17 |
+
class VisionTokenizer(nn.Module):
|
18 |
+
def __init__(self, dim_media, num_tokens_per_media):
|
19 |
+
super().__init__()
|
20 |
+
self.dim_media = dim_media
|
21 |
+
self.num_tokens_per_media = num_tokens_per_media
|
22 |
+
|
23 |
+
class PerceiverAttention(nn.Module):
|
24 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
25 |
+
super().__init__()
|
26 |
+
self.scale = dim_head**-0.5
|
27 |
+
self.heads = heads
|
28 |
+
inner_dim = dim_head * heads
|
29 |
+
|
30 |
+
self.norm_media = nn.LayerNorm(dim)
|
31 |
+
self.norm_latents = nn.LayerNorm(dim)
|
32 |
+
|
33 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
34 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
35 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
36 |
+
|
37 |
+
def forward(self, x, latents, vision_attn_masks=None):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
x (torch.Tensor): image features
|
41 |
+
shape (b, T, n1, D)
|
42 |
+
latent (torch.Tensor): latent features
|
43 |
+
shape (b, T, n2, D)
|
44 |
+
"""
|
45 |
+
x = self.norm_media(x)
|
46 |
+
latents = self.norm_latents(latents)
|
47 |
+
|
48 |
+
h = self.heads
|
49 |
+
|
50 |
+
q = self.to_q(latents)
|
51 |
+
kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
|
52 |
+
if vision_attn_masks is not None:
|
53 |
+
vision_attn_masks = torch.cat((vision_attn_masks,
|
54 |
+
torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
|
55 |
+
dim=-1)
|
56 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
57 |
+
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
|
58 |
+
q = q * self.scale
|
59 |
+
|
60 |
+
# attention
|
61 |
+
sim = einsum("... i d, ... j d -> ... i j", q, k)
|
62 |
+
# Apply vision attention mask here.
|
63 |
+
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
|
64 |
+
if vision_attn_masks is not None:
|
65 |
+
attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
|
66 |
+
vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
|
67 |
+
attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
|
68 |
+
sim += attn_bias
|
69 |
+
|
70 |
+
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
71 |
+
attn = sim.softmax(dim=-1)
|
72 |
+
|
73 |
+
|
74 |
+
out = einsum("... i j, ... j d -> ... i d", attn, v)
|
75 |
+
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
|
76 |
+
return self.to_out(out)
|
77 |
+
|
78 |
+
|
79 |
+
def FeedForward(dim, mult=4):
|
80 |
+
inner_dim = int(dim * mult)
|
81 |
+
return nn.Sequential(
|
82 |
+
nn.LayerNorm(dim),
|
83 |
+
nn.Linear(dim, inner_dim, bias=False),
|
84 |
+
nn.GELU(),
|
85 |
+
nn.Linear(inner_dim, dim, bias=False),
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
class PerceiverResampler(VisionTokenizer):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
*,
|
93 |
+
dim,
|
94 |
+
dim_inner=None,
|
95 |
+
depth=6,
|
96 |
+
dim_head=96,
|
97 |
+
heads=16,
|
98 |
+
num_latents=128,
|
99 |
+
max_num_media=None,
|
100 |
+
max_num_frames=None,
|
101 |
+
ff_mult=4,
|
102 |
+
):
|
103 |
+
"""
|
104 |
+
Perceiver module which takes in image features and outputs image tokens.
|
105 |
+
Args:
|
106 |
+
dim (int): dimension of the incoming image features
|
107 |
+
dim_inner (int, optional): final dimension to project the incoming image features to;
|
108 |
+
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
|
109 |
+
depth (int, optional): number of layers. Defaults to 6.
|
110 |
+
dim_head (int, optional): dimension of each head. Defaults to 64.
|
111 |
+
heads (int, optional): number of heads. Defaults to 8.
|
112 |
+
num_latents (int, optional): number of latent tokens to use in the Perceiver;
|
113 |
+
also corresponds to number of tokens per sequence to output. Defaults to 64.
|
114 |
+
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
|
115 |
+
and keep positional embeddings for. If None, no positional embeddings are used.
|
116 |
+
max_num_frames (int, optional): maximum number of frames to input into the Perceiver
|
117 |
+
and keep positional embeddings for. If None, no positional embeddings are used.
|
118 |
+
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
|
119 |
+
"""
|
120 |
+
if dim_inner is not None:
|
121 |
+
projection = nn.Linear(dim, dim_inner)
|
122 |
+
else:
|
123 |
+
projection = None
|
124 |
+
dim_inner = dim
|
125 |
+
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
|
126 |
+
self.projection = projection
|
127 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
128 |
+
# positional embeddings
|
129 |
+
self.frame_embs = (
|
130 |
+
nn.Parameter(torch.randn(max_num_frames, dim))
|
131 |
+
if exists(max_num_frames)
|
132 |
+
else None
|
133 |
+
)
|
134 |
+
self.media_time_embs = (
|
135 |
+
nn.Parameter(torch.randn(max_num_media, 1, dim))
|
136 |
+
if exists(max_num_media)
|
137 |
+
else None
|
138 |
+
)
|
139 |
+
|
140 |
+
self.layers = nn.ModuleList([])
|
141 |
+
for _ in range(depth):
|
142 |
+
self.layers.append(
|
143 |
+
nn.ModuleList(
|
144 |
+
[
|
145 |
+
PerceiverAttention(
|
146 |
+
dim=dim, dim_head=dim_head, heads=heads
|
147 |
+
),
|
148 |
+
FeedForward(dim=dim, mult=ff_mult),
|
149 |
+
]
|
150 |
+
)
|
151 |
+
)
|
152 |
+
|
153 |
+
self.norm = nn.LayerNorm(dim)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
"""
|
157 |
+
Args:
|
158 |
+
x (torch.Tensor): image features
|
159 |
+
shape (b, T, F, v, D)
|
160 |
+
Returns:
|
161 |
+
shape (b, T, n, D) where n is self.num_latents
|
162 |
+
"""
|
163 |
+
b, T, F, v = x.shape[:4]
|
164 |
+
|
165 |
+
# frame and media time embeddings
|
166 |
+
if exists(self.frame_embs):
|
167 |
+
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
|
168 |
+
x = x + frame_embs
|
169 |
+
x = rearrange(
|
170 |
+
x, "b T F v d -> b T (F v) d"
|
171 |
+
) # flatten the frame and spatial dimensions
|
172 |
+
if exists(self.media_time_embs):
|
173 |
+
x = x + self.media_time_embs[:T]
|
174 |
+
|
175 |
+
# blocks
|
176 |
+
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
|
177 |
+
for attn, ff in self.layers:
|
178 |
+
latents = attn(x, latents) + latents
|
179 |
+
latents = ff(latents) + latents
|
180 |
+
|
181 |
+
if exists(self.projection):
|
182 |
+
return self.projection(self.norm(latents))
|
183 |
+
else:
|
184 |
+
return self.norm(latents)
|
185 |
+
|
186 |
+
class DecoupledEmbedding(nn.Embedding):
|
187 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
|
188 |
+
"""
|
189 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
|
190 |
+
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
|
191 |
+
then it will create `num_additional_embeddings` additional parameters that are always trained. If
|
192 |
+
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(
|
196 |
+
self,
|
197 |
+
max_original_id: int,
|
198 |
+
num_additional_embeddings: int = 0,
|
199 |
+
_weight: torch.Tensor = None,
|
200 |
+
num_original_embeddings: int = None,
|
201 |
+
embedding_dim: int = None,
|
202 |
+
partially_freeze=True,
|
203 |
+
device=None,
|
204 |
+
dtype=None,
|
205 |
+
pad_token_id=None,
|
206 |
+
) -> None:
|
207 |
+
"""
|
208 |
+
Args:
|
209 |
+
max_original_id (`int`):
|
210 |
+
The largest token id that should be embedded using the regular embedding (regular `weight`).
|
211 |
+
This is usually len(tokenizer) - 1 before additional tokens are added.
|
212 |
+
Note that this may not equal self.weight.shape[0]
|
213 |
+
num_additional_embeddings (`int`):
|
214 |
+
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
|
215 |
+
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
|
216 |
+
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
|
217 |
+
num_original_embeddings (`int`):
|
218 |
+
self.weight.shape[0]
|
219 |
+
embedding_dim (`int`):
|
220 |
+
The size of each embedding vector
|
221 |
+
partially_freeze: (`bool`, *optional*, defaults to `True`):
|
222 |
+
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
|
223 |
+
padding_idx (`int`, *optional*):
|
224 |
+
The padding index (needs to be less than num_embeddings)
|
225 |
+
|
226 |
+
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
|
227 |
+
`max_norm` or `norm_type`. We are not supporting these.
|
228 |
+
"""
|
229 |
+
# validate args
|
230 |
+
if pad_token_id is not None and pad_token_id > max_original_id:
|
231 |
+
raise ValueError(
|
232 |
+
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
|
233 |
+
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
|
234 |
+
)
|
235 |
+
if _weight is not None:
|
236 |
+
assert (num_original_embeddings is None) or (
|
237 |
+
_weight.shape[0] == num_original_embeddings
|
238 |
+
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
|
239 |
+
assert (embedding_dim is None) or (
|
240 |
+
_weight.shape[1] == embedding_dim
|
241 |
+
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
|
242 |
+
num_original_embeddings = _weight.shape[0]
|
243 |
+
embedding_dim = _weight.shape[1]
|
244 |
+
else:
|
245 |
+
assert (
|
246 |
+
num_original_embeddings is not None
|
247 |
+
), "num_original_embeddings must be provided if _weight is not provided"
|
248 |
+
assert (
|
249 |
+
embedding_dim is not None
|
250 |
+
), "embedding_dim must be provided if _weight is not provided"
|
251 |
+
|
252 |
+
super().__init__(
|
253 |
+
num_embeddings=num_original_embeddings,
|
254 |
+
embedding_dim=embedding_dim,
|
255 |
+
device=device,
|
256 |
+
dtype=dtype,
|
257 |
+
padding_idx=pad_token_id,
|
258 |
+
_weight=_weight,
|
259 |
+
)
|
260 |
+
self.max_original_id = max_original_id
|
261 |
+
self.padding_idx = pad_token_id
|
262 |
+
self.num_additional_embeddings = num_additional_embeddings
|
263 |
+
if self.num_additional_embeddings > 0:
|
264 |
+
self.additional_embedding = nn.Embedding(
|
265 |
+
num_embeddings=self.num_additional_embeddings,
|
266 |
+
embedding_dim=embedding_dim,
|
267 |
+
device=device,
|
268 |
+
dtype=dtype,
|
269 |
+
)
|
270 |
+
self.set_requires_grad(
|
271 |
+
require_regular_grad=not partially_freeze, require_additional_grad=True
|
272 |
+
)
|
273 |
+
|
274 |
+
def set_requires_grad(self, require_regular_grad, require_additional_grad):
|
275 |
+
"""
|
276 |
+
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
|
277 |
+
"""
|
278 |
+
self.weight.requires_grad_(require_regular_grad)
|
279 |
+
self.additional_embedding.requires_grad_(require_additional_grad)
|
280 |
+
|
281 |
+
def forward(self, input_ids):
|
282 |
+
"""
|
283 |
+
we have 2 embeddings, with different indices - one pretrained self.weight and another
|
284 |
+
self.additional_embedding.weight that is being trained.
|
285 |
+
|
286 |
+
in order to make a lookup of the input ids, we:
|
287 |
+
1. find out the indices of the entries belonging to the 2nd embedding
|
288 |
+
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
|
289 |
+
embedding starts from 0 and not num_embeddings
|
290 |
+
3. perform the 2nd embedding lookup
|
291 |
+
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
|
292 |
+
5. perform the 1st embedding lookup
|
293 |
+
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
|
294 |
+
|
295 |
+
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
|
296 |
+
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
|
297 |
+
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
|
298 |
+
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
|
299 |
+
measure.
|
300 |
+
|
301 |
+
"""
|
302 |
+
if self.num_additional_embeddings == 0:
|
303 |
+
return F.embedding(input_ids, self.weight)
|
304 |
+
|
305 |
+
# Clone so that we don't modify the original input_ids later on
|
306 |
+
input_ids = input_ids.clone()
|
307 |
+
additional_vocab_indices = torch.where(input_ids > self.max_original_id)
|
308 |
+
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
309 |
+
additional_embeddings = self.additional_embedding(
|
310 |
+
input_ids_additional_vocab - self.max_original_id - 1
|
311 |
+
)
|
312 |
+
|
313 |
+
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
314 |
+
input_ids[additional_vocab_indices] = 0
|
315 |
+
full_vector = F.embedding(input_ids, self.weight)
|
316 |
+
|
317 |
+
# overwrite the records with high indices
|
318 |
+
full_vector[additional_vocab_indices] = additional_embeddings
|
319 |
+
|
320 |
+
return full_vector
|
321 |
+
|
322 |
+
def extra_repr(self) -> str:
|
323 |
+
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
|
324 |
+
self.max_original_id + 1,
|
325 |
+
self.num_additional_embeddings,
|
326 |
+
self.embedding_dim,
|
327 |
+
(not self.weight.requires_grad),
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
class DecoupledLinear(nn.Linear):
|
332 |
+
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
|
333 |
+
"""
|
334 |
+
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
|
335 |
+
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
|
336 |
+
then it will create `additional_out_features * in_features` additional parameters that are always trained. If
|
337 |
+
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
|
338 |
+
"""
|
339 |
+
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
max_original_id: int,
|
343 |
+
additional_out_features: int = 0,
|
344 |
+
_weight: torch.Tensor = None,
|
345 |
+
_bias: torch.Tensor = None,
|
346 |
+
in_features: int = None,
|
347 |
+
original_out_features: int = None,
|
348 |
+
bias: bool = True,
|
349 |
+
partially_freeze: bool = True,
|
350 |
+
device=None,
|
351 |
+
dtype=None,
|
352 |
+
) -> None:
|
353 |
+
"""
|
354 |
+
Args:
|
355 |
+
max_original_id (`int`): The largest token id that should be extracted from the regular weight.
|
356 |
+
This is usually len(tokenizer) - 1 before additional tokens are added.
|
357 |
+
Note that this may not equal original_out_features - 1
|
358 |
+
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
|
359 |
+
If provided, this sets the `in_features` and `original_out_features` parameters.
|
360 |
+
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
|
361 |
+
in_features: int. Input hidden size.
|
362 |
+
original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
|
363 |
+
additional_out_features: int. Number of additional trainable dimensions.
|
364 |
+
bias: bool. Whether to include a bias term.
|
365 |
+
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
|
366 |
+
"""
|
367 |
+
# argument validation
|
368 |
+
if _weight is not None:
|
369 |
+
assert (_weight.shape[0] == original_out_features) or (
|
370 |
+
original_out_features is None
|
371 |
+
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
|
372 |
+
assert (_weight.shape[1] == in_features) or (
|
373 |
+
in_features is None
|
374 |
+
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
|
375 |
+
in_features = _weight.shape[1]
|
376 |
+
original_out_features = _weight.shape[0]
|
377 |
+
else:
|
378 |
+
assert (
|
379 |
+
in_features is not None
|
380 |
+
), "in_features must be provided if _weight is not provided"
|
381 |
+
assert (
|
382 |
+
original_out_features is not None
|
383 |
+
), "original_out_features must be provided if _weight is not provided"
|
384 |
+
|
385 |
+
if _bias is not None:
|
386 |
+
assert bias is True, "bias must be True if _bias is provided"
|
387 |
+
|
388 |
+
# initialize original linear
|
389 |
+
super().__init__(
|
390 |
+
in_features,
|
391 |
+
original_out_features,
|
392 |
+
bias,
|
393 |
+
device,
|
394 |
+
dtype)
|
395 |
+
|
396 |
+
# set weight and bias manually
|
397 |
+
if _weight is not None:
|
398 |
+
self.weight = nn.Parameter(_weight)
|
399 |
+
if _bias is not None:
|
400 |
+
self.bias = nn.Parameter(_bias)
|
401 |
+
|
402 |
+
self.in_features = in_features
|
403 |
+
self.original_out_features = original_out_features
|
404 |
+
self.max_original_id = max_original_id
|
405 |
+
|
406 |
+
# initialize additional linear
|
407 |
+
self.additional_out_features = additional_out_features
|
408 |
+
self.has_bias = bias
|
409 |
+
if additional_out_features > 0:
|
410 |
+
self.additional_fc = nn.Linear(
|
411 |
+
in_features=in_features,
|
412 |
+
out_features=additional_out_features,
|
413 |
+
bias=self.has_bias,
|
414 |
+
device=device,
|
415 |
+
dtype=dtype,
|
416 |
+
)
|
417 |
+
self.set_requires_grad(
|
418 |
+
require_regular_grad=not partially_freeze, require_additional_grad=True
|
419 |
+
)
|
420 |
+
|
421 |
+
def set_requires_grad(self, require_regular_grad, require_additional_grad):
|
422 |
+
"""
|
423 |
+
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
|
424 |
+
"""
|
425 |
+
self.weight.requires_grad_(require_regular_grad)
|
426 |
+
if self.has_bias:
|
427 |
+
self.bias.requires_grad_(require_regular_grad)
|
428 |
+
self.additional_fc.requires_grad_(require_additional_grad)
|
429 |
+
|
430 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
431 |
+
output = F.linear(input, self.weight, self.bias)
|
432 |
+
output = output[..., : self.max_original_id + 1]
|
433 |
+
|
434 |
+
if self.additional_out_features > 0:
|
435 |
+
additional_features = F.linear(
|
436 |
+
input, self.additional_fc.weight, self.additional_fc.bias
|
437 |
+
)
|
438 |
+
output = torch.cat((output, additional_features), -1)
|
439 |
+
return output
|
440 |
+
|
441 |
+
def extra_repr(self) -> str:
|
442 |
+
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
|
443 |
+
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
|
444 |
+
self.in_features,
|
445 |
+
self.max_original_id + 1,
|
446 |
+
self.additional_out_features,
|
447 |
+
self.bias is not None,
|
448 |
+
(not self.weight.requires_grad or not self.bias.requires_grad),
|
449 |
+
)
|
450 |
+
|
451 |
+
class VLM(nn.Module):
|
452 |
+
"""
|
453 |
+
Generic vision-language model (VLM) class.
|
454 |
+
A VLM consists of four components:
|
455 |
+
1. A vision encoder that extracts features from pixels, e.g. CLIP
|
456 |
+
input: (B, T_img, F, C, H, W)
|
457 |
+
output: (B, T_img, F, v, d)
|
458 |
+
2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
|
459 |
+
input: (B, T_img, F, v, d)
|
460 |
+
output: (B, T_img, n, d)
|
461 |
+
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
|
462 |
+
4. A language model
|
463 |
+
"""
|
464 |
+
|
465 |
+
def __init__(
|
466 |
+
self,
|
467 |
+
vision_encoder: nn.Module,
|
468 |
+
vision_tokenizer: nn.Module,
|
469 |
+
lang_model: nn.Module,
|
470 |
+
initial_tokenizer_len: int,
|
471 |
+
pad_token_id: int,
|
472 |
+
gradient_checkpointing: bool = False,
|
473 |
+
):
|
474 |
+
"""
|
475 |
+
Args:
|
476 |
+
vision_encoder (nn.Module): e.g. CLIP
|
477 |
+
vision_tokenizer (nn.Module): e.g. PerceiverResampler
|
478 |
+
lang_model (nn.Module): e.g. MPT
|
479 |
+
initial_tokenizer_len (int): size of the original tokenizer vocab
|
480 |
+
pad_token_id (int): id of the pad token
|
481 |
+
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
|
482 |
+
"""
|
483 |
+
super().__init__()
|
484 |
+
|
485 |
+
# save dimension information
|
486 |
+
self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
|
487 |
+
if hasattr(lang_model.config, "d_model"):
|
488 |
+
self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
|
489 |
+
else:
|
490 |
+
self.lang_hidden_dim = lang_model.config.hidden_size
|
491 |
+
self.vis_embedding_dim = vision_tokenizer.dim_media
|
492 |
+
self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
|
493 |
+
|
494 |
+
# core components
|
495 |
+
self.vision_encoder = vision_encoder
|
496 |
+
self.vision_tokenizer = vision_tokenizer
|
497 |
+
self.lang_model = lang_model
|
498 |
+
|
499 |
+
# lm embeddings
|
500 |
+
self.pad_token_id = pad_token_id
|
501 |
+
self.initial_tokenizer_len = initial_tokenizer_len
|
502 |
+
input_embeds = DecoupledEmbedding(
|
503 |
+
max_original_id=initial_tokenizer_len - 1,
|
504 |
+
num_additional_embeddings=len(self.special_tokens),
|
505 |
+
_weight=self.lang_model.get_input_embeddings().weight,
|
506 |
+
pad_token_id=self.pad_token_id,
|
507 |
+
)
|
508 |
+
if hasattr(input_embeds, "additional_embedding"):
|
509 |
+
input_embeds.additional_embedding.weight.data.normal_(
|
510 |
+
mean=0.0,
|
511 |
+
std=self.lang_model.config.initializer_range
|
512 |
+
if hasattr(self.lang_model.config, "initializer_range")
|
513 |
+
else 0.02,
|
514 |
+
)
|
515 |
+
self.lang_model.set_input_embeddings(input_embeds)
|
516 |
+
|
517 |
+
out_embeds = DecoupledLinear(
|
518 |
+
max_original_id=initial_tokenizer_len - 1,
|
519 |
+
additional_out_features=len(self.special_tokens),
|
520 |
+
_weight=self.lang_model.get_output_embeddings().weight,
|
521 |
+
_bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
|
522 |
+
)
|
523 |
+
if hasattr(out_embeds, "additional_fc"):
|
524 |
+
out_embeds.additional_fc.weight.data.normal_(
|
525 |
+
mean=0.0,
|
526 |
+
std=self.lang_model.config.initializer_range
|
527 |
+
if hasattr(self.lang_model.config, "initializer_range")
|
528 |
+
else 0.02,
|
529 |
+
)
|
530 |
+
self.lang_model.set_output_embeddings(out_embeds)
|
531 |
+
|
532 |
+
# gradient checkpointing
|
533 |
+
self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
|
534 |
+
|
535 |
+
def forward(
|
536 |
+
self,
|
537 |
+
vision_x: Optional[torch.Tensor],
|
538 |
+
lang_x: torch.Tensor,
|
539 |
+
attention_mask: Optional[torch.Tensor] = None,
|
540 |
+
labels: Optional[torch.Tensor] = None,
|
541 |
+
past_key_values: Optional[
|
542 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
543 |
+
] = None,
|
544 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
545 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
546 |
+
use_cache: Optional[bool] = False,
|
547 |
+
**kwargs,
|
548 |
+
):
|
549 |
+
"""
|
550 |
+
Args:
|
551 |
+
vision_x: Vision input
|
552 |
+
shape (B, T_img, F, C, H, W) with F=1
|
553 |
+
only F = 1 is supported (single-frame videos)
|
554 |
+
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
|
555 |
+
only the first number of media tokens in lang_x are used
|
556 |
+
lang_x: Language input ids, with media tokens denoting where
|
557 |
+
visual media should be inserted.
|
558 |
+
shape (B, T_txt)
|
559 |
+
attention_mask: Attention mask. Defaults to None.
|
560 |
+
labels: Labels. Defaults to None.
|
561 |
+
shape (B, T_txt)
|
562 |
+
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
|
563 |
+
list of length = number of decoder layers in the LM
|
564 |
+
exact implementation depends on LM, see Hugging Face docs
|
565 |
+
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
|
566 |
+
shape (B, T_txt)
|
567 |
+
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
|
568 |
+
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
|
569 |
+
If True, includes key_values, media_locations, and vision_tokens in the output.
|
570 |
+
"""
|
571 |
+
assert not (past_vision_tokens is None) ^ (
|
572 |
+
past_media_locations is None
|
573 |
+
), "past_vision_tokens and past_media_locations must both be None or both be not None"
|
574 |
+
|
575 |
+
# convert pixels to vision tokens
|
576 |
+
if vision_x is not None:
|
577 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
578 |
+
vision_tokens = self.vision_tokenizer(vision_features)
|
579 |
+
else:
|
580 |
+
vision_tokens = None
|
581 |
+
|
582 |
+
# fuse the vision and language tokens
|
583 |
+
new_inputs = self._prepare_inputs_for_forward(
|
584 |
+
vision_tokens=vision_tokens,
|
585 |
+
lang_x=lang_x,
|
586 |
+
attention_mask=attention_mask,
|
587 |
+
labels=labels,
|
588 |
+
past_key_values=past_key_values,
|
589 |
+
past_media_locations=past_media_locations,
|
590 |
+
padding_side="right",
|
591 |
+
past_vision_tokens=past_vision_tokens,
|
592 |
+
)
|
593 |
+
output = self.lang_model(
|
594 |
+
**new_inputs,
|
595 |
+
use_cache=use_cache,
|
596 |
+
past_key_values=past_key_values,
|
597 |
+
**kwargs,
|
598 |
+
)
|
599 |
+
|
600 |
+
# postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
|
601 |
+
# or to add the past_vision_tokens and past_media_locations to the output
|
602 |
+
output = self._postprocess_outputs_from_forward(
|
603 |
+
output=output,
|
604 |
+
lang_x=lang_x,
|
605 |
+
vision_tokens=vision_tokens,
|
606 |
+
use_cache=use_cache,
|
607 |
+
past_vision_tokens=past_vision_tokens,
|
608 |
+
past_media_locations=past_media_locations,
|
609 |
+
)
|
610 |
+
|
611 |
+
# postforward hooks
|
612 |
+
self._post_forward_hook()
|
613 |
+
return output
|
614 |
+
|
615 |
+
def _encode_vision_x_anyres(self, samples, device):
|
616 |
+
image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
|
617 |
+
image_sizes = samples["image_size"]
|
618 |
+
|
619 |
+
# concate list of patches into one big patch for any res encoding.
|
620 |
+
images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
|
621 |
+
image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
|
622 |
+
image = image.to(device)
|
623 |
+
|
624 |
+
with torch.no_grad():
|
625 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
626 |
+
image_embeds = self.vision_encoder.trunk.forward_features(image)
|
627 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
628 |
+
image_embeds = self.vision_encoder(image).last_hidden_state
|
629 |
+
else:
|
630 |
+
image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
|
631 |
+
|
632 |
+
if isinstance(self.vision_encoder, CLIPVisionModel):
|
633 |
+
base_img_size = self.vision_encoder.config.image_size
|
634 |
+
else:
|
635 |
+
base_img_size = self.vision_encoder.image_size[0]
|
636 |
+
|
637 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
638 |
+
grid_size = self.vision_encoder.trunk.patch_embed.grid_size
|
639 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
640 |
+
grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
|
641 |
+
grid_size = (grid_size_base, grid_size_base)
|
642 |
+
else:
|
643 |
+
grid_size = self.vision_encoder.grid_size
|
644 |
+
height, width = grid_size
|
645 |
+
|
646 |
+
if not image_embeds.shape[1] == height * width:
|
647 |
+
assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
|
648 |
+
image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
|
649 |
+
n_vis_token_per_patch = image_embeds.shape[1]
|
650 |
+
|
651 |
+
# Split encoded patches and merge patch features
|
652 |
+
# 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
|
653 |
+
split_sizes = [image.shape[0] for image in images]
|
654 |
+
image_embeds = torch.split(image_embeds, split_sizes, dim=0)
|
655 |
+
# 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
|
656 |
+
new_image_embeds = []
|
657 |
+
patch_attn_masks = []
|
658 |
+
max_n_img_token = -1
|
659 |
+
for idx, patch_embeds in enumerate(image_embeds):
|
660 |
+
if patch_embeds.shape[0] > 1:
|
661 |
+
# 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
|
662 |
+
base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
|
663 |
+
patch_embeds = patch_embeds[1:]
|
664 |
+
|
665 |
+
assert height * width == base_patch_embeds.shape[0]
|
666 |
+
|
667 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
|
668 |
+
[[base_img_size,base_img_size*2],
|
669 |
+
[base_img_size*2,base_img_size],
|
670 |
+
[base_img_size*2,base_img_size*2],
|
671 |
+
[base_img_size*3,base_img_size],
|
672 |
+
[base_img_size,base_img_size*3]],
|
673 |
+
base_img_size) # Hardcoded grid_pinpoints.
|
674 |
+
patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
|
675 |
+
|
676 |
+
patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
|
677 |
+
patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
|
678 |
+
# TODO: add an option that return masked patch_embeds instead of trimmed.
|
679 |
+
patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
|
680 |
+
if hasattr(self, 'image_newline'):
|
681 |
+
patch_embeds = torch.cat((
|
682 |
+
patch_embeds,
|
683 |
+
self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
|
684 |
+
), dim=-1)
|
685 |
+
if self.anyres_patch_sampling:
|
686 |
+
patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
|
687 |
+
patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
|
688 |
+
assert patch_attn_mask is not None
|
689 |
+
patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
|
690 |
+
patch_attn_mask = patch_attn_mask.flatten(0, 1)
|
691 |
+
patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
|
692 |
+
patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
|
693 |
+
else:
|
694 |
+
patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
|
695 |
+
patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
|
696 |
+
else:
|
697 |
+
patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
|
698 |
+
patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
|
699 |
+
if hasattr(self, 'image_newline'):
|
700 |
+
patch_embeds = torch.cat((
|
701 |
+
patch_embeds,
|
702 |
+
self.image_newline[None]
|
703 |
+
), dim=0)
|
704 |
+
if not self.anyres_patch_sampling:
|
705 |
+
max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
|
706 |
+
|
707 |
+
new_image_embeds.append(patch_embeds)
|
708 |
+
patch_attn_masks.append(patch_attn_mask)
|
709 |
+
|
710 |
+
if self.anyres_patch_sampling:
|
711 |
+
# Return individual patches for independent token downsampling.
|
712 |
+
return new_image_embeds, patch_attn_masks
|
713 |
+
|
714 |
+
# 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
|
715 |
+
image_embeds = []
|
716 |
+
image_atts = []
|
717 |
+
for image_embed in new_image_embeds:
|
718 |
+
n_img_token = image_embed.shape[0]
|
719 |
+
img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
|
720 |
+
if n_img_token < max_n_img_token:
|
721 |
+
padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
|
722 |
+
padded_embed[:n_img_token, :] = image_embed
|
723 |
+
img_attn[n_img_token:] = 0 # Mask out the padded entries.
|
724 |
+
else:
|
725 |
+
padded_embed = image_embed
|
726 |
+
image_embeds.append(padded_embed)
|
727 |
+
image_atts.append(img_attn)
|
728 |
+
image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
|
729 |
+
image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
|
730 |
+
# TODO: reshape image_embeds and image_atts to "b T F v d"
|
731 |
+
image_embeds = image_embeds[:, None, None, :, :]
|
732 |
+
# image_atts = image_atts[:, None, None, :, :]
|
733 |
+
|
734 |
+
return image_embeds, image_atts
|
735 |
+
|
736 |
+
def _encode_vision_x(self, vision_x: torch.Tensor):
|
737 |
+
"""
|
738 |
+
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
|
739 |
+
Args:
|
740 |
+
vision_x: Vision input
|
741 |
+
shape (B, T_img, F, C, H, W)
|
742 |
+
Images in the same chunk are collated along T_img, and frames are collated along F
|
743 |
+
Currently only F=1 is supported (single-frame videos)
|
744 |
+
|
745 |
+
rearrange code based on https://github.com/dhansmair/flamingo-mini
|
746 |
+
"""
|
747 |
+
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
|
748 |
+
b, T, F = vision_x.shape[:3]
|
749 |
+
|
750 |
+
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
|
751 |
+
with torch.no_grad():
|
752 |
+
if self.vision_encoder.__class__.__name__ == "TimmModel":
|
753 |
+
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
|
754 |
+
elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
|
755 |
+
vision_x = self.vision_encoder(vision_x).last_hidden_state
|
756 |
+
else:
|
757 |
+
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
|
758 |
+
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
|
759 |
+
return vision_x
|
760 |
+
|
761 |
+
def _concat_vision_cache(
|
762 |
+
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
|
763 |
+
):
|
764 |
+
"""
|
765 |
+
Helper function to include the past vision tokens and past media locations in the output.
|
766 |
+
"""
|
767 |
+
if use_cache:
|
768 |
+
if past_media_locations is not None and past_vision_tokens is not None:
|
769 |
+
if vision_tokens is not None:
|
770 |
+
updated_vision_tokens = torch.cat(
|
771 |
+
[
|
772 |
+
past_vision_tokens,
|
773 |
+
vision_tokens,
|
774 |
+
],
|
775 |
+
dim=1,
|
776 |
+
)
|
777 |
+
else:
|
778 |
+
updated_vision_tokens = past_vision_tokens
|
779 |
+
updated_media_locations = torch.cat(
|
780 |
+
[
|
781 |
+
past_media_locations,
|
782 |
+
lang_x == self.media_token_id,
|
783 |
+
],
|
784 |
+
dim=1,
|
785 |
+
)
|
786 |
+
else:
|
787 |
+
updated_vision_tokens = vision_tokens
|
788 |
+
updated_media_locations = lang_x == self.media_token_id
|
789 |
+
|
790 |
+
else:
|
791 |
+
updated_vision_tokens = None
|
792 |
+
updated_media_locations = None
|
793 |
+
|
794 |
+
return updated_vision_tokens, updated_media_locations
|
795 |
+
|
796 |
+
def generate(
|
797 |
+
self,
|
798 |
+
vision_x: torch.Tensor,
|
799 |
+
lang_x: torch.Tensor,
|
800 |
+
attention_mask: torch.Tensor = None,
|
801 |
+
past_key_values: Optional[
|
802 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
803 |
+
] = None,
|
804 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
805 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
806 |
+
**kwargs,
|
807 |
+
):
|
808 |
+
"""
|
809 |
+
Generate text conditioned on vision and language inputs.
|
810 |
+
Args:
|
811 |
+
vision_x (torch.Tensor): Vision input
|
812 |
+
shape (B, T_img, F, C, H, W)
|
813 |
+
see documentation for forward
|
814 |
+
lang_x (torch.Tensor): Language input
|
815 |
+
shape (B, T_txt)
|
816 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
817 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models.
|
818 |
+
Returns:
|
819 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
820 |
+
"""
|
821 |
+
num_beams = kwargs.pop("num_beams", 1)
|
822 |
+
|
823 |
+
# convert pixels to vision tokens
|
824 |
+
if vision_x is not None:
|
825 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
826 |
+
vision_tokens = self.vision_tokenizer(vision_features)
|
827 |
+
else:
|
828 |
+
vision_tokens = None
|
829 |
+
|
830 |
+
# fuse the vision and language tokens
|
831 |
+
# for xattn, vision_x and media_location are repeat_interleaved s.t.
|
832 |
+
# the total batch size is B * num_beams
|
833 |
+
new_inputs = self._prepare_inputs_for_forward(
|
834 |
+
vision_tokens=vision_tokens,
|
835 |
+
lang_x=lang_x,
|
836 |
+
attention_mask=attention_mask,
|
837 |
+
past_key_values=past_key_values,
|
838 |
+
past_media_locations=past_media_locations,
|
839 |
+
past_vision_tokens=past_vision_tokens,
|
840 |
+
padding_side="left",
|
841 |
+
num_beams=num_beams,
|
842 |
+
)
|
843 |
+
output = self.lang_model.generate(
|
844 |
+
**new_inputs,
|
845 |
+
past_key_values=past_key_values,
|
846 |
+
num_beams=num_beams,
|
847 |
+
use_cache=True,
|
848 |
+
**kwargs,
|
849 |
+
)
|
850 |
+
self._post_forward_hook()
|
851 |
+
return output
|
852 |
+
|
853 |
+
@property
|
854 |
+
def num_trainable_params(self):
|
855 |
+
"""Print the number of trainable parameters"""
|
856 |
+
return num_params(self, filter_to_trainable=True)
|
857 |
+
|
858 |
+
def set_trainable(self):
|
859 |
+
"""
|
860 |
+
Freeze appropriate parameters in the model.
|
861 |
+
"""
|
862 |
+
raise NotImplementedError
|
863 |
+
|
864 |
+
def group_params_by_weight_decay(self):
|
865 |
+
"""
|
866 |
+
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
|
867 |
+
"""
|
868 |
+
params_with_wd, params_without_wd = [], []
|
869 |
+
for n, p in self.named_parameters():
|
870 |
+
if p.requires_grad:
|
871 |
+
if self._should_apply_weight_decay(n):
|
872 |
+
params_with_wd.append(p)
|
873 |
+
else:
|
874 |
+
params_without_wd.append(p)
|
875 |
+
return params_with_wd, params_without_wd
|
876 |
+
|
877 |
+
def _should_apply_weight_decay(self, parameter_name):
|
878 |
+
"""
|
879 |
+
Return whether weight decay should be applied to a parameter.
|
880 |
+
"""
|
881 |
+
raise NotImplementedError
|
882 |
+
|
883 |
+
@property
|
884 |
+
def special_tokens(self):
|
885 |
+
"""
|
886 |
+
Returns a dict mapping from the attribute name of a special token to its string format,
|
887 |
+
e.g. "media_token": "<image>"
|
888 |
+
"""
|
889 |
+
assert (
|
890 |
+
"media_token" in self._special_tokens
|
891 |
+
), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
|
892 |
+
return self._special_tokens
|
893 |
+
|
894 |
+
@property
|
895 |
+
def special_token_ids(self):
|
896 |
+
"""
|
897 |
+
Returns a list of the special token ids
|
898 |
+
"""
|
899 |
+
return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
|
900 |
+
|
901 |
+
def set_special_token_ids(self, string_to_ids):
|
902 |
+
"""
|
903 |
+
Args:
|
904 |
+
string_to_ids (dict): mapping from token string to id
|
905 |
+
"""
|
906 |
+
assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
|
907 |
+
for att_name, token_str in self.special_tokens.items():
|
908 |
+
token_id = string_to_ids[token_str]
|
909 |
+
setattr(self, f"{att_name}_id", token_id)
|
910 |
+
setattr(self.lang_model, f"{att_name}_id", token_id)
|
911 |
+
|
912 |
+
def init_gradient_checkpointing(self):
|
913 |
+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
914 |
+
checkpoint_wrapper,
|
915 |
+
CheckpointWrapper,
|
916 |
+
CheckpointImpl,
|
917 |
+
apply_activation_checkpointing,
|
918 |
+
)
|
919 |
+
from functools import partial
|
920 |
+
|
921 |
+
non_reentrant_wrapper = partial(
|
922 |
+
checkpoint_wrapper,
|
923 |
+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
924 |
+
)
|
925 |
+
apply_activation_checkpointing(
|
926 |
+
self,
|
927 |
+
checkpoint_wrapper_fn=non_reentrant_wrapper,
|
928 |
+
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
|
929 |
+
and not isinstance(m, CheckpointWrapper),
|
930 |
+
)
|
931 |
+
|
932 |
+
@dataclass
|
933 |
+
class VLMOutputWithPast(CausalLMOutputWithPast):
|
934 |
+
"""
|
935 |
+
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
|
936 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
937 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
938 |
+
"""
|
939 |
+
|
940 |
+
past_media_locations: Optional[torch.Tensor] = None
|
941 |
+
past_vision_tokens: Optional[torch.Tensor] = None
|
942 |
+
|
943 |
+
|
944 |
+
def exists(val):
|
945 |
+
return val is not None
|
946 |
+
|
947 |
+
|
948 |
+
def FeedForward(dim, mult=4):
|
949 |
+
inner_dim = int(dim * mult)
|
950 |
+
return nn.Sequential(
|
951 |
+
nn.LayerNorm(dim),
|
952 |
+
nn.Linear(dim, inner_dim, bias=False),
|
953 |
+
nn.GELU(),
|
954 |
+
nn.Linear(inner_dim, dim, bias=False),
|
955 |
+
)
|
956 |
+
|
957 |
+
class VLMWithLanguageStream(VLM):
|
958 |
+
"""
|
959 |
+
VLM that fuses modalities by inserting vision tokens directly into the language stream.
|
960 |
+
"""
|
961 |
+
|
962 |
+
def __init__(
|
963 |
+
self,
|
964 |
+
vision_encoder: nn.Module,
|
965 |
+
vision_tokenizer: nn.Module,
|
966 |
+
lang_model: nn.Module,
|
967 |
+
initial_tokenizer_len: int,
|
968 |
+
pad_token_id: int,
|
969 |
+
decoder_layers_attr_name: str = None,
|
970 |
+
gradient_checkpointing: bool = False,
|
971 |
+
):
|
972 |
+
super().__init__(
|
973 |
+
vision_encoder=vision_encoder,
|
974 |
+
vision_tokenizer=vision_tokenizer,
|
975 |
+
lang_model=lang_model,
|
976 |
+
initial_tokenizer_len=initial_tokenizer_len,
|
977 |
+
pad_token_id=pad_token_id,
|
978 |
+
gradient_checkpointing=gradient_checkpointing,
|
979 |
+
)
|
980 |
+
self.decoder_layers_attr_name = decoder_layers_attr_name
|
981 |
+
if decoder_layers_attr_name is not None:
|
982 |
+
for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
|
983 |
+
block._use_gradient_checkpointing = gradient_checkpointing
|
984 |
+
|
985 |
+
def _prepare_inputs_for_forward(
|
986 |
+
self,
|
987 |
+
vision_tokens: torch.Tensor,
|
988 |
+
lang_x: torch.Tensor,
|
989 |
+
attention_mask: torch.Tensor,
|
990 |
+
labels: torch.Tensor = None,
|
991 |
+
past_key_values=None,
|
992 |
+
past_media_locations: torch.Tensor = None,
|
993 |
+
past_vision_tokens: torch.Tensor = None,
|
994 |
+
padding_side: str = "left",
|
995 |
+
num_beams: int = 1,
|
996 |
+
):
|
997 |
+
"""
|
998 |
+
Insert the vision tokens directly into the language stream/
|
999 |
+
This requires us to modify the input_ids, attention_mask, and labels.
|
1000 |
+
"""
|
1001 |
+
if past_key_values is not None:
|
1002 |
+
past_len = past_key_values[0][0].shape[2]
|
1003 |
+
assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
|
1004 |
+
"Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
|
1005 |
+
+ "Check that you've expanded the attention mask to account for past image tokens."
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
if vision_tokens is None:
|
1009 |
+
return {
|
1010 |
+
"input_ids": lang_x,
|
1011 |
+
"attention_mask": attention_mask,
|
1012 |
+
"labels": labels,
|
1013 |
+
}
|
1014 |
+
|
1015 |
+
# get the language embeddings
|
1016 |
+
lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
|
1017 |
+
|
1018 |
+
# build up the multimodal embeddings
|
1019 |
+
B = lang_x.shape[0]
|
1020 |
+
has_labels = labels is not None
|
1021 |
+
multimodal_embeds = []
|
1022 |
+
multimodal_attention_mask = []
|
1023 |
+
multimodal_labels = [] if has_labels else None
|
1024 |
+
for i in range(B):
|
1025 |
+
# get index of <image> tokens in lang_x[i]
|
1026 |
+
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
|
1027 |
+
|
1028 |
+
if len(image_token_idxs) == 0:
|
1029 |
+
multimodal_embeds.append(lang_embeds[i].clone())
|
1030 |
+
multimodal_attention_mask.append(attention_mask[i].clone())
|
1031 |
+
if has_labels:
|
1032 |
+
multimodal_labels.append(labels[i].clone())
|
1033 |
+
continue
|
1034 |
+
|
1035 |
+
# # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
|
1036 |
+
# for j, img_idx in enumerate(image_token_idxs):
|
1037 |
+
# image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j
|
1038 |
+
|
1039 |
+
# loop through the image_token_idxs and insert the vision tokens
|
1040 |
+
new_embed = lang_embeds[i].clone()
|
1041 |
+
new_attention_mask = (
|
1042 |
+
attention_mask[i].clone() if attention_mask is not None else None
|
1043 |
+
)
|
1044 |
+
if has_labels:
|
1045 |
+
new_label = labels[i].clone()
|
1046 |
+
|
1047 |
+
for img_num, img_idx in enumerate(image_token_idxs):
|
1048 |
+
new_embed = torch.cat(
|
1049 |
+
(
|
1050 |
+
new_embed[:img_idx],
|
1051 |
+
vision_tokens[i][img_num],
|
1052 |
+
new_embed[img_idx + self.num_tokens_per_vis :],
|
1053 |
+
),
|
1054 |
+
dim=0,
|
1055 |
+
)
|
1056 |
+
new_attention_mask = torch.cat(
|
1057 |
+
(
|
1058 |
+
new_attention_mask[:img_idx],
|
1059 |
+
torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
|
1060 |
+
attention_mask.device
|
1061 |
+
),
|
1062 |
+
new_attention_mask[img_idx + self.num_tokens_per_vis :],
|
1063 |
+
),
|
1064 |
+
dim=0,
|
1065 |
+
)
|
1066 |
+
if has_labels:
|
1067 |
+
new_label = torch.cat(
|
1068 |
+
(
|
1069 |
+
new_label[:img_idx],
|
1070 |
+
torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
|
1071 |
+
labels.device
|
1072 |
+
)
|
1073 |
+
* -100,
|
1074 |
+
new_label[img_idx + self.num_tokens_per_vis :],
|
1075 |
+
),
|
1076 |
+
dim=0,
|
1077 |
+
)
|
1078 |
+
multimodal_embeds.append(new_embed)
|
1079 |
+
multimodal_attention_mask.append(new_attention_mask)
|
1080 |
+
if has_labels:
|
1081 |
+
multimodal_labels.append(new_label)
|
1082 |
+
|
1083 |
+
# stack
|
1084 |
+
multimodal_embeds = stack_with_padding(
|
1085 |
+
multimodal_embeds,
|
1086 |
+
padding_value=self.pad_token_id,
|
1087 |
+
padding_side=padding_side,
|
1088 |
+
)
|
1089 |
+
multimodal_attention_mask = stack_with_padding(
|
1090 |
+
multimodal_attention_mask,
|
1091 |
+
padding_value=0,
|
1092 |
+
padding_side=padding_side,
|
1093 |
+
)
|
1094 |
+
if has_labels:
|
1095 |
+
multimodal_labels = stack_with_padding(
|
1096 |
+
multimodal_labels,
|
1097 |
+
padding_value=-100,
|
1098 |
+
padding_side=padding_side,
|
1099 |
+
)
|
1100 |
+
|
1101 |
+
return {
|
1102 |
+
"inputs_embeds": multimodal_embeds,
|
1103 |
+
"attention_mask": multimodal_attention_mask,
|
1104 |
+
"labels": multimodal_labels,
|
1105 |
+
}
|
1106 |
+
|
1107 |
+
def _postprocess_outputs_from_forward(
|
1108 |
+
self,
|
1109 |
+
output: CausalLMOutputWithPast,
|
1110 |
+
lang_x: torch.Tensor,
|
1111 |
+
vision_tokens: torch.Tensor,
|
1112 |
+
past_vision_tokens: torch.Tensor,
|
1113 |
+
past_media_locations: torch.Tensor,
|
1114 |
+
use_cache: bool = False,
|
1115 |
+
):
|
1116 |
+
# Include the past vision tokens and past media locations in the output
|
1117 |
+
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
|
1118 |
+
lang_x=lang_x,
|
1119 |
+
vision_tokens=vision_tokens,
|
1120 |
+
past_vision_tokens=past_vision_tokens,
|
1121 |
+
past_media_locations=past_media_locations,
|
1122 |
+
use_cache=use_cache,
|
1123 |
+
)
|
1124 |
+
|
1125 |
+
# return logits that are the same shape as the original input_ids
|
1126 |
+
logits = output.logits
|
1127 |
+
batch_logits = []
|
1128 |
+
B, T_txt = lang_x.shape
|
1129 |
+
for i in range(B):
|
1130 |
+
sequence_logits = []
|
1131 |
+
logits_j = 0
|
1132 |
+
for j in range(T_txt):
|
1133 |
+
if lang_x[i, j] != self.media_token_id:
|
1134 |
+
sequence_logits.append(logits[i, logits_j])
|
1135 |
+
logits_j += 1
|
1136 |
+
else:
|
1137 |
+
# append the logit for the first image token, then skip over the rest
|
1138 |
+
# note: the model actually learns to predict <im_patch>, not <image>
|
1139 |
+
sequence_logits.append(logits[i, logits_j])
|
1140 |
+
logits_j += self.num_tokens_per_vis
|
1141 |
+
sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
|
1142 |
+
batch_logits.append(sequence_logits)
|
1143 |
+
|
1144 |
+
batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
|
1145 |
+
# The final logits shape should be the same as the original input_ids shape
|
1146 |
+
assert batch_logits.shape[:2] == (B, T_txt)
|
1147 |
+
|
1148 |
+
# assemble the output
|
1149 |
+
output = VLMOutputWithPast(
|
1150 |
+
loss=output.loss,
|
1151 |
+
logits=batch_logits,
|
1152 |
+
past_key_values=output.past_key_values,
|
1153 |
+
hidden_states=output.hidden_states,
|
1154 |
+
attentions=output.attentions,
|
1155 |
+
past_media_locations=updated_media_locations,
|
1156 |
+
past_vision_tokens=updated_vision_tokens,
|
1157 |
+
)
|
1158 |
+
|
1159 |
+
return output
|
1160 |
+
|
1161 |
+
def _post_forward_hook(self):
|
1162 |
+
pass
|
1163 |
+
|
1164 |
+
|
1165 |
+
@property
|
1166 |
+
def num_params_per_module(self):
|
1167 |
+
"""Print the number of parameters per module in the model"""
|
1168 |
+
return "\n".join(
|
1169 |
+
[
|
1170 |
+
f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
|
1171 |
+
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
|
1172 |
+
f"Language model: {num_params(self.lang_model):,} parameters",
|
1173 |
+
]
|
1174 |
+
)
|
1175 |
+
|
1176 |
+
@property
|
1177 |
+
def num_trainable_params_per_module(self):
|
1178 |
+
"""Print the number of trainable parameters per module in the model"""
|
1179 |
+
return "\n".join(
|
1180 |
+
[
|
1181 |
+
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
|
1182 |
+
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
|
1183 |
+
f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
|
1184 |
+
]
|
1185 |
+
)
|
1186 |
+
|
1187 |
+
|
1188 |
+
class Kosmos(VLMWithLanguageStream):
|
1189 |
+
def __init__(
|
1190 |
+
self,
|
1191 |
+
vision_encoder: nn.Module,
|
1192 |
+
vision_tokenizer: nn.Module,
|
1193 |
+
lang_model: nn.Module,
|
1194 |
+
initial_tokenizer_len: int,
|
1195 |
+
pad_token_id: int,
|
1196 |
+
decoder_layers_attr_name: str = None,
|
1197 |
+
gradient_checkpointing: bool = False,
|
1198 |
+
):
|
1199 |
+
"""
|
1200 |
+
Args:
|
1201 |
+
vision_encoder (nn.Module): HF CLIPModel
|
1202 |
+
lang_encoder (nn.Module): HF causal language model
|
1203 |
+
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
|
1204 |
+
initial_tokenizer_len (int): size of the tokenizer vocab
|
1205 |
+
padding_token_id (int): id of the padding token. None if no padding token; then a padding token
|
1206 |
+
will be inserted into self.special_tokens, which factory.py fills after creating new tokens
|
1207 |
+
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
|
1208 |
+
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
|
1209 |
+
"""
|
1210 |
+
self._special_tokens = {
|
1211 |
+
"media_token": "<image>",
|
1212 |
+
"image_placeholder_token": "<image placeholder>",
|
1213 |
+
"end_of_trunk_token": "<|endofchunk|>"
|
1214 |
+
}
|
1215 |
+
super().__init__(
|
1216 |
+
vision_encoder=vision_encoder,
|
1217 |
+
vision_tokenizer=vision_tokenizer,
|
1218 |
+
lang_model=lang_model,
|
1219 |
+
initial_tokenizer_len=initial_tokenizer_len,
|
1220 |
+
gradient_checkpointing=gradient_checkpointing,
|
1221 |
+
decoder_layers_attr_name=decoder_layers_attr_name,
|
1222 |
+
pad_token_id=pad_token_id
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
# def set_trainable(self):
|
1226 |
+
# """
|
1227 |
+
# Unfreeze everything except the vision_encoder
|
1228 |
+
# """
|
1229 |
+
# self.requires_grad_(True)
|
1230 |
+
# self.vision_encoder.requires_grad_(False)
|
1231 |
+
|
1232 |
+
def set_trainable(self, unfreeze_vision_encoder: bool = False):
|
1233 |
+
"""
|
1234 |
+
Unfreeze everything except the vision_encoder
|
1235 |
+
"""
|
1236 |
+
self.requires_grad_(True)
|
1237 |
+
self.vision_encoder.requires_grad_(unfreeze_vision_encoder)
|
1238 |
+
|
1239 |
+
def _should_apply_weight_decay(self, parameter_name):
|
1240 |
+
"""
|
1241 |
+
Kosmos applies 0.01 weight deacy to everything
|
1242 |
+
"""
|
1243 |
+
return True
|
1244 |
+
|
1245 |
+
def generate(
|
1246 |
+
self,
|
1247 |
+
vision_x: torch.Tensor,
|
1248 |
+
lang_x: torch.Tensor,
|
1249 |
+
attention_mask: torch.Tensor = None,
|
1250 |
+
past_key_values: Optional[
|
1251 |
+
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
|
1252 |
+
] = None,
|
1253 |
+
past_media_locations: Optional[torch.Tensor] = None,
|
1254 |
+
past_vision_tokens: Optional[torch.Tensor] = None,
|
1255 |
+
**kwargs
|
1256 |
+
):
|
1257 |
+
"""
|
1258 |
+
Generate text conditioned on vision and language inputs.
|
1259 |
+
Args:
|
1260 |
+
vision_x (torch.Tensor): Vision input
|
1261 |
+
shape (B, T_img, F, C, H, W)
|
1262 |
+
see documentation for forward
|
1263 |
+
lang_x (torch.Tensor): Language input
|
1264 |
+
shape (B, T_txt)
|
1265 |
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
1266 |
+
**kwargs: see generate documentation in Hugging Face CausalLM models.
|
1267 |
+
Returns:
|
1268 |
+
torch.Tensor: lang_x with generated tokens appended to it
|
1269 |
+
"""
|
1270 |
+
num_beams = kwargs.pop("num_beams", 1)
|
1271 |
+
|
1272 |
+
# convert pixels to vision tokens
|
1273 |
+
if vision_x is not None:
|
1274 |
+
vision_features = self._encode_vision_x(vision_x=vision_x)
|
1275 |
+
vision_tokens = self.vision_tokenizer(vision_features)
|
1276 |
+
else:
|
1277 |
+
vision_tokens = None
|
1278 |
+
|
1279 |
+
# fuse the vision and language tokens
|
1280 |
+
# for xattn, vision_x and media_location are repeat_interleaved s.t.
|
1281 |
+
# the total batch size is B * num_beams
|
1282 |
+
new_inputs = self._prepare_inputs_for_forward(
|
1283 |
+
vision_tokens=vision_tokens,
|
1284 |
+
lang_x=lang_x,
|
1285 |
+
attention_mask=attention_mask,
|
1286 |
+
past_key_values=past_key_values,
|
1287 |
+
past_media_locations=past_media_locations,
|
1288 |
+
past_vision_tokens=past_vision_tokens,
|
1289 |
+
padding_side="left",
|
1290 |
+
num_beams=num_beams,
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
if transformers.__version__ == '4.41.0.dev0':
|
1294 |
+
output = self.lang_model.generate(
|
1295 |
+
**new_inputs,
|
1296 |
+
num_beams=num_beams,
|
1297 |
+
use_cache=True,
|
1298 |
+
eos_token_id=self.end_of_trunk_token_id,
|
1299 |
+
**kwargs)
|
1300 |
+
else:
|
1301 |
+
output = self.lang_model.generate(
|
1302 |
+
**new_inputs,
|
1303 |
+
past_key_values=past_key_values,
|
1304 |
+
num_beams=num_beams,
|
1305 |
+
use_cache=True,
|
1306 |
+
eos_token_id=self.end_of_trunk_token_id,
|
1307 |
+
**kwargs)
|
1308 |
+
return output
|