File size: 2,916 Bytes
85065a5
 
895471a
 
365487b
85065a5
895471a
f2a0c30
 
7e49d85
f2a0c30
 
e0db9a1
f2a0c30
895471a
 
 
 
 
b590e2a
f2a0c30
 
6d0888e
f2a0c30
 
895471a
d1e76ed
895471a
f5ebb1e
 
 
f2a0c30
895471a
f2a0c30
 
 
 
 
819b08c
f2a0c30
895471a
f2a0c30
 
 
 
895471a
f2a0c30
 
 
 
895471a
 
 
 
 
 
f2a0c30
0728a7a
895471a
f2a0c30
895471a
 
 
 
 
 
 
 
 
 
2cb1575
 
 
 
 
 
 
 
 
895471a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from threading import Thread

import gradio as gr
from huggingface_hub import InferenceClient
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

model_id = "CohereForAI/aya-expanse-8b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda")


@spaces.GPU
def generate(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p):
    
    conversation = [{"role": "system", "content": system_message}]
    for user, assistant in history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")

    # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
    #     input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    #     gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
            
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    generate,
    cache_examples=False,
    additional_inputs=[
        gr.Textbox(value="Je bent een vriendelijke, behulpzame chatbot", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
    examples=[
        ["""Vraagje: welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geit, bus"?"""],
        ["Schrijf een nieuwsbericht voor De Speld over de inzet van een kudde geiten door het Nederlands Forensisch Instituut"],
        ["Wat zijn 3 leuke dingen om te doen als ik een weekendje naar Friesland ga?"],
        ["Met wie trad clown Bassie op?"],
        ["Kan je naar de maan fietsen?"],
        ["Wat is het belang van open source taalmodellen?"],
    ],
    title="Aya Expanse 8B demo",
)


if __name__ == "__main__":
    demo.launch()