ezcz commited on
Commit
0955aae
1 Parent(s): 333aeab

Update space

Browse files
Files changed (1) hide show
  1. app.py +148 -65
app.py CHANGED
@@ -1,71 +1,154 @@
1
  import torch
2
- import gradio as gr
3
  from transformers import pipeline
4
- import logging
5
- import warnings
6
- from threading import Lock
7
-
8
- # Suppress non-critical warnings
9
- warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
10
-
11
- # Setup logging
12
- logging.basicConfig(level=logging.INFO)
13
- logger = logging.getLogger(__name__)
14
-
15
- # Model configuration
16
- MODEL_ID = "ezcz/bright-llama-3b-chat"
17
- MAX_NEW_TOKENS = 256
18
- TEMPERATURE = 0.1
19
- TOP_P = 0.9
20
- TOP_K = 60
21
- REPETITION_PENALTY = 1.0
22
-
23
- # Check for GPU availability
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- logger.info(f"Using device: {device}")
26
-
27
- # Load the pipeline
28
- logger.info("Loading model pipeline...")
29
- pipe = pipeline(
30
- "text-generation",
31
- model=MODEL_ID,
32
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
33
- device_map="auto",
34
- )
35
 
36
- # Define the chat interface
37
- def chat_interface(user_input, history=None):
38
- if history is None:
39
- history = []
40
-
41
- messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
42
-
43
- try:
44
- outputs = pipe(
45
- messages,
46
- max_new_tokens=MAX_NEW_TOKENS,
47
- temperature=TEMPERATURE,
48
- top_p=TOP_P,
49
- top_k=TOP_K,
50
- repetition_penalty=REPETITION_PENALTY,
51
  )
52
- response = outputs[0]["generated_text"]
53
- history.append((user_input, response))
54
- return "", history
55
- except Exception as e:
56
- logger.error(f"Error during response generation: {e}")
57
- return "Error generating response.", history
58
-
59
- # Define the Gradio interface
60
- with gr.Blocks() as demo:
61
- chatbot = gr.Chatbot()
62
- user_input = gr.Textbox(placeholder="Type your message...")
63
- clear_button = gr.Button("Clear Chat")
64
- submit_button = gr.Button("Send")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- submit_button.click(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
67
- user_input.submit(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
68
- clear_button.click(lambda: ([], ""), inputs=[], outputs=[chatbot, user_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- # Launch the UI
71
- demo.queue().launch(debug=True, share=True)
 
 
 
 
 
 
 
1
  import torch
 
2
  from transformers import pipeline
3
+ import gradio as gr
4
+ from datetime import datetime
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ class BrightLlamaChatbot:
7
+ def __init__(self):
8
+ self.model_id = "ezcz/bright-llama-3b-chat"
9
+ self.pipe = pipeline(
10
+ "text-generation",
11
+ model=self.model_id,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto",
 
 
 
 
 
 
 
14
  )
15
+ self.system_prompt = """You are a helpful AI assistant focused on coding and reasoning tasks.
16
+ You provide clear, accurate responses while maintaining a friendly tone."""
17
+
18
+ def format_message(self, role, content):
19
+ timestamp = datetime.now().strftime("%H:%M:%S")
20
+ return f"<div class='message {role}'><span class='time'>[{timestamp}]</span>{content}</div>"
21
+
22
+ def generate_response(self, user_input, chat_history):
23
+ if not user_input:
24
+ return chat_history
25
+
26
+ # Format conversation history for the model
27
+ messages = [{"role": "system", "content": self.system_prompt}]
28
+ for msg in chat_history.split("\n"):
29
+ if msg.strip():
30
+ if "User:" in msg:
31
+ messages.append({"role": "user", "content": msg.replace("User:", "").strip()})
32
+ elif "Assistant:" in msg:
33
+ messages.append({"role": "assistant", "content": msg.replace("Assistant:", "").strip()})
34
+
35
+ messages.append({"role": "user", "content": user_input})
36
+
37
+ # Generate response
38
+ response = self.pipe(messages, max_new_tokens=512, return_full_text=False, temperature=0.7, top_p=0.9)
39
+ assistant_response = response[0]["generated_text"]
40
+
41
+ # Format and update chat history
42
+ updated_history = chat_history + "\n" + f"User: {user_input}" + "\n" + f"Assistant: {assistant_response}"
43
+
44
+ return updated_history
45
 
46
+ def create_interface(self):
47
+ with gr.Blocks(css=self.get_custom_css()) as interface:
48
+ gr.HTML("<h1>🦙 Bright Llama Chatbot</h1>")
49
+ gr.HTML("<p>An AI assistant specialized in coding and reasoning tasks.</p>")
50
+
51
+ with gr.Row():
52
+ with gr.Column(scale=4):
53
+ chatbot = gr.Textbox(
54
+ show_label=False,
55
+ placeholder="Conversation history...",
56
+ lines=15,
57
+ max_lines=15,
58
+ interactive=False
59
+ )
60
+
61
+ with gr.Row():
62
+ with gr.Column(scale=8):
63
+ user_input = gr.Textbox(
64
+ show_label=False,
65
+ placeholder="Type your message here...",
66
+ lines=2
67
+ )
68
+ with gr.Column(scale=1):
69
+ submit_btn = gr.Button("Send", variant="primary")
70
+ clear_btn = gr.Button("Clear")
71
+
72
+ with gr.Column(scale=1):
73
+ gr.HTML("<h3>Features</h3>")
74
+ gr.HTML("""
75
+ <ul>
76
+ <li>💬 Natural conversation</li>
77
+ <li>🎨 Code syntax highlighting</li>
78
+ <li>⚡ Quick responses</li>
79
+ </ul>
80
+ """)
81
+
82
+ with gr.Accordion("Settings", open=False):
83
+ temperature = gr.Slider(
84
+ minimum=0.1,
85
+ maximum=1.0,
86
+ value=0.7,
87
+ step=0.1,
88
+ label="Temperature"
89
+ )
90
+
91
+ max_length = gr.Slider(
92
+ minimum=64,
93
+ maximum=1024,
94
+ value=512,
95
+ step=64,
96
+ label="Max Length"
97
+ )
98
+
99
+ submit_btn.click(
100
+ fn=self.generate_response,
101
+ inputs=[user_input, chatbot],
102
+ outputs=chatbot
103
+ ).then(
104
+ fn=lambda: "",
105
+ outputs=user_input
106
+ )
107
+
108
+ clear_btn.click(
109
+ fn=lambda: ("", ""),
110
+ outputs=[user_input, chatbot]
111
+ )
112
+
113
+ return interface
114
+
115
+ def get_custom_css(self):
116
+ return """
117
+ .message {
118
+ padding: 10px;
119
+ margin: 5px;
120
+ border-radius: 8px;
121
+ }
122
+ .user {
123
+ background-color: #f0f0f0;
124
+ margin-left: 20px;
125
+ }
126
+ .assistant {
127
+ background-color: #e3f2fd;
128
+ margin-right: 20px;
129
+ }
130
+ .time {
131
+ color: #666;
132
+ font-size: 0.8em;
133
+ margin-right: 10px;
134
+ }
135
+ pre {
136
+ background-color: #2b2b2b;
137
+ color: #ffffff;
138
+ padding: 10px;
139
+ border-radius: 5px;
140
+ overflow-x: auto;
141
+ }
142
+ code {
143
+ font-family: 'Courier New', monospace;
144
+ }
145
+ """
146
 
147
+ if __name__ == "__main__":
148
+ chatbot = BrightLlamaChatbot()
149
+ interface = chatbot.create_interface()
150
+ interface.launch(
151
+ server_name="0.0.0.0",
152
+ server_port=7860,
153
+ share=True
154
+ )