DRXD1000 commited on
Commit
f001c41
1 Parent(s): da98dba

Making the Code Runnable on CPU

Browse files

Please check GPU functionality. This was a quick fix tested on mac :)

Files changed (1) hide show
  1. modeling_GOT.py +39 -21
modeling_GOT.py CHANGED
@@ -1,25 +1,37 @@
1
- from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
2
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
  from typing import List, Optional, Tuple, Union
4
- from transformers.cache_utils import Cache
5
  import requests
6
- from PIL import Image
7
- from io import BytesIO
8
  import torch
9
  import torch.nn as nn
 
10
  from torch.nn import CrossEntropyLoss
11
- from .got_vision_b import build_GOT_vit_b
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
- import dataclasses
 
 
 
 
 
 
 
 
 
 
 
15
 
 
16
 
17
  DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
  DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
 
22
- from enum import auto, Enum
 
 
23
  class SeparatorStyle(Enum):
24
  """Different separator style."""
25
  SINGLE = auto()
@@ -164,7 +176,7 @@ class GOTQwenModel(Qwen2Model):
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
- device="cuda"
168
  ):
169
 
170
 
@@ -453,7 +465,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
453
  tokenizer,
454
  freeze_lm_model=False,
455
  pretrained_stage1_model=None,
456
- device="cuda"
457
  ):
458
  config = self.get_model().config
459
 
@@ -488,6 +500,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
488
 
489
  self.disable_torch_init()
490
 
 
491
 
492
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
 
@@ -558,7 +571,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
558
 
559
  image_tensor_1 = image_processor_high(image)
560
 
561
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
562
 
563
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
564
  keywords = [stop_str]
@@ -566,10 +579,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
  do_sample=False,
574
  num_beams = 1,
575
  no_repeat_ngram_size = 20,
@@ -578,10 +591,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
- with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
585
  do_sample=False,
586
  num_beams = 1,
587
  no_repeat_ngram_size = 20,
@@ -599,7 +612,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
599
 
600
  if render:
601
  print('==============rendering===============')
602
- from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
 
 
 
 
 
603
 
604
  if '**kern' in outputs:
605
  import verovio
@@ -812,7 +830,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
812
 
813
  inputs = tokenizer([prompt])
814
 
815
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
@@ -820,10 +838,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
- with torch.autocast("cuda", dtype=torch.bfloat16):
824
  output_ids = self.generate(
825
  input_ids,
826
- images=[image_list.half().cuda()],
827
  do_sample=False,
828
  num_beams = 1,
829
  # no_repeat_ngram_size = 20,
@@ -832,10 +850,10 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
832
  stopping_criteria=[stopping_criteria]
833
  )
834
  else:
835
- with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
- images=[image_list.half().cuda()],
839
  do_sample=False,
840
  num_beams = 1,
841
  # no_repeat_ngram_size = 20,
 
1
+ import dataclasses
2
+ from io import BytesIO
3
  from typing import List, Optional, Tuple, Union
4
+
5
  import requests
 
 
6
  import torch
7
  import torch.nn as nn
8
+ from PIL import Image
9
  from torch.nn import CrossEntropyLoss
 
10
  from torchvision import transforms
11
  from torchvision.transforms.functional import InterpolationMode
12
+ from transformers import (
13
+ Qwen2Config,
14
+ Qwen2ForCausalLM,
15
+ Qwen2Model,
16
+ StoppingCriteria,
17
+ TextStreamer,
18
+ )
19
+ from transformers.cache_utils import Cache
20
+ from transformers.modeling_outputs import (
21
+ BaseModelOutputWithPast,
22
+ CausalLMOutputWithPast,
23
+ )
24
 
25
+ from .got_vision_b import build_GOT_vit_b
26
 
27
  DEFAULT_IMAGE_TOKEN = "<image>"
28
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
29
  DEFAULT_IM_START_TOKEN = '<img>'
30
  DEFAULT_IM_END_TOKEN = '</img>'
31
 
32
+ from enum import Enum, auto
33
+
34
+
35
  class SeparatorStyle(Enum):
36
  """Different separator style."""
37
  SINGLE = auto()
 
176
  use_im_start_end=False,
177
  vision_select_layer=-1,
178
  dtype=torch.float16,
179
+ device=None
180
  ):
181
 
182
 
 
465
  tokenizer,
466
  freeze_lm_model=False,
467
  pretrained_stage1_model=None,
468
+ device=None
469
  ):
470
  config = self.get_model().config
471
 
 
500
 
501
  self.disable_torch_init()
502
 
503
+ tokenizer.pad_token_id = tokenizer.eos_token_id
504
 
505
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
506
 
 
571
 
572
  image_tensor_1 = image_processor_high(image)
573
 
574
+ input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
575
 
576
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
577
  keywords = [stop_str]
 
579
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
580
 
581
  if stream_flag:
582
+ with torch.autocast(str(self.device), dtype=torch.bfloat16):
583
  output_ids = self.generate(
584
  input_ids,
585
+ images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
586
  do_sample=False,
587
  num_beams = 1,
588
  no_repeat_ngram_size = 20,
 
591
  stopping_criteria=[stopping_criteria]
592
  )
593
  else:
594
+ with torch.autocast(str(self.device), dtype=torch.bfloat16):
595
  output_ids = self.generate(
596
  input_ids,
597
+ images=[image_tensor_1.unsqueeze(0).half().to(self.device)],
598
  do_sample=False,
599
  num_beams = 1,
600
  no_repeat_ngram_size = 20,
 
612
 
613
  if render:
614
  print('==============rendering===============')
615
+ from .render_tools import (
616
+ content_mmd_to_html,
617
+ svg_to_html,
618
+ tik_html,
619
+ translation_table,
620
+ )
621
 
622
  if '**kern' in outputs:
623
  import verovio
 
830
 
831
  inputs = tokenizer([prompt])
832
 
833
+ input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
834
 
835
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
836
  keywords = [stop_str]
 
838
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
839
 
840
  if stream_flag:
841
+ with torch.autocast(str(self.device), dtype=torch.bfloat16):
842
  output_ids = self.generate(
843
  input_ids,
844
+ images=[image_list.half().to(self.device)],
845
  do_sample=False,
846
  num_beams = 1,
847
  # no_repeat_ngram_size = 20,
 
850
  stopping_criteria=[stopping_criteria]
851
  )
852
  else:
853
+ with torch.autocast(str(self.device), dtype=torch.bfloat16):
854
  output_ids = self.generate(
855
  input_ids,
856
+ images=[image_list.half().to(self.device)],
857
  do_sample=False,
858
  num_beams = 1,
859
  # no_repeat_ngram_size = 20,