Spaces:
Sleeping
Sleeping
0.11 simplifying wo pharia
Browse files
app.py
CHANGED
@@ -8,13 +8,15 @@ import os
|
|
8 |
|
9 |
from threading import Thread
|
10 |
|
|
|
|
|
11 |
logging.basicConfig(level=logging.DEBUG)
|
12 |
|
13 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
14 |
login(token=HF_TOKEN)
|
15 |
|
16 |
models_available = [
|
17 |
-
"
|
18 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
19 |
]
|
20 |
|
@@ -58,7 +60,6 @@ def load_model_a(model_id):
|
|
58 |
device_map="auto",
|
59 |
trust_remote_code=True,
|
60 |
).eval()
|
61 |
-
model_a.tie_weights()
|
62 |
return gr.update(label=model_id)
|
63 |
|
64 |
def load_model_b(model_id):
|
@@ -97,29 +98,17 @@ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_token
|
|
97 |
new_messages_a = system_prompt_list + chat_history_a + input_text_list
|
98 |
new_messages_b = system_prompt_list + chat_history_b + input_text_list
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
if "Pharia" in model_id_b:
|
113 |
-
logging.debug("model b is Pharia based, applying own template")
|
114 |
-
formatted_message_b = apply_chat_template(new_messages_a, add_generation_prompt=True)
|
115 |
-
logging.debug(f"***** formatted message is {formatted_message_b}")
|
116 |
-
input_ids_b = tokenizer_b(formatted_message_b, return_tensors="pt").input_ids.to(model_b.device)
|
117 |
-
else:
|
118 |
-
input_ids_b = tokenizer_b.apply_chat_template(
|
119 |
-
new_messages_b,
|
120 |
-
add_generation_prompt=True,
|
121 |
-
return_tensors="pt"
|
122 |
-
).to(model_b.device)
|
123 |
|
124 |
generation_kwargs_a = dict(
|
125 |
input_ids=input_ids_a,
|
|
|
8 |
|
9 |
from threading import Thread
|
10 |
|
11 |
+
# Status: Breaks during generation
|
12 |
+
|
13 |
logging.basicConfig(level=logging.DEBUG)
|
14 |
|
15 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
16 |
login(token=HF_TOKEN)
|
17 |
|
18 |
models_available = [
|
19 |
+
"NousResearch/Meta-Llama-3.1-8B-Instruct",
|
20 |
"mistralai/Mistral-7B-Instruct-v0.3",
|
21 |
]
|
22 |
|
|
|
60 |
device_map="auto",
|
61 |
trust_remote_code=True,
|
62 |
).eval()
|
|
|
63 |
return gr.update(label=model_id)
|
64 |
|
65 |
def load_model_b(model_id):
|
|
|
98 |
new_messages_a = system_prompt_list + chat_history_a + input_text_list
|
99 |
new_messages_b = system_prompt_list + chat_history_b + input_text_list
|
100 |
|
101 |
+
input_ids_a = tokenizer_a.apply_chat_template(
|
102 |
+
new_messages_a,
|
103 |
+
add_generation_prompt=True,
|
104 |
+
return_tensors="pt"
|
105 |
+
).to(model_a.device)
|
106 |
+
|
107 |
+
input_ids_b = tokenizer_b.apply_chat_template(
|
108 |
+
new_messages_b,
|
109 |
+
add_generation_prompt=True,
|
110 |
+
return_tensors="pt"
|
111 |
+
).to(model_b.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
generation_kwargs_a = dict(
|
114 |
input_ids=input_ids_a,
|