Update transformers_inference.py
Browse files
transformers_inference.py
CHANGED
@@ -10,7 +10,7 @@ tokenizer = LlamaTokenizer.from_pretrained('teknium/OpenHermes-2.5-Mistral-7B',
|
|
10 |
model = MistralForCausalLM.from_pretrained(
|
11 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
12 |
torch_dtype=torch.float16,
|
13 |
-
device_map=
|
14 |
load_in_8bit=False,
|
15 |
load_in_4bit=True,
|
16 |
use_flash_attention_2=True
|
|
|
10 |
model = MistralForCausalLM.from_pretrained(
|
11 |
"teknium/OpenHermes-2.5-Mistral-7B",
|
12 |
torch_dtype=torch.float16,
|
13 |
+
device_map="auto",#{'': 'cuda:0'},
|
14 |
load_in_8bit=False,
|
15 |
load_in_4bit=True,
|
16 |
use_flash_attention_2=True
|