Master88 commited on
Commit
871f7db
1 Parent(s): f909466

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -127
app.py CHANGED
@@ -1,129 +1,3 @@
1
  import gradio as gr
2
- import os
3
- import spaces
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
- from threading import Thread
7
 
8
-
9
- DESCRIPTION = '''
10
- <div>
11
- <h1 style="text-align: center;">OpenChat 3.6</h1>
12
- </div>
13
- '''
14
-
15
-
16
- PLACEHOLDER = """
17
- <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
18
- <img src="https://raw.githubusercontent.com/imoneoi/openchat/master/assets/logo_new.png" style="width: 80%; max-width: 550px; height: auto; opacity: 0.55; ">
19
- <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">OpenChat 3.6</h1>
20
- <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
21
- </div>
22
- """
23
-
24
-
25
- css = """
26
- h1 {
27
- text-align: center;
28
- display: block;
29
- }
30
-
31
- #duplicate-button {
32
- margin: auto;
33
- color: white;
34
- background: #1565c0;
35
- border-radius: 100vh;
36
- }
37
-
38
- footer {
39
- visibility: hidden
40
- }
41
- """
42
-
43
- # Load the tokenizer and model
44
- tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.6-8b-20240522")
45
- model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.6-8b-20240522", device_map="auto") # to("cuda:0")
46
- terminators = [
47
- tokenizer.eos_token_id,
48
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
49
- ]
50
-
51
- @spaces.GPU(duration=120)
52
- def chat_openchat_36(message: str,
53
- history: list,
54
- temperature: float,
55
- max_new_tokens: int
56
- ) -> str:
57
- """
58
- Generate a streaming response using the openchat-3.6 model.
59
- Args:
60
- message (str): The input message.
61
- history (list): The conversation history used by ChatInterface.
62
- temperature (float): The temperature for generating the response.
63
- max_new_tokens (int): The maximum number of new tokens to generate.
64
- Returns:
65
- str: The generated response.
66
- """
67
- conversation = []
68
- for user, assistant in history:
69
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
70
- conversation.append({"role": "user", "content": message})
71
-
72
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
73
-
74
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
75
-
76
- generate_kwargs = dict(
77
- input_ids= input_ids,
78
- streamer=streamer,
79
- max_new_tokens=max_new_tokens,
80
- do_sample=True,
81
- temperature=temperature,
82
- eos_token_id=terminators,
83
- )
84
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
85
- if temperature == 0:
86
- generate_kwargs['do_sample'] = False
87
-
88
- t = Thread(target=model.generate, kwargs=generate_kwargs)
89
- t.start()
90
-
91
- outputs = []
92
- for text in streamer:
93
- outputs.append(text)
94
- #print(outputs)
95
- yield "".join(outputs)
96
-
97
-
98
- # Gradio block
99
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, show_label=False, layout="panel", avatar_images=(None, "bot.png"), likeable=True, show_copy_button=True)
100
-
101
- with gr.Blocks(fill_height=True, css=css, theme="theme-repo/STONE_Theme") as demo:
102
-
103
- gr.Markdown(DESCRIPTION)
104
- gr.ChatInterface(
105
- fn=chat_openchat_36,
106
- chatbot=chatbot,
107
- fill_height=True,
108
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
109
- additional_inputs=[
110
- gr.Slider(minimum=0,
111
- maximum=1,
112
- step=0.1,
113
- value=0.95,
114
- label="Temperature",
115
- render=False),
116
- gr.Slider(minimum=128,
117
- maximum=4096,
118
- step=1,
119
- value=512,
120
- label="Max new tokens",
121
- render=False ),
122
- ],
123
- cache_examples=False,
124
- )
125
-
126
-
127
- if __name__ == "__main__":
128
- demo.launch()
129
-
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
+ gr.load("models/openchat/openchat-3.6-8b-20240522").launch()