remove default "cuda" parameters and refactor torch.bfloat16 to torch.float16
Browse files- modeling_GOT.py +6 -6
modeling_GOT.py
CHANGED
@@ -164,7 +164,7 @@ class GOTQwenModel(Qwen2Model):
|
|
164 |
use_im_start_end=False,
|
165 |
vision_select_layer=-1,
|
166 |
dtype=torch.float16,
|
167 |
-
device="
|
168 |
):
|
169 |
|
170 |
|
@@ -453,7 +453,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
453 |
tokenizer,
|
454 |
freeze_lm_model=False,
|
455 |
pretrained_stage1_model=None,
|
456 |
-
device="
|
457 |
):
|
458 |
config = self.get_model().config
|
459 |
|
@@ -566,7 +566,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
|
568 |
if stream_flag:
|
569 |
-
with torch.autocast("
|
570 |
output_ids = self.generate(
|
571 |
input_ids,
|
572 |
images=[image_tensor_1.unsqueeze(0).half()],
|
@@ -578,7 +578,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
578 |
stopping_criteria=[stopping_criteria]
|
579 |
)
|
580 |
else:
|
581 |
-
with torch.autocast("
|
582 |
output_ids = self.generate(
|
583 |
input_ids,
|
584 |
images=[image_tensor_1.unsqueeze(0).half()],
|
@@ -820,7 +820,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
if stream_flag:
|
823 |
-
with torch.autocast("
|
824 |
output_ids = self.generate(
|
825 |
input_ids,
|
826 |
images=[image_list.half()],
|
@@ -832,7 +832,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
|
|
832 |
stopping_criteria=[stopping_criteria]
|
833 |
)
|
834 |
else:
|
835 |
-
with torch.autocast("
|
836 |
output_ids = self.generate(
|
837 |
input_ids,
|
838 |
images=[image_list.half()],
|
|
|
164 |
use_im_start_end=False,
|
165 |
vision_select_layer=-1,
|
166 |
dtype=torch.float16,
|
167 |
+
device="mps"
|
168 |
):
|
169 |
|
170 |
|
|
|
453 |
tokenizer,
|
454 |
freeze_lm_model=False,
|
455 |
pretrained_stage1_model=None,
|
456 |
+
device="mps"
|
457 |
):
|
458 |
config = self.get_model().config
|
459 |
|
|
|
566 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
567 |
|
568 |
if stream_flag:
|
569 |
+
with torch.autocast("mps", dtype=torch.float16):
|
570 |
output_ids = self.generate(
|
571 |
input_ids,
|
572 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
|
578 |
stopping_criteria=[stopping_criteria]
|
579 |
)
|
580 |
else:
|
581 |
+
with torch.autocast("mps", dtype=torch.float16):
|
582 |
output_ids = self.generate(
|
583 |
input_ids,
|
584 |
images=[image_tensor_1.unsqueeze(0).half()],
|
|
|
820 |
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
821 |
|
822 |
if stream_flag:
|
823 |
+
with torch.autocast("mps", dtype=torch.float16):
|
824 |
output_ids = self.generate(
|
825 |
input_ids,
|
826 |
images=[image_list.half()],
|
|
|
832 |
stopping_criteria=[stopping_criteria]
|
833 |
)
|
834 |
else:
|
835 |
+
with torch.autocast("mps", dtype=torch.float16):
|
836 |
output_ids = self.generate(
|
837 |
input_ids,
|
838 |
images=[image_list.half()],
|