ghuman7 commited on
Commit
7bbe0a9
1 Parent(s): a12a4d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -1,27 +1,29 @@
1
  import streamlit as st
2
- from transformers import RagRetriever, RagTokenizer, RagTokenForGeneration
3
 
4
- # Load the RAG model
5
- @st.cache_resource
6
- def load_rag_model():
7
- retriever = RagRetriever.from_pretrained("facebook/rag-token-nq")
8
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
9
- rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
10
- return tokenizer, retriever, rag_model
11
 
12
- # Set up the Streamlit interface
13
- def main():
14
- st.title("Mental Health Chatbot")
 
15
 
16
- tokenizer, retriever, rag_model = load_rag_model()
 
 
 
 
 
17
 
18
- user_input = st.text_input("Ask me something about mental health:")
19
-
20
- if user_input:
21
- input_ids = tokenizer(user_input, return_tensors="pt").input_ids
22
- output = rag_model.generate(input_ids)
23
- response = tokenizer.decode(output[0], skip_special_tokens=True)
24
- st.write(f"Response: {response}")
25
 
26
- if __name__ == "__main__":
27
- main()
 
 
 
 
 
 
1
  import streamlit as st
2
+ import requests
3
 
4
+ # Hugging Face API details
5
+ API_URL = "https://api-inference.huggingface.co/models/facebook/blenderbot-400M-distill"
6
+ headers = {"Authorization": f"Bearer {api_key}"}
 
 
 
 
7
 
8
+ # Function to query the model
9
+ def query(payload):
10
+ response = requests.post(API_URL, headers=headers, json=payload)
11
+ return response.json()
12
 
13
+ # Streamlit UI for Mental Health Chatbot
14
+ st.title("Mental Health Chatbot")
15
+ st.write("""
16
+ This chatbot provides responses to mental health-related queries.
17
+ Please note that this is an AI-based tool and is not a substitute for professional mental health support.
18
+ """)
19
 
20
+ # User input
21
+ user_input = st.text_input("How can I help you today?")
 
 
 
 
 
22
 
23
+ if st.button("Get Response"):
24
+ if user_input:
25
+ # Query the model
26
+ output = query({"inputs": user_input})
27
+ st.write(f"**Response:** {output['generated_text']}")
28
+ else:
29
+ st.write("Please enter a query to get a response.")