Spaces:
Runtime error
Runtime error
rogerkoranteng
commited on
Commit
•
a5774b4
1
Parent(s):
c510c65
Upload folder using huggingface_hub
Browse files- fined-tuned-model.lora.h5 +3 -0
- flagged/log.csv +4 -0
- main.py +14 -27
- main.py.save +52 -0
- requirements.txt +2 -1
fined-tuned-model.lora.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6fc9f1de53fe3d4eee5c536a0d566dafaf1d11d0167c526506bc9d89c7c3ebe3
|
3 |
+
size 5560280
|
flagged/log.csv
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Input,history,Response,history,flag,username,timestamp
|
2 |
+
"am sad
|
3 |
+
",,,,,,2024-09-04 18:19:14.583976
|
4 |
+
,,"I'm sorry to you and your family. I'm sure this is very upsetting for you. I'm not sure I can really offer much help. I can only imagine how you feel. I wish I could offer you a hug. I'm glad you're considering counseling. That's a good sign. It's a good sign because you're here, reading this. It's a good sign you're here asking this question. It's a good sign you're looking for answers. It's a good thing you're here on this site. You're doing something good for yourself. You're taking care of you. You're looking for answers. You're asking questions. You're looking for help. You're looking for support. You're looking for friends. You're looking for someone to talk to. You're looking for someone to talk to. You're looking for someone to talk with. You're looking for someone to talk about it all. You're looking for someone to listen. You're looking for someone to talk to. You're looking for someone to talk with. You'",,,,2024-09-04 18:25:01.121662
|
main.py
CHANGED
@@ -3,58 +3,45 @@ import os
|
|
3 |
import keras_nlp
|
4 |
from transformers import AutoModelForCausalLM
|
5 |
|
6 |
-
|
7 |
# Set Kaggle API credentials
|
8 |
-
|
9 |
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
|
10 |
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
|
11 |
|
12 |
# Load LoRA weights if you have them
|
13 |
-
LoRA_weights_path = "fined-tuned.lora.h5"
|
14 |
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
15 |
-
|
16 |
gemma_lm.backbone.enable_lora(rank=4) # Enable LoRA with rank 4
|
17 |
gemma_lm.preprocessor.sequence_length = 512 # Limit sequence length
|
18 |
gemma_lm.backbone.load_lora_weights(LoRA_weights_path) # Load LoRA weights
|
19 |
|
20 |
# Define the response generation function
|
21 |
-
def generate_response(message
|
22 |
# Create a prompt template
|
23 |
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
|
24 |
|
25 |
-
#
|
26 |
-
formatted_history = ""
|
27 |
-
for user_msg, bot_msg in history:
|
28 |
-
formatted_history += template.format(instruction=user_msg, response=bot_msg)
|
29 |
-
|
30 |
-
# Add the latest message from the user
|
31 |
prompt = template.format(instruction=message, response="")
|
32 |
-
print(prompt)
|
33 |
-
|
34 |
-
# Combine history with the latest prompt
|
35 |
-
final_prompt = formatted_history + prompt
|
36 |
-
print(final_prompt)
|
37 |
|
38 |
# Generate response from the model
|
39 |
-
response = gemma_lm.generate(
|
40 |
# Only keep the generated response
|
41 |
-
response = response.split("Response:")[1].strip()
|
42 |
|
43 |
-
print(response)
|
44 |
|
45 |
# Extract and return the generated response text
|
46 |
return response # Adjust this if your model's output structure differs
|
47 |
|
48 |
# Create the Gradio chat interface
|
49 |
-
interface = gr.
|
50 |
fn=generate_response, # Function that generates responses
|
51 |
-
|
52 |
-
|
53 |
-
title="
|
54 |
-
|
55 |
-
|
56 |
-
clear_btn="Clear" # Enable clear button
|
57 |
)
|
58 |
|
59 |
# Launch the Gradio app
|
60 |
-
interface.launch(share=True)
|
|
|
3 |
import keras_nlp
|
4 |
from transformers import AutoModelForCausalLM
|
5 |
|
|
|
6 |
# Set Kaggle API credentials
|
|
|
7 |
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
|
8 |
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
|
9 |
|
10 |
# Load LoRA weights if you have them
|
11 |
+
LoRA_weights_path = "fined-tuned-model.lora.h5"
|
12 |
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
|
|
13 |
gemma_lm.backbone.enable_lora(rank=4) # Enable LoRA with rank 4
|
14 |
gemma_lm.preprocessor.sequence_length = 512 # Limit sequence length
|
15 |
gemma_lm.backbone.load_lora_weights(LoRA_weights_path) # Load LoRA weights
|
16 |
|
17 |
# Define the response generation function
|
18 |
+
def generate_response(message):
|
19 |
# Create a prompt template
|
20 |
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
|
21 |
|
22 |
+
# Create the prompt with the current message
|
|
|
|
|
|
|
|
|
|
|
23 |
prompt = template.format(instruction=message, response="")
|
24 |
+
print("Prompt:\n", prompt)
|
|
|
|
|
|
|
|
|
25 |
|
26 |
# Generate response from the model
|
27 |
+
response = gemma_lm.generate(prompt, max_length=256)
|
28 |
# Only keep the generated response
|
29 |
+
response = response.split("Response:")[-1].strip()
|
30 |
|
31 |
+
print("Generated Response:\n", response)
|
32 |
|
33 |
# Extract and return the generated response text
|
34 |
return response # Adjust this if your model's output structure differs
|
35 |
|
36 |
# Create the Gradio chat interface
|
37 |
+
interface = gr.Interface(
|
38 |
fn=generate_response, # Function that generates responses
|
39 |
+
inputs=gr.Textbox(placeholder="Hello, I am Sage, your mental health advisor", lines=2, scale=7),
|
40 |
+
outputs=gr.Textbox(),
|
41 |
+
title="Welcome to Sage, your dedicated mental health advisor.",
|
42 |
+
# description="Chat with Sage, your mental health advisor.",
|
43 |
+
# live=True
|
|
|
44 |
)
|
45 |
|
46 |
# Launch the Gradio app
|
47 |
+
interface.launch(share=True, share_server_address="hopegivers.tech:7000")
|
main.py.save
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as grimport os
|
2 |
+
import keras
|
3 |
+
import keras_nlp
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
os.environ["KERAS_BACKEND"] = "jax"
|
8 |
+
# Avoid memory fragmentation on JAX backend.
|
9 |
+
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
# Set Kaggle API credentials
|
14 |
+
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
|
15 |
+
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
|
16 |
+
|
17 |
+
# Load environment variables
|
18 |
+
load_dotenv()
|
19 |
+
|
20 |
+
# Replace this with the path or method to load your local model
|
21 |
+
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
|
22 |
+
|
23 |
+
def generate_response(message, history):
|
24 |
+
# Format the conversation history for the local model
|
25 |
+
formatted_history = []
|
26 |
+
for user, assistant in history:
|
27 |
+
formatted_history.append(f"Instruction:\n{user}\n\nResponse:\n{assistant}")
|
28 |
+
|
29 |
+
# Add the latest user message to the history
|
30 |
+
formatted_history.append(f"Instruction:\n{message}\n\nResponse:\n")
|
31 |
+
|
32 |
+
# Join formatted history into a single string for input
|
33 |
+
input_text = "\n".join(formatted_history)
|
34 |
+
|
35 |
+
# Generate response from the local model
|
36 |
+
# Make sure to adjust this part according to your model's API
|
37 |
+
response = gemma_lm.generate(input_text, max_length=256)
|
38 |
+
|
39 |
+
# Extract the response text
|
40 |
+
# Adjust the response extraction based on the actual structure of your model's output
|
41 |
+
return response[0] # Change this line if necessary
|
42 |
+
|
43 |
+
# Create the Gradio interface
|
44 |
+
gr.ChatInterface(
|
45 |
+
generate_response,
|
46 |
+
chatbot=gr.Chatbot(height=300),
|
47 |
+
textbox=gr.Textbox(placeholder="You can ask me anything", container=False, scale=7),
|
48 |
+
title="Local Model Chat Bot",
|
49 |
+
retry_btn=None,
|
50 |
+
undo_btn="Delete Previous",
|
51 |
+
clear_btn="Clear"
|
52 |
+
).launch(share=True)
|
requirements.txt
CHANGED
@@ -92,4 +92,5 @@ urllib3==2.2.2
|
|
92 |
uvicorn==0.30.6
|
93 |
websockets==12.0
|
94 |
Werkzeug==3.0.4
|
95 |
-
wrapt==1.16.0
|
|
|
|
92 |
uvicorn==0.30.6
|
93 |
websockets==12.0
|
94 |
Werkzeug==3.0.4
|
95 |
+
wrapt==1.16.0
|
96 |
+
transformers
|