Making the Code Runnable on CPU
Browse filesPlease check GPU functionality. This was a quick fix tested on mac :)
- modeling_GOT.py +39 -21
modeling_GOT.py
CHANGED
@@ -1,25 +1,37 @@
|
|
1 |
-
|
2 |
-
from
|
3 |
from typing import List, Optional, Tuple, Union
|
4 |
-
|
5 |
import requests
|
6 |
-
from PIL import Image
|
7 |
-
from io import BytesIO
|
8 |
import torch
|
9 |
import torch.nn as nn
|
|
|
10 |
from torch.nn import CrossEntropyLoss
|
11 |
-
from .got_vision_b import build_GOT_vit_b
|
12 |
from torchvision import transforms
|
13 |
from torchvision.transforms.functional import InterpolationMode
|
14 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
|
|
16 |
|
17 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
18 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
19 |
DEFAULT_IM_START_TOKEN = '<img>'
|
20 |
DEFAULT_IM_END_TOKEN = '</img>'
|
21 |
|
22 |
-
from enum import
|
|
|
|
|
23 |
class SeparatorStyle(Enum):
|
24 |
"""Different separator style."""
|
25 |
SINGLE = auto()
|
@@ -164,7 +176,7 @@ class GOTQwenModel(Qwen2Model):
|
|
164 |
use_im_start_end=False,
|
165 |
vision_select_layer=-1,
|
166 |
dtype=torch.float16,
|
167 |
-
device=
|
168 |
):
|
169 |
|
170 |
|
@@ -453,7 +465,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
453 |
tokenizer,
|
454 |
freeze_lm_model=False,
|
455 |
pretrained_stage1_model=None,
|
456 |
-
device=
|
457 |
):
|
458 |
config = self.get_model().config
|
459 |
|
@@ -488,6 +500,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
488 |
|
489 |
self.disable_torch_init()
|
490 |
|
|
|
491 |
|
492 |
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
493 |
|
@@ -558,7 +571,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
558 |
|
559 |
image_tensor_1 = image_processor_high(image)
|
560 |
|
561 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
562 |
|
563 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
564 |
keywords = [stop_str]
|
@@ -566,10 +579,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
|
568 |
if stream_flag:
|
569 |
-
with torch.autocast(
|
570 |
output_ids = self.generate(
|
571 |
input_ids,
|
572 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
573 |
do_sample=False,
|
574 |
num_beams = 1,
|
575 |
no_repeat_ngram_size = 20,
|
@@ -578,10 +591,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
578 |
stopping_criteria=[stopping_criteria]
|
579 |
)
|
580 |
else:
|
581 |
-
with torch.autocast(
|
582 |
output_ids = self.generate(
|
583 |
input_ids,
|
584 |
-
images=[image_tensor_1.unsqueeze(0).half().
|
585 |
do_sample=False,
|
586 |
num_beams = 1,
|
587 |
no_repeat_ngram_size = 20,
|
@@ -599,7 +612,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
599 |
|
600 |
if render:
|
601 |
print('==============rendering===============')
|
602 |
-
from .render_tools import
|
|
|
|
|
|
|
|
|
|
|
603 |
|
604 |
if '**kern' in outputs:
|
605 |
import verovio
|
@@ -812,7 +830,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
812 |
|
813 |
inputs = tokenizer([prompt])
|
814 |
|
815 |
-
input_ids = torch.as_tensor(inputs.input_ids).
|
816 |
|
817 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
818 |
keywords = [stop_str]
|
@@ -820,10 +838,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
if stream_flag:
|
823 |
-
with torch.autocast(
|
824 |
output_ids = self.generate(
|
825 |
input_ids,
|
826 |
-
images=[image_list.half().
|
827 |
do_sample=False,
|
828 |
num_beams = 1,
|
829 |
# no_repeat_ngram_size = 20,
|
@@ -832,10 +850,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
832 |
stopping_criteria=[stopping_criteria]
|
833 |
)
|
834 |
else:
|
835 |
-
with torch.autocast(
|
836 |
output_ids = self.generate(
|
837 |
input_ids,
|
838 |
-
images=[image_list.half().
|
839 |
do_sample=False,
|
840 |
num_beams = 1,
|
841 |
# no_repeat_ngram_size = 20,
|
|
|
1 |
+
import dataclasses
|
2 |
+
from io import BytesIO
|
3 |
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
import requests
|
|
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
from torch.nn import CrossEntropyLoss
|
|
|
10 |
from torchvision import transforms
|
11 |
from torchvision.transforms.functional import InterpolationMode
|
12 |
+
from transformers import (
|
13 |
+
Qwen2Config,
|
14 |
+
Qwen2ForCausalLM,
|
15 |
+
Qwen2Model,
|
16 |
+
StoppingCriteria,
|
17 |
+
TextStreamer,
|
18 |
+
)
|
19 |
+
from transformers.cache_utils import Cache
|
20 |
+
from transformers.modeling_outputs import (
|
21 |
+
BaseModelOutputWithPast,
|
22 |
+
CausalLMOutputWithPast,
|
23 |
+
)
|
24 |
|
25 |
+
from .got_vision_b import build_GOT_vit_b
|
26 |
|
27 |
DEFAULT_IMAGE_TOKEN = "<image>"
|
28 |
DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
|
29 |
DEFAULT_IM_START_TOKEN = '<img>'
|
30 |
DEFAULT_IM_END_TOKEN = '</img>'
|
31 |
|
32 |
+
from enum import Enum, auto
|
33 |
+
|
34 |
+
|
35 |
class SeparatorStyle(Enum):
|
36 |
"""Different separator style."""
|
37 |
SINGLE = auto()
|
|
|
176 |
use_im_start_end=False,
|
177 |
vision_select_layer=-1,
|
178 |
dtype=torch.float16,
|
179 |
+
device=None
|
180 |
):
|
181 |
|
182 |
|
|
|
465 |
tokenizer,
|
466 |
freeze_lm_model=False,
|
467 |
pretrained_stage1_model=None,
|
468 |
+
device=None
|
469 |
):
|
470 |
config = self.get_model().config
|
471 |
|
|
|
500 |
|
501 |
self.disable_torch_init()
|
502 |
|
503 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
504 |
|
505 |
image_processor_high = GOTImageEvalProcessor(image_size=1024)
|
506 |
|
|
|
571 |
|
572 |
image_tensor_1 = image_processor_high(image)
|
573 |
|
574 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
|
575 |
|
576 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
577 |
keywords = [stop_str]
|
|
|
579 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
580 |
|
581 |
if stream_flag:
|
582 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
583 |
output_ids = self.generate(
|
584 |
input_ids,
|
585 |
+
images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
|
586 |
do_sample=False,
|
587 |
num_beams = 1,
|
588 |
no_repeat_ngram_size = 20,
|
|
|
591 |
stopping_criteria=[stopping_criteria]
|
592 |
)
|
593 |
else:
|
594 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
595 |
output_ids = self.generate(
|
596 |
input_ids,
|
597 |
+
images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
|
598 |
do_sample=False,
|
599 |
num_beams = 1,
|
600 |
no_repeat_ngram_size = 20,
|
|
|
612 |
|
613 |
if render:
|
614 |
print('==============rendering===============')
|
615 |
+
from .render_tools import (
|
616 |
+
content_mmd_to_html,
|
617 |
+
svg_to_html,
|
618 |
+
tik_html,
|
619 |
+
translation_table,
|
620 |
+
)
|
621 |
|
622 |
if '**kern' in outputs:
|
623 |
import verovio
|
|
|
830 |
|
831 |
inputs = tokenizer([prompt])
|
832 |
|
833 |
+
input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
|
834 |
|
835 |
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
836 |
keywords = [stop_str]
|
|
|
838 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
839 |
|
840 |
if stream_flag:
|
841 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
842 |
output_ids = self.generate(
|
843 |
input_ids,
|
844 |
+
images=[image_list.half().to(self.device)],
|
845 |
do_sample=False,
|
846 |
num_beams = 1,
|
847 |
# no_repeat_ngram_size = 20,
|
|
|
850 |
stopping_criteria=[stopping_criteria]
|
851 |
)
|
852 |
else:
|
853 |
+
with torch.autocast(str(self.device), dtype=torch.bfloat16):
|
854 |
output_ids = self.generate(
|
855 |
input_ids,
|
856 |
+
images=[image_list.half().to(self.device)],
|
857 |
do_sample=False,
|
858 |
num_beams = 1,
|
859 |
# no_repeat_ngram_size = 20,
|