ezcz commited on
Commit
c3e099a
1 Parent(s): f91fde9

Update space

Browse files
Files changed (1) hide show
  1. app.py +30 -31
app.py CHANGED
@@ -1,71 +1,70 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import pipeline
 
 
4
  import logging
5
- import warnings
6
- from threading import Lock
7
-
8
- # Suppress non-critical warnings
9
- warnings.filterwarnings("ignore", message=".*Torch was not compiled with flash attention.*")
10
 
11
  # Setup logging
12
  logging.basicConfig(level=logging.INFO)
13
  logger = logging.getLogger(__name__)
14
 
15
  # Model configuration
16
- MODEL_ID = "ezcz/bright-llama-3b-chat"
17
- MAX_NEW_TOKENS = 256
18
- TEMPERATURE = 0.1
19
- TOP_P = 0.9
20
- TOP_K = 60
21
- REPETITION_PENALTY = 1.0
22
 
23
  # Check for GPU availability
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  logger.info(f"Using device: {device}")
26
 
27
- # Load the pipeline
28
- logger.info("Loading model pipeline...")
 
 
 
 
 
 
 
 
 
 
 
 
29
  pipe = pipeline(
30
  "text-generation",
31
- model=MODEL_ID,
 
32
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
33
- device_map="auto",
34
  )
35
 
36
- # Define the chat interface
37
  def chat_interface(user_input, history=None):
38
  if history is None:
39
  history = []
40
-
41
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
42
 
43
  try:
44
  outputs = pipe(
45
  messages,
46
- max_new_tokens=MAX_NEW_TOKENS,
47
- temperature=TEMPERATURE,
48
- top_p=TOP_P,
49
- top_k=TOP_K,
50
- repetition_penalty=REPETITION_PENALTY,
51
  )
52
  response = outputs[0]["generated_text"]
53
  history.append((user_input, response))
54
  return "", history
55
  except Exception as e:
56
- logger.error(f"Error during response generation: {e}")
57
  return "Error generating response.", history
58
 
59
- # Define the Gradio interface
60
  with gr.Blocks() as demo:
61
  chatbot = gr.Chatbot()
62
  user_input = gr.Textbox(placeholder="Type your message...")
63
- clear_button = gr.Button("Clear Chat")
64
  submit_button = gr.Button("Send")
65
-
66
  submit_button.click(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
67
- user_input.submit(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
68
- clear_button.click(lambda: ([], ""), inputs=[], outputs=[chatbot, user_input])
69
 
70
- # Launch the UI
71
- demo.queue().launch(debug=True, share=True)
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ from peft import PeftModel
5
+ import os
6
  import logging
 
 
 
 
 
7
 
8
  # Setup logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
  # Model configuration
13
+ BASE_MODEL_ID = "unsloth/Llama-3.2-3B-instruct" # The base model you fine-tuned
14
+ ADAPTER_MODEL_ID = "ezcz/bright-llama-3b-chat"
 
 
 
 
15
 
16
  # Check for GPU availability
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  logger.info(f"Using device: {device}")
19
 
20
+ # Load the base model and apply the adapter
21
+ logger.info("Loading base model...")
22
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
23
+ base_model = AutoModelForCausalLM.from_pretrained(
24
+ BASE_MODEL_ID,
25
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
26
+ device_map="auto"
27
+ )
28
+
29
+ # Load the adapter model on top of the base model
30
+ logger.info("Loading adapter weights...")
31
+ model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID)
32
+
33
+ # Create the pipeline with the combined model
34
  pipe = pipeline(
35
  "text-generation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
39
+ device_map="auto"
40
  )
41
 
 
42
  def chat_interface(user_input, history=None):
43
  if history is None:
44
  history = []
45
+
46
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
47
 
48
  try:
49
  outputs = pipe(
50
  messages,
51
+ max_new_tokens=256,
52
+ temperature=0.1,
53
+ top_p=0.9,
54
+ top_k=60,
55
+ repetition_penalty=1.0
56
  )
57
  response = outputs[0]["generated_text"]
58
  history.append((user_input, response))
59
  return "", history
60
  except Exception as e:
61
+ logger.error(f"Error generating response: {e}")
62
  return "Error generating response.", history
63
 
 
64
  with gr.Blocks() as demo:
65
  chatbot = gr.Chatbot()
66
  user_input = gr.Textbox(placeholder="Type your message...")
 
67
  submit_button = gr.Button("Send")
 
68
  submit_button.click(chat_interface, inputs=[user_input, chatbot], outputs=[user_input, chatbot])
 
 
69
 
70
+ demo.launch(debug=True, share=True)