Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ except:
|
|
11 |
import torch
|
12 |
from fastchat.model import get_conversation_template
|
13 |
import re
|
14 |
-
|
15 |
|
16 |
def truncate_list(lst, num):
|
17 |
if num not in lst:
|
@@ -73,7 +73,7 @@ def highlight_text(text, text_list,color="black"):
|
|
73 |
|
74 |
return result
|
75 |
|
76 |
-
@spaces.GPU(duration=
|
77 |
def warmup(model):
|
78 |
model.cuda()
|
79 |
conv = get_conversation_template(args.model_type)
|
@@ -90,12 +90,13 @@ def warmup(model):
|
|
90 |
prompt = conv.get_prompt()
|
91 |
if args.model_type == "llama-2-chat":
|
92 |
prompt += " "
|
93 |
-
input_ids = tokenizer([prompt]).input_ids
|
94 |
input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
|
95 |
-
|
96 |
-
|
97 |
-
@spaces.GPU(duration=
|
98 |
def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
|
|
|
99 |
if not history:
|
100 |
return history, "0.00 tokens/s", "0.00", session_state
|
101 |
pure_history = session_state.get("pure_history", [])
|
@@ -270,17 +271,17 @@ parser.add_argument(
|
|
270 |
args = parser.parse_args()
|
271 |
a=torch.tensor(1).cuda()
|
272 |
print(a)
|
273 |
-
model =
|
274 |
-
args.base_model_path,
|
|
|
|
|
275 |
torch_dtype=torch.float16,
|
276 |
low_cpu_mem_usage=True,
|
277 |
load_in_4bit=args.load_in_4bit,
|
278 |
load_in_8bit=args.load_in_8bit,
|
279 |
device_map="auto",
|
280 |
)
|
281 |
-
|
282 |
model.eval()
|
283 |
-
tokenizer=AutoTokenizer.from_pretrained(args.base_model_path)
|
284 |
warmup(model)
|
285 |
|
286 |
custom_css = """
|
@@ -327,4 +328,4 @@ with gr.Blocks(css=custom_css) as demo:
|
|
327 |
)
|
328 |
stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
|
329 |
demo.queue()
|
330 |
-
demo.launch(
|
|
|
11 |
import torch
|
12 |
from fastchat.model import get_conversation_template
|
13 |
import re
|
14 |
+
|
15 |
|
16 |
def truncate_list(lst, num):
|
17 |
if num not in lst:
|
|
|
73 |
|
74 |
return result
|
75 |
|
76 |
+
@spaces.GPU(duration=60)
|
77 |
def warmup(model):
|
78 |
model.cuda()
|
79 |
conv = get_conversation_template(args.model_type)
|
|
|
90 |
prompt = conv.get_prompt()
|
91 |
if args.model_type == "llama-2-chat":
|
92 |
prompt += " "
|
93 |
+
input_ids = model.tokenizer([prompt]).input_ids
|
94 |
input_ids = torch.as_tensor(input_ids).to(model.base_model.device)
|
95 |
+
for output_ids in model.ea_generate(input_ids):
|
96 |
+
ol=output_ids.shape[1]
|
97 |
+
@spaces.GPU(duration=60)
|
98 |
def bot(history, temperature, top_p, use_EaInfer, highlight_EaInfer,session_state,):
|
99 |
+
model.cuda()
|
100 |
if not history:
|
101 |
return history, "0.00 tokens/s", "0.00", session_state
|
102 |
pure_history = session_state.get("pure_history", [])
|
|
|
271 |
args = parser.parse_args()
|
272 |
a=torch.tensor(1).cuda()
|
273 |
print(a)
|
274 |
+
model = EaModel.from_pretrained(
|
275 |
+
base_model_path=args.base_model_path,
|
276 |
+
ea_model_path=args.ea_model_path,
|
277 |
+
total_token=args.total_token,
|
278 |
torch_dtype=torch.float16,
|
279 |
low_cpu_mem_usage=True,
|
280 |
load_in_4bit=args.load_in_4bit,
|
281 |
load_in_8bit=args.load_in_8bit,
|
282 |
device_map="auto",
|
283 |
)
|
|
|
284 |
model.eval()
|
|
|
285 |
warmup(model)
|
286 |
|
287 |
custom_css = """
|
|
|
328 |
)
|
329 |
stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
|
330 |
demo.queue()
|
331 |
+
demo.launch()
|