Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,917 Bytes
e55bd08 f60e921 e55bd08 8cffdb8 c9dcbb2 8cffdb8 e55bd08 e44fa8d e55bd08 8cffdb8 e55bd08 f60e921 78da2d6 f60e921 e55bd08 f60e921 131a07a e55bd08 131a07a f60e921 131a07a f60e921 131a07a f60e921 15d1015 f60e921 e55bd08 f60e921 e55bd08 78da2d6 e55bd08 8cffdb8 78da2d6 8cffdb8 78da2d6 8cffdb8 e55bd08 8cffdb8 e55bd08 0474700 e55bd08 e7e3b25 e55bd08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
from threading import Thread
from typing import Iterator
# Add markdown header
header = """
# 🐦⬛ MagpieLMs: Open LLMs with Fully Transparent Alignment Recipes
💬 We've aligned Llama-3.1-8B and a 4B version (distilled by NVIDIA) using purely synthetic data generated by our [Magpie](https://arxiv.org/abs/2406.08464) method. Our open-source post-training recipe includes: SFT and DPO data, all training configs + logs. This allows everyone to reproduce the alignment process for their own research. Note that our data does not contain any GPT-generated data, and has a much friendly license for both commercial and academic use.
🔗 Links: [**Magpie Collection**](https://huggingface.co/collections/Magpie-Align/magpielm-66e2221f31fa3bf05b10786a); [**Magpie Paper**](https://arxiv.org/abs/2406.08464) 📮 Contact: [Zhangchen Xu](https://zhangchenxu.com) and [Bill Yuchen Lin](https://yuchenlin.xyz).
---
"""
# Load model and tokenizer
model_name = "Magpie-Align/MagpieLM-8B-Chat-v0.1"
device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
ignore_mismatched_sizes=True
)
model.to(device)
MAX_INPUT_TOKEN_LENGTH = 4096 # You may need to adjust this value
@spaces.GPU
def respond(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are Magpie, a helpful AI assistant. For simple queries, try to answer them directly; for complex questions, try to think step-by-step before providing an answer.", label="System message"),
gr.Slider(minimum=128, maximum=2048, value=512, step=64, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p (nucleus sampling)",
),
gr.Slider(minimum=0.5, maximum=1.5, value=1.0, step=0.1, label="Repetition Penalty"),
],
description=header, # Add the header as the description
title="MagpieLM-8B Chat (v0.1)",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.queue()
demo.launch(share=True)
|