File size: 4,161 Bytes
988d041
 
 
 
 
 
 
 
 
 
3fe043e
988d041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
import subprocess
from threading import Thread

import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

MODEL_ID = os.environ.get("MODEL_ID")
CHAT_TEMPLATE = os.environ.get("CHAT_TEMPLATE")
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = int(os.environ.get("CONTEXT_LENGTH"))
COLOR = os.environ.get("COLOR")
EMOJI = os.environ.get("EMOJI")
DESCRIPTION = os.environ.get("DESCRIPTION")


@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history with a given chat template
    if CHAT_TEMPLATE == "ChatML":
        stop_tokens = ["<|endoftext|>", "<|im_end|>"]
        instruction = '<|im_start|>system\n' + system_prompt + '\n<|im_end|>\n'
        for human, assistant in history:
            instruction += '<|im_start|>user\n' + human + '\n<|im_end|>\n<|im_start|>assistant\n' + assistant
        instruction += '\n<|im_start|>user\n' + message + '\n<|im_end|>\n<|im_start|>assistant\n'
    elif CHAT_TEMPLATE == "Mistral Instruct":
        stop_tokens = ["</s>", "[INST]", "[INST] ", "<s>", "[/INST]", "[/INST] "]
        instruction = '<s>[INST] ' + system_prompt
        for human, assistant in history:
            instruction += human + ' [/INST] ' + assistant + '</s>[INST]'
        instruction += ' ' + message + ' [/INST]'
    else:
        raise Exception("Incorrect chat template, select 'ChatML' or 'Mistral Instruct'")
    print(instruction)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer([instruction], return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        {"input_ids": input_ids.to(device), "attention_mask": attention_mask.to(device)},
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        if new_token in stop_tokens:
            break
        yield "".join(outputs)


# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    # bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

# Create Gradio interface
gr.ChatInterface(
    predict,
    title=EMOJI + " " + MODEL_NAME,
    description=DESCRIPTION,
    examples=[
        ["Calculate the volume of a sphere with a radius of 5 centimeters."],
        ["Describe a dystopian world where water is more valuable than gold."],
        ["Tips for first-time travelers to Japan."],
        ["Write a Python function to check if a number is prime."],
        ["Explain the process of photosynthesis to a 10-year-old."],
        ["Create a story about a lost civilization found beneath the Pacific Ocean."]
    ],
    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
    additional_inputs=[
        gr.Textbox("Perform the task to the best of your ability.", label="System prompt"),
        gr.Slider(0, 1, 0.8, label="Temperature"),
        gr.Slider(128, 4096, 1024, label="Max new tokens"),
        gr.Slider(1, 80, 40, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
    theme=gr.themes.Soft(primary_hue=COLOR),
).queue().launch()