File size: 3,917 Bytes
e55bd08
 
 
 
f60e921
e55bd08
8cffdb8
 
 
 
 
c9dcbb2
8cffdb8
 
 
 
e55bd08
e44fa8d
e55bd08
 
 
 
 
8cffdb8
 
e55bd08
 
 
 
 
f60e921
78da2d6
f60e921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e55bd08
f60e921
131a07a
 
 
 
 
 
 
e55bd08
131a07a
f60e921
131a07a
 
f60e921
131a07a
f60e921
15d1015
 
f60e921
 
e55bd08
f60e921
 
 
 
e55bd08
 
78da2d6
e55bd08
8cffdb8
 
 
78da2d6
 
 
 
8cffdb8
78da2d6
 
8cffdb8
e55bd08
8cffdb8
 
 
e55bd08
0474700
e55bd08
e7e3b25
e55bd08
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
from threading import Thread
from typing import Iterator

# Add markdown header
header = """
# 🐦‍⬛ MagpieLMs: Open LLMs with Fully Transparent Alignment Recipes

💬 We've aligned Llama-3.1-8B and a 4B version (distilled by NVIDIA) using purely synthetic data generated by our [Magpie](https://arxiv.org/abs/2406.08464) method. Our open-source post-training recipe includes: SFT and DPO data, all training configs + logs. This allows everyone to reproduce the alignment process for their own research. Note that our data does not contain any GPT-generated data, and has a much friendly license for both commercial and academic use.
🔗 Links: [**Magpie Collection**](https://huggingface.co/collections/Magpie-Align/magpielm-66e2221f31fa3bf05b10786a); [**Magpie Paper**](https://arxiv.org/abs/2406.08464) 📮 Contact: [Zhangchen Xu](https://zhangchenxu.com) and [Bill Yuchen Lin](https://yuchenlin.xyz).

---
"""

# Load model and tokenizer
model_name = "Magpie-Align/MagpieLM-8B-Chat-v0.1"

device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    ignore_mismatched_sizes=True
)
model.to(device)

MAX_INPUT_TOKEN_LENGTH = 4096  # You may need to adjust this value

@spaces.GPU
def respond(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are Magpie, a helpful AI assistant. For simple queries, try to answer them directly; for complex questions, try to think step-by-step before providing an answer.", label="System message"),
        gr.Slider(minimum=128, maximum=2048, value=512, step=64, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.9,
            step=0.1,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.5, maximum=1.5, value=1.0, step=0.1, label="Repetition Penalty"),
    ],
    description=header,  # Add the header as the description
    title="MagpieLM-8B Chat (v0.1)",
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    demo.queue()
    demo.launch(share=True)