yuhuili commited on
Commit
7858d5a
1 Parent(s): fd169f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -11,7 +11,7 @@ except:
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
-
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
@@ -91,8 +91,8 @@ def warmup(model):
91
  prompt += " "
92
  input_ids = model.tokenizer([prompt]).input_ids
93
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
94
- for output_ids in model.ea_generate(input_ids):
95
- ol=output_ids.shape[1]
96
  @spaces.GPU(duration=30)
97
  def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
98
  if not history:
@@ -269,10 +269,8 @@ parser.add_argument(
269
  args = parser.parse_args()
270
  a=torch.tensor(1).cuda()
271
  print(a)
272
- model = EaModel.from_pretrained(
273
- base_model_path=args.base_model_path,
274
- ea_model_path=args.ea_model_path,
275
- total_token=args.total_token,
276
  torch_dtype=torch.float16,
277
  low_cpu_mem_usage=True,
278
  load_in_4bit=args.load_in_4bit,
 
11
  import torch
12
  from fastchat.model import get_conversation_template
13
  import re
14
+ from transformers import LlamaForCausalLM
15
 
16
  def truncate_list(lst, num):
17
  if num not in lst:
 
91
  prompt += " "
92
  input_ids = model.tokenizer([prompt]).input_ids
93
  input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
94
+ outs=model.generate(input_ids)
95
+ print(outs)
96
  @spaces.GPU(duration=30)
97
  def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
98
  if not history:
 
269
  args = parser.parse_args()
270
  a=torch.tensor(1).cuda()
271
  print(a)
272
+ model = LlamaForCausalLM.from_pretrained(
273
+ args.base_model_path,
 
 
274
  torch_dtype=torch.float16,
275
  low_cpu_mem_usage=True,
276
  load_in_4bit=args.load_in_4bit,