--- library_name: transformers license: apache-2.0 language: - fa base_model: llava-hf/llava-1.5-7b-hf --- language: - fa datasets: - BaSalam/vision-catalogs-llava-format-v3 pipeline_tag: image-text-to-text # LLaVA Model Card ## Model details This model is [`"llava-hf/llava-1.5-7b-hf"`](https://huggingface.co/llava-hf/llava-1.5-7b-hf), fine-tuned on [`"Basalam product"`](https://huggingface.co/datasets/BaSalam/vision-catalogs-llava-format-v3) data for extracting visual attributes of products. The outputs are in JSON format and can be parsed. ## How to use the model Below is an example script to run generation in `float16` precision on a GPU device: ```python import requests from PIL import Image import torch import json from transformers import AutoProcessor, LlavaForConditionalGeneration model_id = "BaSalam/Llava-1.5-7b-hf-bslm-product-attributes-v0" model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, ).to(0) processor = AutoProcessor.from_pretrained(model_id) def prompt_formatter(entity): json_format = """attributes': {'attribute_name_1' : , 'attribute_name_2': , ...}""" final_prompt = f"""برای محصول داده شده، ویژگی‌های تصویری محصول را در قالب جیسون (json) استخراج کن. ساختار JSON باید به این شکل باشد: {json_format}. محصول از یک بازار اینترنتی ایرانی است پس خروجی Json باید به زبان فارسی باشد. محصول: '{entity}'.""" return final_prompt prompt = prompt_formatter(entity='تیشرت مردانه') conversation = [ { "role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image"}, ], }, ] prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) image_file = "https://statics.basalam.com/public-16/users/6eOEg/01-24/qJ34XziHu7Orp3GToVWTms1nKvCv0X86Ux7tQLtuRoyTXTxyQ4.jpg_800X800X70.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float16) output = model.generate(**inputs, max_new_tokens=384, do_sample=False) generated_title = processor.decode(output[0], skip_special_tokens=True)[len(text.replace('', ' ')):] output = generated_title.replace('ASSISTANT: ', '') json_output = json.loads(output) print(json_output) ``` ``` [ { "attributes": { "نوع": [ "تیشرت مردانه" ], "طرح چاپی": [ "MVP" ], "رنگ": [ "زرد", "آبی", "سفید", "مشکی", "کرم", "سبز" ], "سایز": [ "L", "XL", "2XL", "3XL" ] } } ] ``` ### Model optimization #### 4-bit quantization through `bitsandbytes` library First make sure to install `bitsandbytes`, `pip install bitsandbytes` and make sure to have access to a CUDA compatible GPU device. Simply change the snippet above with: ```diff model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, + load_in_4bit=True ) ``` #### Use Flash-Attention 2 to further speed-up generation First make sure to install `flash-attn`. Refer to the [original repository of Flash Attention](https://github.com/Dao-AILab/flash-attention) regarding that package installation. Simply change the snippet above with: ```diff model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, + use_flash_attention_2=True ).to(0) ```