yeswondwerr
commited on
Commit
•
43f1a0b
1
Parent(s):
6cd45a2
Update __main__.py
Browse files- __main__.py +15 -23
__main__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import argparse
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
from PIL import Image
|
@@ -12,26 +13,14 @@ from transformers import (
|
|
12 |
from transformers import TextStreamer
|
13 |
|
14 |
|
15 |
-
def tokenizer_image_token(
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
def insert_separator(X, sep):
|
21 |
-
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
len(prompt_chunks) > 0
|
27 |
-
and len(prompt_chunks[0]) > 0
|
28 |
-
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
29 |
-
):
|
30 |
-
offset = 1
|
31 |
-
input_ids.append(prompt_chunks[0][0])
|
32 |
-
|
33 |
-
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
34 |
-
input_ids.extend(x[offset:])
|
35 |
|
36 |
return torch.tensor(input_ids, dtype=torch.long)
|
37 |
|
@@ -126,18 +115,20 @@ def answer_question(
|
|
126 |
tokenizer.eos_token = "<|eot_id|>"
|
127 |
|
128 |
try:
|
129 |
-
q = input("
|
130 |
except EOFError:
|
131 |
q = ""
|
132 |
if not q:
|
133 |
print("no input detected. exiting.")
|
|
|
|
|
134 |
|
135 |
question = "<image>" + q
|
136 |
|
137 |
-
prompt = f"<|
|
138 |
|
139 |
input_ids = (
|
140 |
-
tokenizer_image_token(prompt, tokenizer
|
141 |
.unsqueeze(0)
|
142 |
.to(model.device)
|
143 |
)
|
@@ -183,13 +174,14 @@ def answer_question(
|
|
183 |
}
|
184 |
|
185 |
while True:
|
|
|
186 |
generated_ids = model.generate(
|
187 |
inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
|
188 |
)[0]
|
189 |
|
190 |
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
|
191 |
try:
|
192 |
-
q = input("
|
193 |
except EOFError:
|
194 |
q = ""
|
195 |
if not q:
|
|
|
1 |
import argparse
|
2 |
+
import sys
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from PIL import Image
|
|
|
13 |
from transformers import TextStreamer
|
14 |
|
15 |
|
16 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=-200):
|
17 |
+
prompt_chunks = prompt.split("<image>")
|
18 |
+
tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
|
19 |
+
input_ids = tokenized_chunks[0]
|
|
|
|
|
|
|
20 |
|
21 |
+
for chunk in tokenized_chunks[1:]:
|
22 |
+
input_ids.append(image_token_index)
|
23 |
+
input_ids.extend(chunk[1:]) # Exclude BOS token on nonzero index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
return torch.tensor(input_ids, dtype=torch.long)
|
26 |
|
|
|
115 |
tokenizer.eos_token = "<|eot_id|>"
|
116 |
|
117 |
try:
|
118 |
+
q = input("\nuser: ")
|
119 |
except EOFError:
|
120 |
q = ""
|
121 |
if not q:
|
122 |
print("no input detected. exiting.")
|
123 |
+
sys.exit()
|
124 |
+
|
125 |
|
126 |
question = "<image>" + q
|
127 |
|
128 |
+
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
129 |
|
130 |
input_ids = (
|
131 |
+
tokenizer_image_token(prompt, tokenizer)
|
132 |
.unsqueeze(0)
|
133 |
.to(model.device)
|
134 |
)
|
|
|
174 |
}
|
175 |
|
176 |
while True:
|
177 |
+
print('assistant: ')
|
178 |
generated_ids = model.generate(
|
179 |
inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
|
180 |
)[0]
|
181 |
|
182 |
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
|
183 |
try:
|
184 |
+
q = input("\nuser: ")
|
185 |
except EOFError:
|
186 |
q = ""
|
187 |
if not q:
|