5to9 commited on
Commit
75b1a69
1 Parent(s): 2569b24

0.11 simplifying wo pharia

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -8,13 +8,15 @@ import os
8
 
9
  from threading import Thread
10
 
 
 
11
  logging.basicConfig(level=logging.DEBUG)
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
14
  login(token=HF_TOKEN)
15
 
16
  models_available = [
17
- "Aleph-Alpha/Pharia-1-LLM-7B-control-hf",
18
  "mistralai/Mistral-7B-Instruct-v0.3",
19
  ]
20
 
@@ -58,7 +60,6 @@ def load_model_a(model_id):
58
  device_map="auto",
59
  trust_remote_code=True,
60
  ).eval()
61
- model_a.tie_weights()
62
  return gr.update(label=model_id)
63
 
64
  def load_model_b(model_id):
@@ -97,29 +98,17 @@ def generate_both(system_prompt, input_text, chatbot_a, chatbot_b, max_new_token
97
  new_messages_a = system_prompt_list + chat_history_a + input_text_list
98
  new_messages_b = system_prompt_list + chat_history_b + input_text_list
99
 
100
- if "Pharia" in model_id_a:
101
- logging.debug("***** Model a is Pharia based, applying own template")
102
- formatted_message_a = apply_chat_template(new_messages_a, add_generation_prompt=True)
103
- logging.debug(f"***** formatted message is {formatted_message_a}")
104
- input_ids_a = tokenizer_b(formatted_message_a, return_tensors="pt").input_ids.to(model_a.device)
105
- else:
106
- input_ids_a = tokenizer_a.apply_chat_template(
107
- new_messages_a,
108
- add_generation_prompt=True,
109
- return_tensors="pt"
110
- ).to(model_a.device)
111
-
112
- if "Pharia" in model_id_b:
113
- logging.debug("model b is Pharia based, applying own template")
114
- formatted_message_b = apply_chat_template(new_messages_a, add_generation_prompt=True)
115
- logging.debug(f"***** formatted message is {formatted_message_b}")
116
- input_ids_b = tokenizer_b(formatted_message_b, return_tensors="pt").input_ids.to(model_b.device)
117
- else:
118
- input_ids_b = tokenizer_b.apply_chat_template(
119
- new_messages_b,
120
- add_generation_prompt=True,
121
- return_tensors="pt"
122
- ).to(model_b.device)
123
 
124
  generation_kwargs_a = dict(
125
  input_ids=input_ids_a,
 
8
 
9
  from threading import Thread
10
 
11
+ # Status: Breaks during generation
12
+
13
  logging.basicConfig(level=logging.DEBUG)
14
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  login(token=HF_TOKEN)
17
 
18
  models_available = [
19
+ "NousResearch/Meta-Llama-3.1-8B-Instruct",
20
  "mistralai/Mistral-7B-Instruct-v0.3",
21
  ]
22
 
 
60
  device_map="auto",
61
  trust_remote_code=True,
62
  ).eval()
 
63
  return gr.update(label=model_id)
64
 
65
  def load_model_b(model_id):
 
98
  new_messages_a = system_prompt_list + chat_history_a + input_text_list
99
  new_messages_b = system_prompt_list + chat_history_b + input_text_list
100
 
101
+ input_ids_a = tokenizer_a.apply_chat_template(
102
+ new_messages_a,
103
+ add_generation_prompt=True,
104
+ return_tensors="pt"
105
+ ).to(model_a.device)
106
+
107
+ input_ids_b = tokenizer_b.apply_chat_template(
108
+ new_messages_b,
109
+ add_generation_prompt=True,
110
+ return_tensors="pt"
111
+ ).to(model_b.device)
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  generation_kwargs_a = dict(
114
  input_ids=input_ids_a,