Update README.md
Browse filesFix for running on cpu
README.md
CHANGED
@@ -34,8 +34,8 @@ tokenizer = AutoTokenizer.from_pretrained("GeneZC/MiniChat-3B", use_fast=False)
|
|
34 |
# GPU.
|
35 |
model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="auto", torch_dtype=torch.float16).eval()
|
36 |
# CPU.
|
37 |
-
# model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="cpu", torch_dtype=torch.
|
38 |
-
|
39 |
conv = get_default_conv_template("minichat")
|
40 |
|
41 |
question = "Implement a program to find the common elements in two arrays without using any extra data structures."
|
@@ -44,7 +44,7 @@ conv.append_message(conv.roles[1], None)
|
|
44 |
prompt = conv.get_prompt()
|
45 |
input_ids = tokenizer([prompt]).input_ids
|
46 |
output_ids = model.generate(
|
47 |
-
torch.as_tensor(input_ids).
|
48 |
do_sample=True,
|
49 |
temperature=0.7,
|
50 |
max_new_tokens=1024,
|
|
|
34 |
# GPU.
|
35 |
model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="auto", torch_dtype=torch.float16).eval()
|
36 |
# CPU.
|
37 |
+
# model = AutoModelForCausalLM.from_pretrained("GeneZC/MiniChat-3B", use_cache=True, device_map="cpu", torch_dtype=torch.float32).eval()
|
38 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
39 |
conv = get_default_conv_template("minichat")
|
40 |
|
41 |
question = "Implement a program to find the common elements in two arrays without using any extra data structures."
|
|
|
44 |
prompt = conv.get_prompt()
|
45 |
input_ids = tokenizer([prompt]).input_ids
|
46 |
output_ids = model.generate(
|
47 |
+
torch.as_tensor(input_ids).to(device),
|
48 |
do_sample=True,
|
49 |
temperature=0.7,
|
50 |
max_new_tokens=1024,
|