ybelkada commited on
Commit
ec505fb
1 Parent(s): 977b2d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -4
app.py CHANGED
@@ -1,7 +1,158 @@
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import subprocess
2
+ import sys
3
+ import spaces
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
  import gradio as gr
8
+ from threading import Thread
9
+
10
+ MODEL = "tiiuae/falcon-mamba-7b-instruct"
11
+
12
+ TITLE = "<h1><center>FalconMamba-7b playground</center></h1>"
13
+ SUB_TITLE = """<center>FalconMamba is a new model released by Technology Innovation Institute (TII) in Abu Dhabi. The model is open source and available within the Hugging Face ecosystem for anyone to use it for their research or application purpose. Refer to <a href="https://hf.co/blog/falconmamba">the HF release blogpost</a> or <a href="https://www.tii.ae/news/uaes-technology-innovation-institute-revolutionizes-ai-language-models-new-architecture">the official announcement</a> for more details. This interface has been created for quick validation purposes, do not use it for production.</center>"""
14
+
15
+ CSS = """
16
+ .duplicate-button {
17
+ margin: auto !important;
18
+ color: white !important;
19
+ background: black !important;
20
+ border-radius: 100vh !important;
21
+ }
22
+ h3 {
23
+ text-align: center;
24
+ }
25
+ """
26
+
27
+ END_MESSAGE = """
28
+ \n
29
+ **The conversation has reached to its end, please press "Clear" to restart a new conversation**
30
+ """
31
+
32
+ device = "cuda" # for GPU usage or "cpu" for CPU usage
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL,
37
+ torch_dtype=torch.bfloat16,
38
+ ).to(device)
39
+
40
+ if device == "cuda":
41
+ model = torch.compile(model)
42
+
43
+ @spaces.GPU
44
+ def stream_chat(
45
+ message: str,
46
+ history: list,
47
+ temperature: float = 0.3,
48
+ max_new_tokens: int = 1024,
49
+ top_p: float = 1.0,
50
+ top_k: int = 20,
51
+ penalty: float = 1.2,
52
+ ):
53
+ print(f'message: {message}')
54
+ print(f'history: {history}')
55
+
56
+ conversation = []
57
+ for prompt, answer in history:
58
+ conversation.extend([
59
+ {"role": "user", "content": prompt},
60
+ {"role": "assistant", "content": answer},
61
+ ])
62
+
63
+
64
+ conversation.append({"role": "user", "content": message})
65
+
66
+
67
+ input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
68
+ input_text += "<|im_start|>assistant\n"
69
+
70
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
71
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
72
+
73
+ generate_kwargs = dict(
74
+ input_ids=inputs,
75
+ max_new_tokens = max_new_tokens,
76
+ do_sample = False if temperature == 0 else True,
77
+ top_p = top_p,
78
+ top_k = top_k,
79
+ temperature = temperature,
80
+ streamer=streamer,
81
+ pad_token_id = 10,
82
+ )
83
+
84
+ with torch.no_grad():
85
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
86
+ thread.start()
87
+
88
+ buffer = ""
89
+ for new_text in streamer:
90
+ buffer += new_text
91
+ yield buffer
92
+
93
+
94
+ print(f'response: {buffer}')
95
+
96
+ chatbot = gr.Chatbot(height=600)
97
+
98
+ with gr.Blocks(css=CSS, theme="soft") as demo:
99
+ gr.HTML(TITLE)
100
+ gr.HTML(SUB_TITLE)
101
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
102
+ gr.ChatInterface(
103
+ fn=stream_chat,
104
+ chatbot=chatbot,
105
+ fill_height=True,
106
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
107
+ additional_inputs=[
108
+ gr.Slider(
109
+ minimum=0,
110
+ maximum=1,
111
+ step=0.1,
112
+ value=0.3,
113
+ label="Temperature",
114
+ render=False,
115
+ ),
116
+ gr.Slider(
117
+ minimum=128,
118
+ maximum=8192,
119
+ step=1,
120
+ value=1024,
121
+ label="Max new tokens",
122
+ render=False,
123
+ ),
124
+ gr.Slider(
125
+ minimum=0.0,
126
+ maximum=1.0,
127
+ step=0.1,
128
+ value=1.0,
129
+ label="top_p",
130
+ render=False,
131
+ ),
132
+ gr.Slider(
133
+ minimum=1,
134
+ maximum=20,
135
+ step=1,
136
+ value=20,
137
+ label="top_k",
138
+ render=False,
139
+ ),
140
+ gr.Slider(
141
+ minimum=0.0,
142
+ maximum=2.0,
143
+ step=0.1,
144
+ value=1.2,
145
+ label="Repetition penalty",
146
+ render=False,
147
+ ),
148
+ ],
149
+ examples=[
150
+ ["Hello there, can you suggest few places to visit in UAE?"],
151
+ ["What UAE is known for?"],
152
+ ],
153
+ cache_examples=False,
154
+ )
155
 
 
 
156
 
157
+ if __name__ == "__main__":
158
+ demo.launch()