metadata
license: apache-2.0
language:
- en
library_name: transformers
pipeline_tag: image-text-to-text
tags:
- multimodal
- aria
This is a fork of the rhymes-ai/Aria model. The main modification involves replacing grouped GEMM with a sequential MLP. In this configuration, each expert is implemented as a torch.nn.Linear
layer executed in sequence. This adjustment simplifies quantization with current open-source libraries, which are optimized for nn.Linear
layers.
While the sequential MLP approach aids in easier quantization, using grouped GEMM provides the advantage of faster inference speed.
Quick Start
Installation
pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torchvision requests torch Pillow
pip install flash-attn --no-build-isolation
Inference
import requests
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
model_id_or_path = "rhymes-ai/Aria-sequential_mlp"
model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
image = Image.open(requests.get(image_path, stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model.generate(
**inputs,
max_new_tokens=500,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
result = processor.decode(output_ids, skip_special_tokens=True)
print(result)