Updated "How to Use" python example
Browse files
README.md
CHANGED
@@ -56,96 +56,26 @@ from PIL import Image
|
|
56 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
57 |
from urllib.request import urlopen
|
58 |
import torch.nn as nn
|
59 |
-
from
|
|
|
|
|
|
|
|
|
60 |
|
61 |
DEVICE = "cuda:0"
|
62 |
PROMPT = "This is a dialog with AI assistant.\n"
|
63 |
-
tokenizer = AutoTokenizer.from_pretrained("OmniMistral-tokenizer", use_fast=False)
|
64 |
-
model = AutoModelForCausalLM.from_pretrained("OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
projection = torch.load("projection", map_location=DEVICE)
|
67 |
special_embs = torch.load("special_embeddings.pt", map_location=DEVICE)
|
68 |
|
69 |
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
class CLIPVisionTower(nn.Module):
|
74 |
-
def __init__(self, vision_tower, args, delay_load=False):
|
75 |
-
super().__init__()
|
76 |
-
|
77 |
-
self.is_loaded = False
|
78 |
-
|
79 |
-
self.vision_tower_name = vision_tower
|
80 |
-
self.select_layer = args.mm_vision_select_layer
|
81 |
-
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
|
82 |
-
|
83 |
-
if not delay_load:
|
84 |
-
self.load_model()
|
85 |
-
else:
|
86 |
-
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
87 |
-
|
88 |
-
def load_model(self):
|
89 |
-
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
90 |
-
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
91 |
-
self.vision_tower.requires_grad_(False)
|
92 |
-
|
93 |
-
self.is_loaded = True
|
94 |
-
|
95 |
-
def feature_select(self, image_forward_outs):
|
96 |
-
image_features = image_forward_outs.hidden_states[self.select_layer]
|
97 |
-
if self.select_feature == 'patch':
|
98 |
-
image_features = image_features[:, 1:]
|
99 |
-
elif self.select_feature == 'cls_patch':
|
100 |
-
image_features = image_features
|
101 |
-
else:
|
102 |
-
raise ValueError(f'Unexpected select feature: {self.select_feature}')
|
103 |
-
return image_features
|
104 |
-
|
105 |
-
@torch.no_grad()
|
106 |
-
def forward(self, images):
|
107 |
-
if type(images) is list:
|
108 |
-
image_features = []
|
109 |
-
for image in images:
|
110 |
-
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
111 |
-
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
112 |
-
image_features.append(image_feature)
|
113 |
-
else:
|
114 |
-
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
115 |
-
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
116 |
-
|
117 |
-
return image_features
|
118 |
-
|
119 |
-
@property
|
120 |
-
def dummy_feature(self):
|
121 |
-
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
122 |
-
|
123 |
-
@property
|
124 |
-
def dtype(self):
|
125 |
-
return self.vision_tower.dtype
|
126 |
-
|
127 |
-
@property
|
128 |
-
def device(self):
|
129 |
-
return self.vision_tower.device
|
130 |
-
|
131 |
-
@property
|
132 |
-
def config(self):
|
133 |
-
if self.is_loaded:
|
134 |
-
return self.vision_tower.config
|
135 |
-
else:
|
136 |
-
return self.cfg_only
|
137 |
-
|
138 |
-
@property
|
139 |
-
def hidden_size(self):
|
140 |
-
return self.config.hidden_size
|
141 |
-
|
142 |
-
|
143 |
-
class ClipTowerCfg:
|
144 |
-
def __init__(self):
|
145 |
-
self.mm_vision_select_feature = 'patch'
|
146 |
-
self.mm_vision_select_layer = -2
|
147 |
-
|
148 |
-
clip = CLIPVisionTower("openai/clip-vit-large-patch14-336", ClipTowerCfg())
|
149 |
clip.load_model()
|
150 |
clip = clip.to(device=DEVICE, dtype=torch.bfloat16)
|
151 |
|
@@ -169,11 +99,11 @@ def gen_answer(model, tokenizer, clip, projection, query, special_embs, image=No
|
|
169 |
with torch.no_grad():
|
170 |
image_features = clip.image_processor(image, return_tensors='pt')
|
171 |
image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16)
|
172 |
-
|
173 |
projected_vision_embeddings = projection(image_embedding).to(device=DEVICE, dtype=torch.bfloat16)
|
174 |
prompt_ids = tokenizer.encode(f"{PROMPT}", add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
|
175 |
question_ids = tokenizer.encode(query, add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
|
176 |
-
|
177 |
prompt_embeddings = model.model.embed_tokens(prompt_ids).to(torch.bfloat16)
|
178 |
question_embeddings = model.model.embed_tokens(question_ids).to(torch.bfloat16)
|
179 |
|
@@ -200,11 +130,11 @@ img = Image.open(urlopen(img_url))
|
|
200 |
|
201 |
answer = gen_answer(
|
202 |
model,
|
203 |
-
tokenizer,
|
204 |
-
clip,
|
205 |
-
projection,
|
206 |
-
query=question,
|
207 |
-
special_embs=special_embs,
|
208 |
image=img
|
209 |
)
|
210 |
|
|
|
56 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
57 |
from urllib.request import urlopen
|
58 |
import torch.nn as nn
|
59 |
+
from huggingface_hub import hf_hub_download
|
60 |
+
|
61 |
+
# Loading some sources of the projection adapter and image encoder
|
62 |
+
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="models.py", local_dir='./')
|
63 |
+
from models import CLIPVisionTower
|
64 |
|
65 |
DEVICE = "cuda:0"
|
66 |
PROMPT = "This is a dialog with AI assistant.\n"
|
|
|
|
|
67 |
|
68 |
+
tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-tokenizer", use_fast=False)
|
69 |
+
model = AutoModelForCausalLM.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
|
70 |
+
|
71 |
+
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="projection", local_dir='./')
|
72 |
+
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="special_embeddings.pt", local_dir='./')
|
73 |
projection = torch.load("projection", map_location=DEVICE)
|
74 |
special_embs = torch.load("special_embeddings.pt", map_location=DEVICE)
|
75 |
|
76 |
|
77 |
|
78 |
+
clip = CLIPVisionTower("openai/clip-vit-large-patch14-336")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
clip.load_model()
|
80 |
clip = clip.to(device=DEVICE, dtype=torch.bfloat16)
|
81 |
|
|
|
99 |
with torch.no_grad():
|
100 |
image_features = clip.image_processor(image, return_tensors='pt')
|
101 |
image_embedding = clip(image_features['pixel_values']).to(device=DEVICE, dtype=torch.bfloat16)
|
102 |
+
|
103 |
projected_vision_embeddings = projection(image_embedding).to(device=DEVICE, dtype=torch.bfloat16)
|
104 |
prompt_ids = tokenizer.encode(f"{PROMPT}", add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
|
105 |
question_ids = tokenizer.encode(query, add_special_tokens=False, return_tensors="pt").to(device=DEVICE)
|
106 |
+
|
107 |
prompt_embeddings = model.model.embed_tokens(prompt_ids).to(torch.bfloat16)
|
108 |
question_embeddings = model.model.embed_tokens(question_ids).to(torch.bfloat16)
|
109 |
|
|
|
130 |
|
131 |
answer = gen_answer(
|
132 |
model,
|
133 |
+
tokenizer,
|
134 |
+
clip,
|
135 |
+
projection,
|
136 |
+
query=question,
|
137 |
+
special_embs=special_embs,
|
138 |
image=img
|
139 |
)
|
140 |
|