matthewlyleolson
commited on
Commit
•
3371fdb
1
Parent(s):
a678b33
Update README.md
Browse files
README.md
CHANGED
@@ -1,5 +1,138 @@
|
|
1 |
-
---
|
2 |
-
license: other
|
3 |
-
license_name: intel-research-use
|
4 |
-
license_link: LICENSE
|
5 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: other
|
3 |
+
license_name: intel-research-use-license
|
4 |
+
license_link: LICENSE
|
5 |
+
---
|
6 |
+
|
7 |
+
# LLaVA-Llama3 Model Card
|
8 |
+
|
9 |
+
_This model card corresponds to the instruction tuned 8B version of the model with the CLIP-based vision encoder._
|
10 |
+
|
11 |
+
|
12 |
+
## Overview
|
13 |
+
|
14 |
+
`llava-llama-3-8b` is a large multimodal model (LMM) trained using the [LLaVA-v1.5 framework](https://arxiv.org/abs/2310.03744) with the 8-billion parameter [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model as language backbone.
|
15 |
+
|
16 |
+
## Uses
|
17 |
+
|
18 |
+
The model has been finetuned for multimodal benchmark evaluations, but can also be used as a multimodal chatbot.
|
19 |
+
|
20 |
+
## Bias, Risks, and Limitations
|
21 |
+
|
22 |
+
This model has not been assessed for harm or biases, and should not be used for sensitive applications where it may cause harm.
|
23 |
+
|
24 |
+
## Training Details
|
25 |
+
|
26 |
+
The `llava-llama-3-8b` model was trained on a 4 node cluster with a total of 32 Gaudi 2 accelerators.
|
27 |
+
|
28 |
+
### Training Data
|
29 |
+
|
30 |
+
The model was trained using the LLaVA-v1.5 data mixture.
|
31 |
+
|
32 |
+
This is listed as follows:
|
33 |
+
|
34 |
+
- 558K filtered image-text pairs from LAION/CC/SBU, captioned by BLIP.
|
35 |
+
- 158K GPT-generated multimodal instruction-following data.
|
36 |
+
- 450K academic-task-oriented VQA data mixture.
|
37 |
+
- 40K ShareGPT data.
|
38 |
+
|
39 |
+
## Evaluation
|
40 |
+
|
41 |
+
| Model | Metrics |
|
42 |
+
|----------|------------------|
|
43 |
+
| ScienceQA| 72.9797 |
|
44 |
+
| MMVet | 31.9725 |
|
45 |
+
| llavaw | 56.9/61.9/73.6/65.7 |
|
46 |
+
| Pope Acc | 87.33, F1 86.5 |
|
47 |
+
| GQA | 60.6138 |
|
48 |
+
| MMVP | 36 |
|
49 |
+
|
50 |
+
## License
|
51 |
+
The weights are released under the Intel Research Use License Agreement (see LICENSE file)
|
52 |
+
All usage code is licensed Apache 2.0
|
53 |
+
|
54 |
+
## Usage
|
55 |
+
|
56 |
+
Please note, we only provide the trained weights difference and do not provide a copy of the base [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model. Any use of these weights requires a separate download of the base model.
|
57 |
+
|
58 |
+
```
|
59 |
+
# Copyright 2024 Intel Corporation
|
60 |
+
# SPDX-License-Identifier: Apache-2.0
|
61 |
+
|
62 |
+
import requests
|
63 |
+
import torch
|
64 |
+
from PIL import Image
|
65 |
+
from transformers import AutoProcessor, AutoModelForPreTraining
|
66 |
+
import transformers
|
67 |
+
|
68 |
+
def expand2square(pil_img, background_color):
|
69 |
+
width, height = pil_img.size
|
70 |
+
if width == height:
|
71 |
+
return pil_img
|
72 |
+
elif width > height:
|
73 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
74 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
75 |
+
return result
|
76 |
+
else:
|
77 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
78 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
79 |
+
return result
|
80 |
+
|
81 |
+
def add_model_a_to_b(model_a, model_b):
|
82 |
+
state_dict_a = model_a.state_dict()
|
83 |
+
state_dict_b = model_b.state_dict()
|
84 |
+
|
85 |
+
# Ensure keys match before subtraction
|
86 |
+
if set(state_dict_a.keys()) != set(state_dict_b.keys()):
|
87 |
+
raise ValueError("Model state dicts do not have the same keys.")
|
88 |
+
|
89 |
+
for key in state_dict_a:
|
90 |
+
if state_dict_a[key].shape != state_dict_b[key].shape:
|
91 |
+
raise ValueError(f"Shape mismatch for key '{key}': {state_dict_a[key].shape} vs {state_dict_b[key].shape}")
|
92 |
+
# Subtract model_a's weights from model_b for the matching key
|
93 |
+
state_dict_b[key] = state_dict_b[key] + state_dict_a[key]
|
94 |
+
|
95 |
+
# Update model_b with the new weights
|
96 |
+
model_b.load_state_dict(state_dict_b)
|
97 |
+
|
98 |
+
output_checkpoint = "" # set if you don't want to merge every time
|
99 |
+
hf_checkpoint = "Intel/llava-llama-3-8b-old"
|
100 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101 |
+
|
102 |
+
processor = AutoProcessor.from_pretrained(hf_checkpoint)
|
103 |
+
model = AutoModelForPreTraining.from_pretrained(hf_checkpoint)
|
104 |
+
if model.language_model.model.embed_tokens.weight[-1].sum() == 0:
|
105 |
+
print("adding llama3 weights")
|
106 |
+
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
|
107 |
+
pipeline = transformers.pipeline(
|
108 |
+
"text-generation",
|
109 |
+
model=model_id,
|
110 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
111 |
+
device_map="cpu",
|
112 |
+
)
|
113 |
+
llama3 = pipeline.model
|
114 |
+
add_model_a_to_b(llama3, model.language_model)
|
115 |
+
if output_checkpoint:
|
116 |
+
print("saving weights, so no adding is needed again")
|
117 |
+
model.save_pretrained(output_checkpoint)
|
118 |
+
|
119 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
120 |
+
model.to(device)
|
121 |
+
|
122 |
+
prompt = processor.tokenizer.apply_chat_template(
|
123 |
+
[{'role': 'user', 'content': "<image>\nWhat's the content of the image?"}],
|
124 |
+
tokenize=False,
|
125 |
+
add_generation_prompt=True
|
126 |
+
)
|
127 |
+
|
128 |
+
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
129 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
130 |
+
|
131 |
+
#original llava pads with mean, HF llava pads with zeros
|
132 |
+
image = expand2square(image, tuple(int(x*255) for x in processor.image_processor.image_mean))
|
133 |
+
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
|
134 |
+
# Generate
|
135 |
+
generate_ids = model.generate(**inputs, max_length=30)
|
136 |
+
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
137 |
+
print(output)
|
138 |
+
```
|