Update modeling_GOT.py
Browse files- modeling_GOT.py +10 -4
modeling_GOT.py
CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
@@ -495,7 +495,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
495 |
|
496 |
image_token_len = 256
|
497 |
|
498 |
-
|
|
|
|
|
|
|
499 |
|
500 |
w, h = image.size
|
501 |
|
@@ -713,7 +716,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
713 |
return processed_images
|
714 |
|
715 |
|
716 |
-
def chat_crop(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False):
|
717 |
# Model
|
718 |
self.disable_torch_init()
|
719 |
multi_page=False
|
@@ -749,7 +752,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
749 |
|
750 |
else:
|
751 |
qs = 'OCR with format upon the patch reference: '
|
752 |
-
|
|
|
|
|
|
|
753 |
sub_images = self.dynamic_preprocess(img)
|
754 |
ll = len(sub_images)
|
755 |
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
495 |
|
496 |
image_token_len = 256
|
497 |
|
498 |
+
if gradio_input:
|
499 |
+
image = image_file.copy()
|
500 |
+
else:
|
501 |
+
image = self.load_image(image_file)
|
502 |
|
503 |
w, h = image.size
|
504 |
|
|
|
716 |
return processed_images
|
717 |
|
718 |
|
719 |
+
def chat_crop(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
720 |
# Model
|
721 |
self.disable_torch_init()
|
722 |
multi_page=False
|
|
|
752 |
|
753 |
else:
|
754 |
qs = 'OCR with format upon the patch reference: '
|
755 |
+
if gradio_input:
|
756 |
+
img = image_file.copy()
|
757 |
+
else:
|
758 |
+
img = self.load_image(image_file)
|
759 |
sub_images = self.dynamic_preprocess(img)
|
760 |
ll = len(sub_images)
|
761 |
|