vlrm-blip2-opt-2.7b / README.md
sashakunitsyn's picture
Update README.md
67c80a5 verified
metadata
language:
  - en
license: mit
library_name: transformers
tags:
  - vision
  - image-to-text
  - image-captioning
pipeline_tag: image-to-text
base_model: Salesforce/blip2-opt-2.7b

VLRM

This repository contains the weights of BLIP-2 OPT-2.7B model fine-tuned by reinforcement learning method introduced in the paper VLRM: Vision-Language Models act as Reward Models for Image Captioning.

The RL-tuned model is able to generate longer and more comprehensive descriptions with zero computational overhead compared to the original model.

You can find other details in the GitHub Repository (to be done).

Running the model

Option 1

Load the whole model from this repo
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'

Option 2

Since the fine-tuned layers take small part of the whole model, you can first load the original model, then load the RL-tuned weights.

Step 1. Load the original model
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")

img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman sitting on the beach with a dog'
Step 2. Load the RL-tuned weights Available checkpoints:
  • vlrm-blip2-opt-2.7b.pt (VLRM in the paper)
  • vlrm-rs-blip2-opt-2.7b.pt (VLRM-RS in the paper)
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)

out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'