|
--- |
|
datasets: |
|
- liuhaotian/LLaVA-Pretrain |
|
- liuhaotian/LLaVA-Instruct-150K |
|
language: |
|
- en |
|
tags: |
|
- llava |
|
- phi |
|
license: mit |
|
library_name: transformers |
|
--- |
|
|
|
# LLaVA-3b |
|
|
|
<a target="_blank" href="https://colab.research.google.com/drive/1W7JQrFXwFunAY1XvS31mwC7mrXBgGD_M"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/> |
|
</a> |
|
|
|
## Model details |
|
|
|
LLaVA-3b is a model fine-tuned from [Dolphin 2.6 Phi](https://huggingface.co/cognitivecomputations/dolphin-2_6-phi-2) in a LLaVA fashion using vision tower from |
|
[SigLIP 400M](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384). There are a couple of things different from the original LLaVA architecture: |
|
|
|
1. Multiple image tokens. The multimodal projector generates embeddings of shape [5, 2560] instead of [1, 2560] for images. The idea is that using more tokens |
|
allows us to get more info from the image into the language model. |
|
2. The model uses the output from the latest layer of the vision encoder instead of the intermediate one. |
|
3. The context length during training was 1200 tokens, as the L4 GPUs I used didn't allow me to get more. |
|
|
|
As Dolphin 2.6 Phi, LLaVA-3b uses ChatML prompt format: |
|
|
|
``` |
|
<|im_start|>system |
|
You are Dolphin, a helpful AI assistant.<|im_end|> |
|
<|im_start|>user |
|
{prompt}<|im_end|> |
|
<|im_start|>assistant |
|
``` |
|
|
|
## How to use |
|
|
|
**Install dependencies** |
|
|
|
```bash |
|
!pip install -q open_clip_torch timm einops |
|
``` |
|
|
|
**Download modeling files** |
|
|
|
```python |
|
from huggingface_hub import hf_hub_download |
|
|
|
hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="configuration_llava.py", local_dir="./", force_download=True) |
|
hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="configuration_phi.py", local_dir="./", force_download=True) |
|
hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="modeling_llava.py", local_dir="./", force_download=True) |
|
hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="modeling_phi.py", local_dir="./", force_download=True) |
|
hf_hub_download(repo_id="visheratin/LLaVA-3b", filename="processing_llava.py", local_dir="./", force_download=True) |
|
``` |
|
|
|
**Create a model** |
|
|
|
```python |
|
from modeling_llava import LlavaForConditionalGeneration |
|
import torch |
|
|
|
model = LlavaForConditionalGeneration.from_pretrained("visheratin/LLaVA-3b", torch_dtype=torch.float16) |
|
model = model.to("cuda") |
|
``` |
|
|
|
**Create processors** |
|
|
|
```python |
|
from transformers import AutoTokenizer |
|
from processing_llava import LlavaProcessor, OpenCLIPImageProcessor |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("visheratin/LLaVA-3b") |
|
image_processor = OpenCLIPImageProcessor(model.config.preprocess_config) |
|
processor = LlavaProcessor(image_processor, tokenizer) |
|
``` |
|
|
|
**Set image and text** |
|
|
|
```python |
|
from PIL import Image |
|
import requests |
|
|
|
image_file = "https://images.unsplash.com/photo-1439246854758-f686a415d9da" |
|
raw_image = Image.open(requests.get(image_file, stream=True).raw) |
|
|
|
prompt = """<|im_start|>system |
|
A chat between a curious human and an artificial intelligence assistant. |
|
The assistant gives helpful, detailed, and polite answers to the human's questions. |
|
The assistant does not hallucinate and pays very close attention to the details.<|im_end|> |
|
<|im_start|>user |
|
<image> |
|
Describe the image.<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
``` |
|
|
|
**Process inputs** |
|
|
|
```python |
|
inputs = processor(prompt, raw_image, model, return_tensors='pt') |
|
|
|
inputs['input_ids'] = inputs['input_ids'].to(model.device) |
|
inputs['attention_mask'] = inputs['attention_mask'].to(model.device) |
|
``` |
|
|
|
**Generate the data** |
|
|
|
```python |
|
output = model.generate(**inputs, max_new_tokens=200, do_sample=True, top_p=0.5, temperature=1.2, eos_token_id=tokenizer.eos_token_id) |
|
``` |
|
|
|
## Benchmarks |
|
|
|
- TextVQA - 33.25% |
|
- GQA - 47.15% |
|
- VQAv2 - 63.1% |
|
- VizWiz - 24.03% |
|
|
|
## Acknowledgments |
|
|
|
Thanks to [ML Collective](https://mlcollective.org/) for providing credits for computing resources. |