Mohammadreza's picture
Update README.md
2dc29d4 verified
metadata
library_name: transformers
license: apache-2.0
language:
  - fa
base_model: llava-hf/llava-1.5-7b-hf

language:

  • fa datasets:
  • BaSalam/vision-catalogs-llava-format-v3 pipeline_tag: image-text-to-text

LLaVA Model Card

Model details

This model is "llava-hf/llava-1.5-7b-hf", fine-tuned on "Basalam product" data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.

How to use the model

Below is an example script to run generation in float16 precision on a GPU device:

import requests
from PIL import Image
import torch
import json

from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

def prompt_formatter(entity):
    json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
    final_prompt = f"""برای محصول داده شده، ویژگی‌های تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
    return final_prompt

prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
    {
      "role": "user",
      "content": [
          {"type": "text", "text": prompt},
          {"type": "image"},
        ],
    },
]

prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)

output = model.generate(**inputs, max_new_tokens=384, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
[
  {
    "attributes": {
      "نوع": [
        "تیشرت مردانه"
      ],
      "طرح چاپی": [
        "MVP"
      ],
      "رنگ": [
        "زرد",
        "آبی",
        "سفید",
        "مشکی",
        "کرم",
        "سبز"
      ],
      "سایز": [
        "L",
        "XL",
        "2XL",
        "3XL"
      ]
    }
  }
]

Model optimization

4-bit quantization through bitsandbytes library

First make sure to install bitsandbytes, pip install bitsandbytes and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   load_in_4bit=True
)

Use Flash-Attention 2 to further speed-up generation

First make sure to install flash-attn. Refer to the original repository of Flash Attention regarding that package installation. Simply change the snippet above with:

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
+   use_flash_attention_2=True
).to(0)