yuhuili commited on
Commit
3ede498
1 Parent(s): 7858d5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -11,7 +11,7 @@ except:
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
- from transformers import LlamaForCausalLM
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
@@ -89,7 +89,7 @@ def warmup(model):
89
  prompt = conv.get_prompt()
90
  if args.model_type == "llama-2-chat":
91
  prompt += " "
92
- input_ids = model.tokenizer([prompt]).input_ids
93
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
94
  outs=model.generate(input_ids)
95
  print(outs)
@@ -278,6 +278,7 @@ model = LlamaForCausalLM.from_pretrained(
278
  device_map="auto",
279
  )
280
  model.eval()
 
281
  warmup(model)
282
 
283
  custom_css = """
 
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
+ from transformers import LlamaForCausalLM,AutoTokenizer
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
 
89
  prompt = conv.get_prompt()
90
  if args.model_type == "llama-2-chat":
91
  prompt += " "
92
+ input_ids = tokenizer([prompt]).input_ids
93
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
94
  outs=model.generate(input_ids)
95
  print(outs)
 
278
  device_map="auto",
279
  )
280
  model.eval()
281
+ tokenizer=AutoTokenizer.from_pretrained(args.base_model_path)
282
  warmup(model)
283
 
284
  custom_css = """