|
--- |
|
language: en |
|
datasets: |
|
- laion2b |
|
--- |
|
|
|
# OpenFlamingo-3B (CLIP ViT-L/14, MPT-1B) |
|
|
|
[Blog post]() | [Code](https://github.com/mlfoundations/open_flamingo) | [Demo](https://huggingface.co/spaces/openflamingo/OpenFlamingo) |
|
|
|
OpenFlamingo is an open source implementation of DeepMind's [Flamingo](https://www.deepmind.com/blog/tackling-multiple-tasks-with-a-single-visual-language-model) models. |
|
This 3B-parameter model uses a [CLIP ViT-L/14](https://huggingface.co/openai/clip-vit-large-patch14) vision encoder and [MPT-1B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) language model. |
|
|
|
## Model Details |
|
We follow the Flamingo modeling paradigm, outfitting the layers of a pretrained, frozen language model such that they cross-attend to visual features when decoding. Following Flamingo, we freeze the vision encoder and language model but train the connecting modules on web-scraped image-text sequences. Specifically, we trained this model on a mixture of [LAION-2B](https://arxiv.org/abs/2210.08402) and [Multimodal C4](https://arxiv.org/abs/2304.06939). |
|
|
|
This model has cross-attention modules inserted in *every* decoder block. It was trained using DistributedDataParallel across 64 A100 80GB GPUs at FP32 precision. |
|
|
|
The [MPT-1B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) modeling code does not accept the `labels` kwarg and compute cross-entropy loss within `forward()`. To train with the OpenFlamingo codebase, we suggest a version with the `labels` kwarg [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b). |
|
|
|
## Uses |
|
OpenFlamingo models process arbitrarily interleaved sequences of images and text to output text. This allows the models to accept in-context examples and undertake tasks like captioning, visual question answering, and image classification. |
|
### Initialization |
|
|
|
``` python |
|
from open_flamingo import create_model_and_transforms |
|
|
|
model, image_processor, tokenizer = create_model_and_transforms( |
|
clip_vision_encoder_path="ViT-L-14", |
|
clip_vision_encoder_pretrained="openai", |
|
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b", |
|
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b", |
|
cross_attn_every_n_layers=1 |
|
) |
|
|
|
# grab model checkpoint from huggingface hub |
|
from huggingface_hub import hf_hub_download |
|
import torch |
|
|
|
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt") |
|
model.load_state_dict(torch.load(checkpoint_path), strict=False) |
|
``` |
|
### Generation example |
|
Below is an example of generating text conditioned on interleaved images/text. In particular, let's try few-shot image captioning. |
|
|
|
``` python |
|
from PIL import Image |
|
import requests |
|
|
|
""" |
|
Step 1: Load images |
|
""" |
|
demo_image_one = Image.open( |
|
requests.get( |
|
"http://images.cocodataset.org/val2017/000000039769.jpg", stream=True |
|
).raw |
|
) |
|
|
|
demo_image_two = Image.open( |
|
requests.get( |
|
"http://images.cocodataset.org/test-stuff2017/000000028137.jpg", |
|
stream=True |
|
).raw |
|
) |
|
|
|
query_image = Image.open( |
|
requests.get( |
|
"http://images.cocodataset.org/test-stuff2017/000000028352.jpg", |
|
stream=True |
|
).raw |
|
) |
|
|
|
|
|
""" |
|
Step 2: Preprocessing images |
|
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape |
|
batch_size x num_media x num_frames x channels x height x width. |
|
In this case batch_size = 1, num_media = 3, num_frames = 1, |
|
channels = 3, height = 224, width = 224. |
|
""" |
|
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)] |
|
vision_x = torch.cat(vision_x, dim=0) |
|
vision_x = vision_x.unsqueeze(1).unsqueeze(0) |
|
|
|
""" |
|
Step 3: Preprocessing text |
|
Details: In the text we expect an <image> special token to indicate where an image is. |
|
We also expect an <|endofchunk|> special token to indicate the end of the text |
|
portion associated with an image. |
|
""" |
|
tokenizer.padding_side = "left" # For generation padding tokens should be on the left |
|
lang_x = tokenizer( |
|
["<image>An image of two cats.<|endofchunk|><image>An image of a bathroom sink.<|endofchunk|><image>An image of"], |
|
return_tensors="pt", |
|
) |
|
|
|
|
|
""" |
|
Step 4: Generate text |
|
""" |
|
generated_text = model.generate( |
|
vision_x=vision_x, |
|
lang_x=lang_x["input_ids"], |
|
attention_mask=lang_x["attention_mask"], |
|
max_new_tokens=20, |
|
num_beams=3, |
|
) |
|
|
|
print("Generated text: ", tokenizer.decode(generated_text[0])) |
|
``` |
|
|
|
### Bias, Risks, and Limitations |
|
OpenFlamingo models inherit the risks of their parent models, especially the language model. As an open-source research effort, we highly value open, accessible, reproducible multimodal model research; however, it is crucial to be aware that these models are trained on web data, have not been finetuned for safety, and thus may produce unintended, inappropriate, unreliable, and/or inaccurate outputs. Please use caution before deploying OpenFlamingo models in real applications. We also hope that OpenFlamingo enables further safety and reliability research to address these issues. |
|
|
|
In an effort to mitigate current potential biases and harms, we have deployed a text content filter on model outputs in the OpenFlamingo demo. We continue to red-team the model to understand and improve its safety. |
|
|
|
## Evaluation |
|
|
|
<table> |
|
<tr> |
|
<th></th> |
|
<th>0-shot</th> |
|
<th>4-shot</th> |
|
<th>8-shot</th> |
|
<th>16-shot</th> |
|
<th>32-shot</th> |
|
</tr> |
|
<tr> |
|
<th>COCO (CIDEr)</th> |
|
<td>74.99 (0.2)</td> |
|
<td>77.37 (0.3)</td> |
|
<td>85.97 (0.6)</td> |
|
<td>90.02</td> |
|
<td>93.08</td> |
|
</tr> |
|
<tr> |
|
<th>Flickr-30K (CIDEr)</th> |
|
<td>52.33 (1.0)</td> |
|
<td>57.27 (0.4)</td> |
|
<td>58.63 (1.1)</td> |
|
<td>59.83</td> |
|
<td>32</td> |
|
</tr> |
|
<tr> |
|
<th>VQAv2 (Accuracy)</th> |
|
<td>44.6 (0.7)</td> |
|
<td>45.9 (0.7)</td> |
|
<td>45.8 (0.5)</td> |
|
<td>45.5 (0.2)</td> |
|
<td>45.8 (0.4)</td> |
|
</tr> |
|
<tr> |
|
<th>OK-VQA (Accuracy)</th> |
|
<td>26.8 (0.3)</td> |
|
<td>27.61 (0.2)</td> |
|
<td>27.72 (0.1)</td> |
|
<td>28.42 (0.1)</td> |
|
<td>29.37 (0.2)</td> |
|
</tr> |
|
<tr> |
|
<th>TextVQA (Accuracy)</th> |
|
<td>22.87 (0.2)</td> |
|
<td>25.89 (0.2)</td> |
|
<td>24.74 (0.05)</td> |
|
<td>25.25 (0.2)</td> |
|
<td>26.31 (0.2)</td> |
|
|
|
</tr> |
|
<tr> |
|
<th>Vizwiz (Accuracy)</th> |
|
<td>18.30 (0.6)</td> |
|
<td>23.38 (1.1)</td> |
|
<td>31.80 (0.7)</td> |
|
<td>38.46 (1.1)</td> |
|
<td>42.16 (0.6)</td> |
|
</td> |
|
</tr> |
|
<tr> |
|
<th>Hateful Memes (ROC AUC)</th> |
|
<td>51.4</td> |
|
<td>51.4</td> |
|
<td>52.1</td> |
|
<td>51.6</td> |
|
<td>51.16</td> |
|
</tr> |
|
</table> |
|
|