File size: 2,513 Bytes
bd0332f
1d72a65
933ec2b
a475ce0
09c998a
 
bd0332f
9ce660a
bd0332f
09c998a
 
d9b1afb
933ec2b
 
 
 
4becd74
933ec2b
 
a475ce0
1d72a65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933ec2b
bb98ae2
1d72a65
 
933ec2b
1d72a65
 
 
 
 
 
 
bd0332f
d9b1afb
933ec2b
9b150da
933ec2b
d9b1afb
80c63c5
 
d9b1afb
 
 
b94cdc8
d9b1afb
 
b94cdc8
 
bd0332f
 
d9b1afb
933ec2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
from threading import Thread
from typing import Iterator

model_id = "mistralai/Mistral-Nemo-Instruct-2407"

MAX_INPUT_TOKEN_LENGTH = 4096

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    load_in_8bit=True
)

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9
) -> Iterator[str]:
    conversation = []
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        temperature=temperature,
        num_beams=1
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Set up Gradio interface
iface = gr.ChatInterface(
    generate,
    chatbot=gr.Chatbot(height=600),
    textbox=gr.Textbox(placeholder="Enter your message here...", container=False, scale=7),
    title="Chat with Mistral Nemo",
    description="This is a chat interface for the Mistral Nemo model. Ask questions and get answers!",
    retry_btn="Retry",
    undo_btn="Undo Last",
    clear_btn="Clear",
    additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Maximum number of new tokens"),
        gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

# Launch the interface
iface.launch()