yuhuili commited on
Commit
151160e
1 Parent(s): 8449c2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -90,7 +90,7 @@ def warmup(model):
90
  if args.model_type == "llama-2-chat":
91
  prompt += " "
92
  input_ids = model.tokenizer([prompt]).input_ids
93
- input_ids = torch.as_tensor(input_ids).to(model.device)
94
  for output_ids in model.ea_generate(input_ids):
95
  ol=output_ids.shape[1]
96
  @spaces.GPU(duration=30)
@@ -144,7 +144,7 @@ def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_stat
144
  prompt += " "
145
 
146
  input_ids = model.tokenizer([prompt]).input_ids
147
- input_ids = torch.as_tensor(input_ids).to(model.device)
148
  input_len = input_ids.shape[1]
149
  naive_text = []
150
  cu_len = input_len
 
90
  if args.model_type == "llama-2-chat":
91
  prompt += " "
92
  input_ids = model.tokenizer([prompt]).input_ids
93
+ input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
94
  for output_ids in model.ea_generate(input_ids):
95
  ol=output_ids.shape[1]
96
  @spaces.GPU(duration=30)
 
144
  prompt += " "
145
 
146
  input_ids = model.tokenizer([prompt]).input_ids
147
+ input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
148
  input_len = input_ids.shape[1]
149
  naive_text = []
150
  cu_len = input_len