ezcz commited on
Commit
333aeab
1 Parent(s): 790457e

Update space

Browse files
Files changed (1) hide show
  1. app.py +31 -30
app.py CHANGED
@@ -1,70 +1,71 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
- from peft import PeftModel
5
- import os
6
  import logging
 
 
 
 
 
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  # Model configuration
13
- BASE_MODEL_ID = "unsloth/Llama-3.2-3B-instruct" # The base model you fine-tuned
14
- ADAPTER_MODEL_ID = "ezcz/bright-llama-3b-chat"
 
 
 
 
15
 
16
  # Check for GPU availability
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  logger.info(f"Using device: {device}")
19
 
20
- # Load the base model and apply the adapter
21
- logger.info("Loading base model...")
22
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
23
- base_model = AutoModelForCausalLM.from_pretrained(
24
- BASE_MODEL_ID,
25
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
26
- device_map="auto"
27
- )
28
-
29
- # Load the adapter model on top of the base model
30
- logger.info("Loading adapter weights...")
31
- model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
32
-
33
- # Create the pipeline with the combined model
34
  pipe = pipeline(
35
  "text-generation",
36
- model=model,
37
- tokenizer=tokenizer,
38
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
39
- device_map="auto"
40
  )
41
 
 
42
  def chat_interface(user_input, history=None):
43
  if history is None:
44
  history = []
45
-
46
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
47
 
48
  try:
49
  outputs = pipe(
50
  messages,
51
- max_new_tokens=256,
52
- temperature=0.1,
53
- top_p=0.9,
54
- top_k=60,
55
- repetition_penalty=1.0
56
  )
57
  response = outputs[0]["generated_text"]
58
  history.append((user_input, response))
59
  return "", history
60
  except Exception as e:
61
- logger.error(f"Error generating response: {e}")
62
  return "Error generating response.", history
63
 
 
64
  with gr.Blocks() as demo:
65
  chatbot = gr.Chatbot()
66
  user_input = gr.Textbox(placeholder="Type your message...")
 
67
  submit_button = gr.Button("Send")
 
68
  submit_button.click(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
 
 
69
 
70
- demo.launch(debug=True, share=True)
 
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import pipeline
 
 
4
  import logging
5
+ import warnings
6
+ from threading import Lock
7
+
8
+ # Suppress non-critical warnings
9
+ warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
10
 
11
  # Setup logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  # Model configuration
16
+ MODEL_ID = "ezcz/bright-llama-3b-chat"
17
+ MAX_NEW_TOKENS = 256
18
+ TEMPERATURE = 0.1
19
+ TOP_P = 0.9
20
+ TOP_K = 60
21
+ REPETITION_PENALTY = 1.0
22
 
23
  # Check for GPU availability
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  logger.info(f"Using device: {device}")
26
 
27
+ # Load the pipeline
28
+ logger.info("Loading model pipeline...")
 
 
 
 
 
 
 
 
 
 
 
 
29
  pipe = pipeline(
30
  "text-generation",
31
+ model=MODEL_ID,
 
32
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
33
+ device_map="auto",
34
  )
35
 
36
+ # Define the chat interface
37
  def chat_interface(user_input, history=None):
38
  if history is None:
39
  history = []
40
+
41
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
42
 
43
  try:
44
  outputs = pipe(
45
  messages,
46
+ max_new_tokens=MAX_NEW_TOKENS,
47
+ temperature=TEMPERATURE,
48
+ top_p=TOP_P,
49
+ top_k=TOP_K,
50
+ repetition_penalty=REPETITION_PENALTY,
51
  )
52
  response = outputs[0]["generated_text"]
53
  history.append((user_input, response))
54
  return "", history
55
  except Exception as e:
56
+ logger.error(f"Error during response generation: {e}")
57
  return "Error generating response.", history
58
 
59
+ # Define the Gradio interface
60
  with gr.Blocks() as demo:
61
  chatbot = gr.Chatbot()
62
  user_input = gr.Textbox(placeholder="Type your message...")
63
+ clear_button = gr.Button("Clear Chat")
64
  submit_button = gr.Button("Send")
65
+
66
  submit_button.click(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
67
+ user_input.submit(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
68
+ clear_button.click(lambda: ([], ""), inputs=[], outputs=[chatbot, user_input])
69
 
70
+ # Launch the UI
71
+ demo.queue().launch(debug=True, share=True)