Update README.md
Browse files
README.md
CHANGED
@@ -216,7 +216,179 @@ with open("studio/unsloth_studio/chat.py", "r") as chat_module:
|
|
216 |
)
|
217 |
exec(code)
|
218 |
```
|
|
|
219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
|
221 |
|
222 |
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
|
|
|
216 |
)
|
217 |
exec(code)
|
218 |
```
|
219 |
+
- Change the `chat.py`
|
220 |
|
221 |
+
```py
|
222 |
+
# Unsloth Studio
|
223 |
+
# Copyright (C) 2024-present the Unsloth AI team. All rights reserved.
|
224 |
+
|
225 |
+
# This program is free software: you can redistribute it and/or modify
|
226 |
+
# it under the terms of the GNU Affero General Public License as published
|
227 |
+
# by the Free Software Foundation, either version 3 of the License, or
|
228 |
+
# (at your option) any later version.
|
229 |
+
|
230 |
+
# This program is distributed in the hope that it will be useful,
|
231 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
232 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
233 |
+
# GNU Affero General Public License for more details.
|
234 |
+
|
235 |
+
# You should have received a copy of the GNU Affero General Public License
|
236 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
237 |
+
|
238 |
+
from IPython.display import clear_output
|
239 |
+
import subprocess
|
240 |
+
import os
|
241 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
242 |
+
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
243 |
+
MODEL_NAME = "vutuka/Llama-3.1-8B-Instruct-African-Ultrachat"
|
244 |
+
|
245 |
+
print("Installing packages for 🦥 Unsloth Studio ... Please wait 1 minute ...")
|
246 |
+
|
247 |
+
install_first = [
|
248 |
+
"pip", "install",
|
249 |
+
"huggingface_hub[hf_transfer]",
|
250 |
+
]
|
251 |
+
install_first = subprocess.Popen(install_first)
|
252 |
+
install_first.wait()
|
253 |
+
|
254 |
+
install_second = [
|
255 |
+
"pip", "install",
|
256 |
+
"gradio",
|
257 |
+
"unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git",
|
258 |
+
]
|
259 |
+
install_second = subprocess.Popen(install_second)
|
260 |
+
|
261 |
+
from huggingface_hub import snapshot_download
|
262 |
+
import warnings
|
263 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
264 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
|
265 |
+
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
|
266 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
|
267 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
|
268 |
+
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
|
269 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
|
270 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
|
271 |
+
|
272 |
+
from huggingface_hub.utils import disable_progress_bars
|
273 |
+
disable_progress_bars()
|
274 |
+
snapshot_download(repo_id = MODEL_NAME, repo_type = "model")
|
275 |
+
|
276 |
+
install_second.wait()
|
277 |
+
|
278 |
+
install_dependencies = [
|
279 |
+
"pip", "install", "--no-deps",
|
280 |
+
"xformers<0.0.27", "trl<0.9.0", "peft", "accelerate", "bitsandbytes",
|
281 |
+
]
|
282 |
+
install_dependencies = subprocess.Popen(install_dependencies)
|
283 |
+
install_dependencies.wait()
|
284 |
+
clear_output()
|
285 |
+
|
286 |
+
|
287 |
+
from contextlib import redirect_stdout
|
288 |
+
import io
|
289 |
+
import logging
|
290 |
+
logging.getLogger("transformers.utils.hub").setLevel(logging.CRITICAL+1)
|
291 |
+
|
292 |
+
print("Loading model ... Please wait 1 more minute! ...")
|
293 |
+
|
294 |
+
with redirect_stdout(io.StringIO()):
|
295 |
+
from unsloth import FastLanguageModel
|
296 |
+
import torch
|
297 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
298 |
+
model_name = MODEL_NAME,
|
299 |
+
max_seq_length = None,
|
300 |
+
dtype = None,
|
301 |
+
load_in_4bit = True,
|
302 |
+
)
|
303 |
+
FastLanguageModel.for_inference(model)
|
304 |
+
pass
|
305 |
+
clear_output()
|
306 |
+
|
307 |
+
import gradio
|
308 |
+
gradio.strings.en["SHARE_LINK_DISPLAY"] = ""
|
309 |
+
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
|
310 |
+
from threading import Thread
|
311 |
+
|
312 |
+
class StopOnTokens(StoppingCriteria):
|
313 |
+
def __init__(self, stop_token_ids):
|
314 |
+
self.stop_token_ids = tuple(set(stop_token_ids))
|
315 |
+
pass
|
316 |
+
|
317 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
318 |
+
return input_ids[0][-1].item() in self.stop_token_ids
|
319 |
+
pass
|
320 |
+
pass
|
321 |
+
|
322 |
+
def async_process_chatbot(message, history):
|
323 |
+
eos_token = tokenizer.eos_token
|
324 |
+
stop_on_tokens = StopOnTokens([eos_token,])
|
325 |
+
text_streamer = TextIteratorStreamer(tokenizer, skip_prompt = True)
|
326 |
+
|
327 |
+
# From https://www.gradio.app/guides/creating-a-chatbot-fast
|
328 |
+
history_transformer_format = history + [[message, ""]]
|
329 |
+
messages = []
|
330 |
+
for item in history_transformer_format:
|
331 |
+
messages.append({"role": "user", "content": item[0]})
|
332 |
+
messages.append({"role": "assistant", "content": item[1]})
|
333 |
+
pass
|
334 |
+
# Remove last assistant and instead use add_generation_prompt
|
335 |
+
messages.pop(-1)
|
336 |
+
|
337 |
+
input_ids = tokenizer.apply_chat_template(
|
338 |
+
messages,
|
339 |
+
add_generation_prompt = True,
|
340 |
+
return_tensors = "pt",
|
341 |
+
).to("cuda", non_blocking = True)
|
342 |
+
|
343 |
+
# Add stopping criteria - will not output EOS / EOT
|
344 |
+
generation_kwargs = dict(
|
345 |
+
input_ids = input_ids,
|
346 |
+
streamer = text_streamer,
|
347 |
+
max_new_tokens = 1024,
|
348 |
+
stopping_criteria = StoppingCriteriaList([stop_on_tokens,]),
|
349 |
+
temperature = 0.7,
|
350 |
+
do_sample = True,
|
351 |
+
)
|
352 |
+
thread = Thread(target = model.generate, kwargs = generation_kwargs)
|
353 |
+
thread.start()
|
354 |
+
|
355 |
+
# Yield will save the output to history!
|
356 |
+
generated_text = ""
|
357 |
+
for new_text in text_streamer:
|
358 |
+
if new_text.endswith(eos_token):
|
359 |
+
new_text = new_text[:len(new_text) - len(eos_token)]
|
360 |
+
generated_text += new_text
|
361 |
+
yield generated_text
|
362 |
+
pass
|
363 |
+
pass
|
364 |
+
|
365 |
+
studio_theme = gradio.themes.Soft(
|
366 |
+
primary_hue = "teal",
|
367 |
+
)
|
368 |
+
|
369 |
+
scene = gradio.ChatInterface(
|
370 |
+
async_process_chatbot,
|
371 |
+
chatbot = gradio.Chatbot(
|
372 |
+
height = 325,
|
373 |
+
label = "Unsloth Studio Chat",
|
374 |
+
),
|
375 |
+
textbox = gradio.Textbox(
|
376 |
+
placeholder = "Message Unsloth Chat",
|
377 |
+
container = False,
|
378 |
+
),
|
379 |
+
title = None,
|
380 |
+
theme = studio_theme,
|
381 |
+
examples = None,
|
382 |
+
cache_examples = False,
|
383 |
+
retry_btn = None,
|
384 |
+
undo_btn = "Remove Previous Message",
|
385 |
+
clear_btn = "Restart Entire Chat",
|
386 |
+
)
|
387 |
+
|
388 |
+
scene.launch(quiet = True)
|
389 |
+
```
|
390 |
+
|
391 |
+
|
392 |
This llama model was trained 2x faster with [Unsloth](https://github.com/unslothai/unsloth) and Huggingface's TRL library.
|
393 |
|
394 |
[<img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20made%20with%20love.png" width="200"/>](https://github.com/unslothai/unsloth)
|