llava-llama-3-8b / README.md
musashihinck's picture
Adding syntax highlighting to code example
36ef144 verified
|
raw
history blame
No virus
4.99 kB
---
license: other
license_name: intel-research-use-license
license_link: LICENSE
---
# LLaVA-Llama3 Model Card
_This model card corresponds to the instruction tuned 8B version of the model with the CLIP-based vision encoder._
## Overview
`llava-llama-3-8b` is a large multimodal model (LMM) trained using the [LLaVA-v1.5 framework](https://arxiv.org/abs/2310.03744) with the 8-billion parameter [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model as language backbone.
## Uses
The model has been finetuned for multimodal benchmark evaluations, but can also be used as a multimodal chatbot.
## Bias, Risks, and Limitations
This model has not been assessed for harm or biases, and should not be used for sensitive applications where it may cause harm.
## Training Details
The `llava-llama-3-8b` model was trained on a 4 node cluster with a total of 32 Gaudi 2 accelerators.
### Training Data
The model was trained using the LLaVA-v1.5 data mixture.
This is listed as follows:
- 558K filtered image-text pairs from LAION/CC/SBU, captioned by BLIP.
- 158K GPT-generated multimodal instruction-following data.
- 450K academic-task-oriented VQA data mixture.
- 40K ShareGPT data.
## Evaluation
| Model | Metrics |
|----------|------------------|
| ScienceQA| 72.9797 |
| MMVet | 31.9725 |
| llavaw | 56.9/61.9/73.6/65.7 |
| Pope Acc | 87.33, F1 86.5 |
| GQA | 60.6138 |
| MMVP | 36 |
## License
The weights are released under the Intel Research Use License Agreement (see LICENSE file)
All usage code is licensed Apache 2.0
## Usage
Please note, we only provide the trained weights difference and do not provide a copy of the base [`meta-llama/Meta-Llama-3-8B-Instruct`](https://huggingface.co/meta-llama/Meta-Llama-3-8B) model. Any use of these weights requires a separate download of the base model.
```python
# Copyright 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import requests
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForPreTraining
import transformers
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def add_model_a_to_b(model_a, model_b):
state_dict_a = model_a.state_dict()
state_dict_b = model_b.state_dict()
# Ensure keys match before subtraction
if set(state_dict_a.keys()) != set(state_dict_b.keys()):
raise ValueError("Model state dicts do not have the same keys.")
for key in state_dict_a:
if state_dict_a[key].shape != state_dict_b[key].shape:
raise ValueError(f"Shape mismatch for key '{key}': {state_dict_a[key].shape} vs {state_dict_b[key].shape}")
# Subtract model_a's weights from model_b for the matching key
state_dict_b[key] = state_dict_b[key] + state_dict_a[key]
# Update model_b with the new weights
model_b.load_state_dict(state_dict_b)
output_checkpoint = "" # set if you don't want to merge every time
hf_checkpoint = "Intel/llava-llama-3-8b"
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained(hf_checkpoint)
model = AutoModelForPreTraining.from_pretrained(hf_checkpoint)
if model.language_model.model.embed_tokens.weight[-1].sum() == 0:
print("adding llama3 weights")
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="cpu",
)
llama3 = pipeline.model
add_model_a_to_b(llama3, model.language_model)
if output_checkpoint:
print("saving weights, so no adding is needed again")
model.save_pretrained(output_checkpoint)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
prompt = processor.tokenizer.apply_chat_template(
[{'role': 'user', 'content': "<image>\nWhat's the content of the image?"}],
tokenize=False,
add_generation_prompt=True
)
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
#original llava pads with mean, HF llava pads with zeros
image = expand2square(image, tuple(int(x*255) for x in processor.image_processor.image_mean))
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
# Generate
generate_ids = model.generate(**inputs, max_length=30)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output)
```