Update modeling_GOT.py
Browse files- modeling_GOT.py +16 -15
modeling_GOT.py
CHANGED
@@ -249,7 +249,7 @@ class GOTQwenModel(Qwen2Model):
|
|
249 |
image_patches_features = []
|
250 |
for image_patch in image_patches:
|
251 |
image_p = torch.stack([image_patch])
|
252 |
-
|
253 |
with torch.set_grad_enabled(False):
|
254 |
cnn_feature_p = vision_tower_high(image_p)
|
255 |
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
|
@@ -257,7 +257,6 @@ class GOTQwenModel(Qwen2Model):
|
|
257 |
image_patches_features.append(image_feature_p)
|
258 |
image_feature = torch.cat(image_patches_features, dim=1)
|
259 |
image_features.append(image_feature)
|
260 |
-
exit()
|
261 |
|
262 |
|
263 |
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
@@ -485,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
485 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
486 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
487 |
|
488 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
|
489 |
|
490 |
self.disable_torch_init()
|
491 |
|
@@ -549,7 +548,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
549 |
conv.append_message(conv.roles[1], None)
|
550 |
prompt = conv.get_prompt()
|
551 |
|
552 |
-
|
|
|
553 |
|
554 |
inputs = tokenizer([prompt])
|
555 |
|
@@ -570,7 +570,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
570 |
do_sample=False,
|
571 |
num_beams = 1,
|
572 |
no_repeat_ngram_size = 20,
|
573 |
-
streamer=streamer,
|
574 |
max_new_tokens=4096,
|
575 |
stopping_criteria=[stopping_criteria]
|
576 |
)
|
@@ -715,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
715 |
return processed_images
|
716 |
|
717 |
|
718 |
-
def chat_plus(self, tokenizer,
|
719 |
# Model
|
720 |
self.disable_torch_init()
|
721 |
multi_page=False
|
@@ -730,8 +730,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
730 |
|
731 |
image_list = []
|
732 |
|
733 |
-
if len(image_file_list)>1:
|
734 |
-
|
735 |
|
736 |
if multi_page:
|
737 |
qs = 'OCR with format across multi pages: '
|
@@ -739,19 +739,19 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
739 |
import glob
|
740 |
# from natsort import natsorted
|
741 |
# patches = glob.glob(image_file + '/*png')
|
742 |
-
patches =
|
743 |
# patches = natsorted(patches)
|
744 |
sub_images = []
|
745 |
for sub_image in patches:
|
746 |
sub_images.append(self.load_image(sub_image))
|
747 |
|
748 |
ll = len(patches)
|
749 |
-
print(patches)
|
750 |
-
print("len ll: ", ll)
|
751 |
|
752 |
else:
|
753 |
qs = 'OCR with format upon the patch reference: '
|
754 |
-
img = self.load_image(
|
755 |
sub_images = self.dynamic_preprocess(img)
|
756 |
ll = len(sub_images)
|
757 |
|
@@ -762,7 +762,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
762 |
|
763 |
image_list = torch.stack(image_list)
|
764 |
|
765 |
-
print('====new images batch size======: ',image_list.shape)
|
766 |
|
767 |
|
768 |
if use_im_start_end:
|
@@ -788,7 +788,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
788 |
conv.append_message(conv.roles[1], None)
|
789 |
prompt = conv.get_prompt()
|
790 |
|
791 |
-
|
|
|
792 |
|
793 |
inputs = tokenizer([prompt])
|
794 |
|
@@ -807,7 +808,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
807 |
do_sample=False,
|
808 |
num_beams = 1,
|
809 |
# no_repeat_ngram_size = 20,
|
810 |
-
streamer=streamer,
|
811 |
max_new_tokens=4096,
|
812 |
stopping_criteria=[stopping_criteria]
|
813 |
)
|
|
|
249 |
image_patches_features = []
|
250 |
for image_patch in image_patches:
|
251 |
image_p = torch.stack([image_patch])
|
252 |
+
|
253 |
with torch.set_grad_enabled(False):
|
254 |
cnn_feature_p = vision_tower_high(image_p)
|
255 |
cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
|
|
|
257 |
image_patches_features.append(image_feature_p)
|
258 |
image_feature = torch.cat(image_patches_features, dim=1)
|
259 |
image_features.append(image_feature)
|
|
|
260 |
|
261 |
|
262 |
dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
548 |
conv.append_message(conv.roles[1], None)
|
549 |
prompt = conv.get_prompt()
|
550 |
|
551 |
+
if print_prompt:
|
552 |
+
print(prompt)
|
553 |
|
554 |
inputs = tokenizer([prompt])
|
555 |
|
|
|
570 |
do_sample=False,
|
571 |
num_beams = 1,
|
572 |
no_repeat_ngram_size = 20,
|
573 |
+
# streamer=streamer,
|
574 |
max_new_tokens=4096,
|
575 |
stopping_criteria=[stopping_criteria]
|
576 |
)
|
|
|
715 |
return processed_images
|
716 |
|
717 |
|
718 |
+
def chat_plus(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False):
|
719 |
# Model
|
720 |
self.disable_torch_init()
|
721 |
multi_page=False
|
|
|
730 |
|
731 |
image_list = []
|
732 |
|
733 |
+
# if len(image_file_list)>1:
|
734 |
+
# multi_page = True
|
735 |
|
736 |
if multi_page:
|
737 |
qs = 'OCR with format across multi pages: '
|
|
|
739 |
import glob
|
740 |
# from natsort import natsorted
|
741 |
# patches = glob.glob(image_file + '/*png')
|
742 |
+
patches = image_file
|
743 |
# patches = natsorted(patches)
|
744 |
sub_images = []
|
745 |
for sub_image in patches:
|
746 |
sub_images.append(self.load_image(sub_image))
|
747 |
|
748 |
ll = len(patches)
|
749 |
+
# print(patches)
|
750 |
+
# print("len ll: ", ll)
|
751 |
|
752 |
else:
|
753 |
qs = 'OCR with format upon the patch reference: '
|
754 |
+
img = self.load_image(image_file)
|
755 |
sub_images = self.dynamic_preprocess(img)
|
756 |
ll = len(sub_images)
|
757 |
|
|
|
762 |
|
763 |
image_list = torch.stack(image_list)
|
764 |
|
765 |
+
print('====new images batch size======: \n',image_list.shape)
|
766 |
|
767 |
|
768 |
if use_im_start_end:
|
|
|
788 |
conv.append_message(conv.roles[1], None)
|
789 |
prompt = conv.get_prompt()
|
790 |
|
791 |
+
if print_prompt:
|
792 |
+
print(prompt)
|
793 |
|
794 |
inputs = tokenizer([prompt])
|
795 |
|
|
|
808 |
do_sample=False,
|
809 |
num_beams = 1,
|
810 |
# no_repeat_ngram_size = 20,
|
811 |
+
# streamer=streamer,
|
812 |
max_new_tokens=4096,
|
813 |
stopping_criteria=[stopping_criteria]
|
814 |
)
|