File size: 4,162 Bytes
7b8756a
 
 
 
 
 
 
 
7ba5f92
7b8756a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ba5f92
7b8756a
 
 
04b2de5
7b8756a
 
 
 
 
 
 
 
7ba5f92
 
 
 
 
7b8756a
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import json
import subprocess
from threading import Thread

import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, AutoModelForSeq2SeqLM

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

MODEL_ID = os.environ.get("MODEL_ID")
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")


@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history with a given chat template
    if CHAT_TEMPLATE == "ChatML":
        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
        for human, assistant in history:
            instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
        instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
    elif CHAT_TEMPLATE == "Mistral Instruct":
        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
        instruction = '<s>[INST] ' + system_prompt
        for human, assistant in history:
            instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
        instruction += ' ' + message + ' [/INST]'
    else:
        raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
    print(instruction)
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        if new_token in stop_tokens:
            break
        yield "".join(outputs)


# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=quantization_config,
    trust_remote_code=True
)

# Create Gradio interface
gr.ChatInterface(
    predict,
    title=EMOJI + " " + MODEL_NAME,
    description=DESCRIPTION,
    examples=[
        ["ู‡ู„ ูŠู…ูƒู†ูƒ ุญู„ ุงู„ู…ุนุงุฏู„ุฉ 2x + 3 = 11 ู„ู„ู…ุชุบูŠุฑ xุŸ"],
        ["ุงูƒุชุจ ู‚ุตูŠุฏุฉ ู…ู„ุญู…ูŠุฉ ุนู† ู…ุฏูŠู†ุฉ ุงู„ู‚ุฏุณ ุงู„ุดุฑูŠูุฉ."],
        ["ู…ู† ูƒุงู† ุฃูˆู„ ุดุฎุต ูŠู…ุดูŠ ุนู„ู‰ ุณุทุญ ุงู„ู‚ู…ุฑุŸ"],
        ["ุฃูˆุตู ุจุจุนุถ ูƒุชุจ ุงู„ุฎูŠุงู„ ุงู„ุนู„ู…ูŠ ุงู„ุดู‡ูŠุฑุฉ."],
        ["ู‡ู„ ูŠู…ูƒู†ูƒ ูƒุชุงุจุฉ ู‚ุตุฉ ู‚ุตูŠุฑุฉ ุนู† ู…ุญู‚ู‚ ูŠุณุงูุฑ ุนุจุฑ ุงู„ุฒู…ู†ุŸ"]
    ],
    additional_inputs_accordion=gr.Accordion(label="โš™๏ธ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
        gr.Slider(0, 1, 0.8, label="Temperature"),
        gr.Slider(128, 4096, 1024, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()