Update modeling_GOT.py
Browse files- modeling_GOT.py +50 -26
modeling_GOT.py
CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
-
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
@@ -565,18 +565,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
582 |
|
@@ -716,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
716 |
return processed_images
|
717 |
|
718 |
|
719 |
-
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False):
|
720 |
# Model
|
721 |
self.disable_torch_init()
|
722 |
multi_page=False
|
@@ -807,18 +819,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
807 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
808 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
809 |
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
822 |
|
823 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
824 |
|
|
|
484 |
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
485 |
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
486 |
|
487 |
+
def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
565 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
|
568 |
+
if stream_flag:
|
569 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
570 |
+
output_ids = self.generate(
|
571 |
+
input_ids,
|
572 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
573 |
+
do_sample=False,
|
574 |
+
num_beams = 1,
|
575 |
+
no_repeat_ngram_size = 20,
|
576 |
+
streamer=streamer,
|
577 |
+
max_new_tokens=4096,
|
578 |
+
stopping_criteria=[stopping_criteria]
|
579 |
+
)
|
580 |
+
else:
|
581 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
582 |
+
output_ids = self.generate(
|
583 |
+
input_ids,
|
584 |
+
images=[image_tensor_1.unsqueeze(0).half().cuda()],
|
585 |
+
do_sample=False,
|
586 |
+
num_beams = 1,
|
587 |
+
no_repeat_ngram_size = 20,
|
588 |
+
# streamer=streamer,
|
589 |
+
max_new_tokens=4096,
|
590 |
+
stopping_criteria=[stopping_criteria]
|
591 |
+
)
|
592 |
|
593 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
594 |
|
|
|
728 |
return processed_images
|
729 |
|
730 |
|
731 |
+
def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
|
732 |
# Model
|
733 |
self.disable_torch_init()
|
734 |
multi_page=False
|
|
|
819 |
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
+
if stream_flag:
|
823 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
824 |
+
output_ids = self.generate(
|
825 |
+
input_ids,
|
826 |
+
images=[image_list.half().cuda()],
|
827 |
+
do_sample=False,
|
828 |
+
num_beams = 1,
|
829 |
+
# no_repeat_ngram_size = 20,
|
830 |
+
streamer=streamer,
|
831 |
+
max_new_tokens=4096,
|
832 |
+
stopping_criteria=[stopping_criteria]
|
833 |
+
)
|
834 |
+
else:
|
835 |
+
with torch.autocast("cuda", dtype=torch.bfloat16):
|
836 |
+
output_ids = self.generate(
|
837 |
+
input_ids,
|
838 |
+
images=[image_list.half().cuda()],
|
839 |
+
do_sample=False,
|
840 |
+
num_beams = 1,
|
841 |
+
# no_repeat_ngram_size = 20,
|
842 |
+
# streamer=streamer,
|
843 |
+
max_new_tokens=4096,
|
844 |
+
stopping_criteria=[stopping_criteria]
|
845 |
+
)
|
846 |
|
847 |
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
848 |
|