File size: 4,040 Bytes
0b15f14
140793a
1997cd5
0b15f14
140793a
aac3374
 
73660ac
 
aac3374
 
 
 
 
7d6878a
140793a
 
e1eb2b8
140793a
 
 
e1eb2b8
140793a
 
61d12d7
140793a
 
 
2147ae4
 
61d12d7
934db30
140793a
 
2147ae4
83a6345
140793a
 
 
 
 
 
 
c3acc2f
140793a
 
 
 
 
 
 
83a6345
140793a
c3acc2f
a68ea86
aac3374
 
5ec054f
 
aac3374
 
 
dddc929
aac3374
61d12d7
 
 
 
 
 
 
07466ed
0caf6b4
b525961
 
 
 
 
 
 
 
6ae1c70
b525961
 
07466ed
706d17f
fec6802
140793a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f441bd4
140793a
 
 
 
 
 
 
f441bd4
 
140793a
 
 
 
 
 
 
 
 
 
 
 
0b15f14
140793a
0b15f14
140793a
53aa7e6
140793a
 
 
 
 
 
161e125
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import gradio as gr
from huggingface_hub import InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN", None)

model2api = [
             "tiiuae/falcon-180B-chat",
             "meta-llama/Llama-2-70b-chat-hf",
             "codellama/CodeLlama-34b-Instruct-hf",
             "victor/CodeLlama-34b-Instruct-hf",
             "timdettmers/guanaco-33b-merged",
]

STOP_SEQUENCES = ["User:", "###", "<|endoftext|>", "</s>"]

EXAMPLES = [
    ["Hey LLAMA! Any recommendations for my holidays in Abu Dhabi?"],
    ["What's the Everett interpretation of quantum mechanics?"],
    ["Give me a list of the top 10 dive sites you would recommend around the world."],
    ["Can you tell me more about deep-water soloing?"],
    ["Can you write a short tweet about the release of our latest AI model, LLAMA LLM?"]
    ]

def format_prompt(message, history, system_prompt, bot_name):
  prompt = ""
  if system_prompt:
    prompt += f"System: {system_prompt}\n"
  for user_prompt, bot_response in history:
    prompt += f"User: {user_prompt}\n"
    prompt += f"{bot_name}: {bot_response}\n"
  prompt += f"""User: {message}\n{bot_name}:"""
  return prompt

seed = 42

def generate(
    prompt, history, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    global seed
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        stop_sequences=STOP_SEQUENCES,
        do_sample=True,
        seed=seed,
    )
    seed = seed + 1

    client = InferenceClient()
    clientList = (client.list_deployed_models('text-generation-inference'))['text-generation']
    for i in range(0, len(model2api)):
        model = model2api[i]
        if model in clientList:
            client = InferenceClient(model, token=HF_TOKEN)
            print(f"Choosen model: {model}")
            break

    if model == model2api[0]:
        bot_name = "Falcon"
    else:
        bot_name = "Assistant"
    
    formatted_prompt = format_prompt(prompt, history, system_prompt, bot_name)
    
    try:
        stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""

        for response in stream:
            output += response.token.text
    
            for stop_str in STOP_SEQUENCES:
                if output.endswith(stop_str):
                    output = output[:-len(stop_str)]
#                    output = output.rstrip()
                    yield output
            yield output
    except Exception as e:
        raise gr.Error(f"Client error while generating: {e}")
    return output

additional_inputs=[
    gr.Textbox("", label="Optional system prompt"),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=3000,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.01,
        maximum=0.99,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks() as demo:
    
    gr.ChatInterface(
        generate, 
        examples=EXAMPLES,
        additional_inputs=additional_inputs,
    ) 

#demo.queue(concurrency_count=100, api_open=False).launch(show_api=False)
demo.queue(concurrency_count=100).launch()