Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ except:
|
|
11 |
import torch
|
12 |
from fastchat.model import get_conversation_template
|
13 |
import re
|
14 |
-
from transformers import LlamaForCausalLM
|
15 |
|
16 |
def truncate_list(lst, num):
|
17 |
if num not in lst:
|
@@ -89,7 +89,7 @@ def warmup(model):
|
|
89 |
prompt = conv.get_prompt()
|
90 |
if args.model_type == "llama-2-chat":
|
91 |
prompt += " "
|
92 |
-
input_ids =
|
93 |
input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
|
94 |
outs=model.generate(input_ids)
|
95 |
print(outs)
|
@@ -278,6 +278,7 @@ model = LlamaForCausalLM.from_pretrained(
|
|
278 |
device_map="auto",
|
279 |
)
|
280 |
model.eval()
|
|
|
281 |
warmup(model)
|
282 |
|
283 |
custom_css = """
|
|
|
11 |
import torch
|
12 |
from fastchat.model import get_conversation_template
|
13 |
import re
|
14 |
+
from transformers import LlamaForCausalLM,AutoTokenizer
|
15 |
|
16 |
def truncate_list(lst, num):
|
17 |
if num not in lst:
|
|
|
89 |
prompt = conv.get_prompt()
|
90 |
if args.model_type == "llama-2-chat":
|
91 |
prompt += " "
|
92 |
+
input_ids = tokenizer([prompt]).input_ids
|
93 |
input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
|
94 |
outs=model.generate(input_ids)
|
95 |
print(outs)
|
|
|
278 |
device_map="auto",
|
279 |
)
|
280 |
model.eval()
|
281 |
+
tokenizer=AutoTokenizer.from_pretrained(args.base_model_path)
|
282 |
warmup(model)
|
283 |
|
284 |
custom_css = """
|