Spaces:
Running
Running
add more models
Browse files
README.md
CHANGED
@@ -6,8 +6,7 @@ colorTo: purple
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.19.2
|
8 |
app_file: app.py
|
9 |
-
pinned: true
|
10 |
-
fullWidth: true
|
11 |
---
|
12 |
|
13 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.19.2
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
|
|
10 |
---
|
11 |
|
12 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
app.py
CHANGED
@@ -85,6 +85,18 @@ def respond(
|
|
85 |
_model_name = "meta-llama/Llama-3-8b-hf"
|
86 |
elif model_name == "Llama-3-70B":
|
87 |
_model_name = "meta-llama/Llama-3-70b-hf"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
else:
|
89 |
raise ValueError("Invalid model name")
|
90 |
# _model_name = "meta-llama/Llama-3-8b-hf"
|
@@ -105,31 +117,69 @@ def respond(
|
|
105 |
for msg in request:
|
106 |
# print(msg.choices[0].delta.keys())
|
107 |
token = msg.choices[0].delta["content"]
|
108 |
-
response += token
|
109 |
should_stop = False
|
110 |
for _stop in stop_str:
|
111 |
-
if _stop in response:
|
112 |
should_stop = True
|
113 |
break
|
114 |
if should_stop:
|
115 |
break
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with gr.Row():
|
120 |
with gr.Column():
|
121 |
-
gr.
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
124 |
with gr.Column():
|
|
|
125 |
with gr.Column():
|
126 |
with gr.Row():
|
127 |
max_tokens = gr.Textbox(value=1024, label="Max tokens")
|
128 |
temperature = gr.Textbox(value=0.5, label="Temperature")
|
129 |
-
with gr.Column():
|
130 |
-
|
131 |
top_p = gr.Textbox(value=0.9, label="Top-p")
|
132 |
rp = gr.Textbox(value=1.1, label="Repetition penalty")
|
|
|
133 |
|
134 |
chat = gr.ChatInterface(
|
135 |
respond,
|
@@ -139,6 +189,8 @@ with gr.Blocks() as demo:
|
|
139 |
)
|
140 |
chat.chatbot.height = 600
|
141 |
|
|
|
|
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
demo.launch()
|
|
|
85 |
_model_name = "meta-llama/Llama-3-8b-hf"
|
86 |
elif model_name == "Llama-3-70B":
|
87 |
_model_name = "meta-llama/Llama-3-70b-hf"
|
88 |
+
elif model_name == "Llama-2-7B":
|
89 |
+
_model_name = "meta-llama/Llama-2-7b-hf"
|
90 |
+
elif model_name == "Llama-2-70B":
|
91 |
+
_model_name = "meta-llama/Llama-2-70b-hf"
|
92 |
+
elif model_name == "Mistral-7B-v0.1":
|
93 |
+
_model_name = "mistralai/Mistral-7B-v0.1"
|
94 |
+
elif model_name == "mistralai/Mixtral-8x22B":
|
95 |
+
_model_name = "mistralai/Mixtral-8x22B"
|
96 |
+
elif model_name == "Qwen1.5-72B":
|
97 |
+
_model_name = "Qwen/Qwen1.5-72B"
|
98 |
+
elif model_name == "Yi-34B":
|
99 |
+
_model_name = "zero-one-ai/Yi-34B"
|
100 |
else:
|
101 |
raise ValueError("Invalid model name")
|
102 |
# _model_name = "meta-llama/Llama-3-8b-hf"
|
|
|
117 |
for msg in request:
|
118 |
# print(msg.choices[0].delta.keys())
|
119 |
token = msg.choices[0].delta["content"]
|
|
|
120 |
should_stop = False
|
121 |
for _stop in stop_str:
|
122 |
+
if _stop in response + token:
|
123 |
should_stop = True
|
124 |
break
|
125 |
if should_stop:
|
126 |
break
|
127 |
+
response += token
|
128 |
+
if response.endswith('\n"'):
|
129 |
+
response = response[:-1]
|
130 |
+
elif response.endswith('\n""'):
|
131 |
+
response = response[:-2]
|
132 |
+
yield response
|
133 |
|
134 |
+
js_code_label = """
|
135 |
+
function addApiKeyLink() {
|
136 |
+
// Select the div with id 'api_key'
|
137 |
+
const apiKeyDiv = document.getElementById('api_key');
|
138 |
+
|
139 |
+
// Find the span within that div with data-testid 'block-info'
|
140 |
+
const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');
|
141 |
+
|
142 |
+
// Create the new link element
|
143 |
+
const newLink = document.createElement('a');
|
144 |
+
newLink.href = 'https://api.together.ai/settings/api-keys';
|
145 |
+
newLink.textContent = ' View your keys here.';
|
146 |
+
newLink.target = '_blank'; // Open link in new tab
|
147 |
+
newLink.style = 'color: #007bff; text-decoration: underline;';
|
148 |
+
|
149 |
+
// Create the additional text
|
150 |
+
const additionalText = document.createTextNode(' (new account will have free credits to use.)');
|
151 |
+
|
152 |
+
// Append the link and additional text to the span
|
153 |
+
if (blockInfoSpan) {
|
154 |
+
// add a br
|
155 |
+
apiKeyDiv.appendChild(document.createElement('br'));
|
156 |
+
apiKeyDiv.appendChild(newLink);
|
157 |
+
apiKeyDiv.appendChild(additionalText);
|
158 |
+
} else {
|
159 |
+
console.error('Span with data-testid "block-info" not found');
|
160 |
+
}
|
161 |
+
}
|
162 |
+
"""
|
163 |
+
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
|
164 |
with gr.Row():
|
165 |
with gr.Column():
|
166 |
+
gr.Markdown("""# 💬 BaseChat: Chat with Base LLMs with URIAL
|
167 |
+
[Paper](https://arxiv.org/abs/2312.01552) | [Website](https://allenai.github.io/re-align/) | [GitHub](https://github.com/Re-Align/urial) | Contact: [Yuchen Lin](https://yuchenlin.xyz/)
|
168 |
+
|
169 |
+
**Talk with __BASE__ LLMs which are not fine-tuned at all.**
|
170 |
+
""")
|
171 |
+
model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", "mistralai/Mixtral-8x22B", "Yi-34B", "Llama-2-7B", "Llama-2-70B"], value="Llama-3-8B", label="Base LLM name")
|
172 |
with gr.Column():
|
173 |
+
together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key")
|
174 |
with gr.Column():
|
175 |
with gr.Row():
|
176 |
max_tokens = gr.Textbox(value=1024, label="Max tokens")
|
177 |
temperature = gr.Textbox(value=0.5, label="Temperature")
|
178 |
+
# with gr.Column():
|
179 |
+
# with gr.Row():
|
180 |
top_p = gr.Textbox(value=0.9, label="Top-p")
|
181 |
rp = gr.Textbox(value=1.1, label="Repetition penalty")
|
182 |
+
|
183 |
|
184 |
chat = gr.ChatInterface(
|
185 |
respond,
|
|
|
189 |
)
|
190 |
chat.chatbot.height = 600
|
191 |
|
192 |
+
|
193 |
+
|
194 |
|
195 |
if __name__ == "__main__":
|
196 |
demo.launch()
|