velier commited on
Commit
7f908b9
1 Parent(s): e770f21

remove default "cuda" parameters and refactor torch.bfloat16 to torch.float16

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +6 -6
modeling_GOT.py CHANGED
@@ -164,7 +164,7 @@ class GOTQwenModel(Qwen2Model):
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
- device="cuda"
168
  ):
169
 
170
 
@@ -453,7 +453,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
453
  tokenizer,
454
  freeze_lm_model=False,
455
  pretrained_stage1_model=None,
456
- device="cuda"
457
  ):
458
  config = self.get_model().config
459
 
@@ -566,7 +566,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half()],
@@ -578,7 +578,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
- with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half()],
@@ -820,7 +820,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
- with torch.autocast("cuda", dtype=torch.bfloat16):
824
  output_ids = self.generate(
825
  input_ids,
826
  images=[image_list.half()],
@@ -832,7 +832,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
832
  stopping_criteria=[stopping_criteria]
833
  )
834
  else:
835
- with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half()],
 
164
  use_im_start_end=False,
165
  vision_select_layer=-1,
166
  dtype=torch.float16,
167
+ device="mps"
168
  ):
169
 
170
 
 
453
  tokenizer,
454
  freeze_lm_model=False,
455
  pretrained_stage1_model=None,
456
+ device="mps"
457
  ):
458
  config = self.get_model().config
459
 
 
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
+ with torch.autocast("mps", dtype=torch.float16):
570
  output_ids = self.generate(
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half()],
 
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
+ with torch.autocast("mps", dtype=torch.float16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half()],
 
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
+ with torch.autocast("mps", dtype=torch.float16):
824
  output_ids = self.generate(
825
  input_ids,
826
  images=[image_list.half()],
 
832
  stopping_criteria=[stopping_criteria]
833
  )
834
  else:
835
+ with torch.autocast("mps", dtype=torch.float16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half()],