aria-dev commited on
Commit
39d9c7b
1 Parent(s): e83fa52

update readme

Browse files
Files changed (2) hide show
  1. README.md +64 -0
  2. inference.py +0 -43
README.md CHANGED
@@ -1,3 +1,67 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ This repository offers int8 quantized weights of the [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) model utilizing the [TorchAO](https://github.com/pytorch/ao) quantization framework. It now supports inference within 30GB of GPU memory.
6
+
7
+
8
+ ## Quick Start
9
+ ### Installation
10
+ ```
11
+ pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torch==2.5.0 torchao==0.6.1 torchvision requests Pillow
12
+ pip install flash-attn --no-build-isolation
13
+ ```
14
+
15
+ ### Inference
16
+
17
+ ```python
18
+ import requests
19
+ import torch
20
+ from PIL import Image
21
+ from transformers import AutoModelForCausalLM, AutoProcessor
22
+
23
+ model_id_or_path = "rhymes-ai/Aria-torchao-int8wo"
24
+
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id_or_path,
27
+ device_map="auto",
28
+ torch_dtype=torch.bfloat16,
29
+ trust_remote_code=True,
30
+ attn_implementation="flash_attention_2",
31
+ )
32
+
33
+ processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True)
34
+
35
+ image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
36
+
37
+ image = Image.open(requests.get(image_path, stream=True).raw)
38
+
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": [
43
+ {"text": None, "type": "image"},
44
+ {"text": "what is the image?", "type": "text"},
45
+ ],
46
+ }
47
+ ]
48
+
49
+ text = processor.apply_chat_template(messages, add_generation_prompt=True)
50
+ inputs = processor(text=text, images=image, return_tensors="pt")
51
+ inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
52
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
53
+
54
+ with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
55
+ output = model.generate(
56
+ **inputs,
57
+ max_new_tokens=500,
58
+ stop_strings=["<|im_end|>"],
59
+ tokenizer=processor.tokenizer,
60
+ do_sample=True,
61
+ temperature=0.9,
62
+ )
63
+ output_ids = output[0][inputs["input_ids"].shape[1] :]
64
+ result = processor.decode(output_ids, skip_special_tokens=True)
65
+
66
+ print(result)
67
+ ```
inference.py DELETED
@@ -1,43 +0,0 @@
1
- import torch
2
- from PIL import Image
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
- import requests
5
-
6
- model_id_or_path = "./"
7
- tokenizer_id_or_path = "./"
8
-
9
- model = AutoModelForCausalLM.from_pretrained(
10
- model_id_or_path,
11
- device_map="cuda",
12
- torch_dtype=torch.bfloat16,
13
- trust_remote_code=True,
14
- attn_implementation="flash_attention_2",
15
- )
16
-
17
- model = torch.compile(model, mode="max-autotune", fullgraph=True)
18
-
19
- messages = [
20
- {
21
- "role": "user",
22
- "content": [
23
- {"text": None, "type": "image"},
24
- {"text": "what's in the image?", "type": "text"},
25
- ],
26
- }
27
- ]
28
-
29
- image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
30
-
31
- image = Image.open(requests.get(image_path, stream=True).raw)
32
-
33
- processor = AutoProcessor.from_pretrained(tokenizer_id_or_path, trust_remote_code=True)
34
- text = processor.apply_chat_template(messages, add_generation_prompt=True)
35
- inputs = processor(text=text, images=image, return_tensors="pt")
36
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
37
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
38
-
39
- out = model.generate(**inputs, max_new_tokens=100, tokenizer=processor.tokenizer, stop_strings=["<|im_end|>"])
40
-
41
- output_ids = out[0][inputs["input_ids"].shape[1] :]
42
- result = processor.decode(output_ids, skip_special_tokens=True)
43
- print(result)