Spaces:
Runtime error
Runtime error
File size: 1,832 Bytes
cd4405e 5ccf2c1 cd4405e cf7b404 cd4405e 5ccf2c1 a5774b4 5ccf2c1 cd4405e 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 5ccf2c1 a5774b4 46d4ec4 a5774b4 5ccf2c1 46d4ec4 5ccf2c1 46d4ec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import gradio as gr
import os
import keras_nlp
from transformers import AutoModelForCausalLM
# Set Kaggle API credentials
os.environ["KAGGLE_USERNAME"] = "rogerkorantenng"
os.environ["KAGGLE_KEY"] = "9a33b6e88bcb6058b1281d777fa6808d"
# Load LoRA weights if you have them
LoRA_weights_path = "fined-tuned-model.lora.h5"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.backbone.enable_lora(rank=4) # Enable LoRA with rank 4
gemma_lm.preprocessor.sequence_length = 512 # Limit sequence length
gemma_lm.backbone.load_lora_weights(LoRA_weights_path) # Load LoRA weights
# Define the response generation function
def generate_response(message):
# Create a prompt template
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
# Create the prompt with the current message
prompt = template.format(instruction=message, response="")
print("Prompt:\n", prompt)
# Generate response from the model
response = gemma_lm.generate(prompt, max_length=256)
# Only keep the generated response
response = response.split("Response:")[-1].strip()
print("Generated Response:\n", response)
# Extract and return the generated response text
return response # Adjust this if your model's output structure differs
# Create the Gradio chat interface
interface = gr.Interface(
fn=generate_response, # Function that generates responses
inputs=gr.Textbox(placeholder="Hello, I am Sage, your mental health advisor", lines=2, scale=7),
outputs=gr.Textbox(),
title="Sage, your Mental Health Advisor",
# description="Chat with Sage, your mental health advisor.",
# live=True
)
proxy_prefix = os.environ.get("PROXY_PREFIX")
# Launch the Gradio app
interface.launch(server_name="0.0.0.0", server_port=8080, root_path=proxy_prefix, share=True)
|