File size: 1,832 Bytes
cd4405e
5ccf2c1
 
 
cd4405e
cf7b404
 
 
cd4405e
5ccf2c1
a5774b4
5ccf2c1
 
 
 
cd4405e
5ccf2c1
a5774b4
5ccf2c1
 
 
a5774b4
5ccf2c1
a5774b4
5ccf2c1
 
a5774b4
5ccf2c1
a5774b4
5ccf2c1
a5774b4
5ccf2c1
 
 
 
 
a5774b4
5ccf2c1
a5774b4
 
46d4ec4
a5774b4
 
5ccf2c1
46d4ec4
5ccf2c1
46d4ec4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import gradio as gr
import os
import keras_nlp
from transformers import AutoModelForCausalLM

# Set Kaggle API credentials
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"

# Load LoRA weights if you have them
LoRA_weights_path = "fined-tuned-model.lora.h5"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.backbone.enable_lora(rank=4)  # Enable LoRA with rank 4
gemma_lm.preprocessor.sequence_length = 512  # Limit sequence length
gemma_lm.backbone.load_lora_weights(LoRA_weights_path)  # Load LoRA weights

# Define the response generation function
def generate_response(message):
    # Create a prompt template
    template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

    # Create the prompt with the current message
    prompt = template.format(instruction=message, response="")
    print("Prompt:\n", prompt)

    # Generate response from the model
    response = gemma_lm.generate(prompt, max_length=256)
    # Only keep the generated response
    response = response.split("Response:")[-1].strip()

    print("Generated Response:\n", response)

    # Extract and return the generated response text
    return response  # Adjust this if your model's output structure differs

# Create the Gradio chat interface
interface = gr.Interface(
    fn=generate_response,  # Function that generates responses
    inputs=gr.Textbox(placeholder="Hello, I am Sage, your mental health advisor", lines=2, scale=7),
    outputs=gr.Textbox(),
    title="Sage, your Mental Health Advisor",
#     description="Chat with Sage, your mental health advisor.",
#     live=True
)
proxy_prefix = os.environ.get("PROXY_PREFIX")
# Launch the Gradio app
interface.launch(server_name="0.0.0.0", server_port=8080, root_path=proxy_prefix, share=True)