File size: 2,389 Bytes
8b04d55
dfce08c
 
8b04d55
dfce08c
e233fc5
8b04d55
dfce08c
e233fc5
dfce08c
e233fc5
dfce08c
 
 
 
 
8b04d55
 
 
 
dfce08c
 
 
9508422
8b04d55
 
 
 
 
 
 
 
 
 
dfce08c
 
8b04d55
dfce08c
 
 
 
 
 
 
 
 
 
9508422
dfce08c
 
 
 
8b04d55
dfce08c
 
8b04d55
 
 
 
 
 
 
dfce08c
9508422
 
8b04d55
 
 
 
 
 
 
9508422
8b04d55
 
 
 
 
e233fc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

# Load model and tokenizer
model_name = "yuchenlin/Rex-v0.1-1.5B"

device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, rex_size=3)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto"
)
model.to(device)

@spaces.GPU(enable_queue=True)
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens=512,
    temperature=0.5,
    top_p=1.0,
    repetition_penalty=1.1,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})
 
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens = max_tokens,
        temperature = temperature,
        top_p = top_p,
        repetition_penalty=repetition_penalty,
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI assistant and your name is RexLM.", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.5, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.5, maximum=1.5, value=1.1, step=0.1, label="Repetation Penalty"),
    ],
)


if __name__ == "__main__":
    demo.launch(share=True)