Mohammadreza's picture
Update README.md
2dc29d4 verified
---
library_name: transformers
license: apache-2.0
language:
- fa
base_model: llava-hf/llava-1.5-7b-hf
---
language:
- fa
datasets:
- BaSalam/vision-catalogs-llava-format-v3
pipeline_tag: image-text-to-text
# LLaVA Model Card
## Model details
This model is [`"llava-hf/llava-1.5-7b-hf"`](https://huggingface.co/llava-hf/llava-1.5-7b-hf), fine-tuned on [`"Basalam product"`](https://huggingface.co/datasets/BaSalam/vision-catalogs-llava-format-v3) data for extracting visual attributes of products. The outputs are in JSON format and can be parsed.
## How to use the model
Below is an example script to run generation in `float16` precision on a GPU device:
```python
import requests
from PIL import Image
import torch
import json
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
def prompt_formatter(entity):
json_format = """attributes': {'attribute_name_1' : <list of attribute values>, 'attribute_name_2': <list of attribute values>, ...}"""
final_prompt = f"""برای محصول داده شده، ویژگی‌های تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد.
محصول: '{entity}'."""
return final_prompt
prompt = prompt_formatter(entity='تیشرت مردانه')
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=384, do_sample=False)
generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('<image>', ' ')):]
output = generated_title.replace('ASSISTANT: ', '')
json_output = json.loads(output)
print(json_output)
```
```
[
{
"attributes": {
"نوع": [
"تیشرت مردانه"
],
"طرح چاپی": [
"MVP"
],
"رنگ": [
"زرد",
"آبی",
"سفید",
"مشکی",
"کرم",
"سبز"
],
"سایز": [
"L",
"XL",
"2XL",
"3XL"
]
}
}
]
```
### Model optimization
#### 4-bit quantization through `bitsandbytes` library
First make sure to install `bitsandbytes`, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with:
```diff
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ load_in_4bit=True
)
```
#### Use Flash-Attention 2 to further speed-up generation
First make sure to install `flash-attn`. Refer to the [original repository of Flash Attention](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with:
```diff
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
+ use_flash_attention_2=True
).to(0)
```