ManoVyadh / app.py
ShieldX's picture
Update app.py
22da7c8 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
tokenizer = AutoTokenizer.from_pretrained("ShieldX/manovyadh-1.1B-v1")
model = AutoModelForCausalLM.from_pretrained("ShieldX/manovyadh-1.1B-v1")
# Check for GPU availability
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
# Move model and inputs to the GPU (if available)
model.to(device)
title = "🌱 ManoVyadh 🌱"
description = "Mental Health Counselling Chatbot"
examples = ["I have been feeling more and more down for over a month. I have started having trouble sleeping due to panic attacks, but they are almost never triggered by something that I know of.", "I self-harm, and I stop for a while. Then when I see something sad or depressing, I automatically want to self-harm.", "I am feeling sad for my friend's divorce"]
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [1, 2]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
sys_msg = """###SYSTEM: You are an AI assistant that helps people cope with stress and improve their mental health. User will tell you about their feelings and challenges. Your task is to listen empathetically and offer helpful suggestions. While responding, think about the user’s needs and goals and show compassion and support"""
messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message +
for item in history_transformer_format])
# def format_prompt(q):
# return f"""{sys_msg}
# ###USER: {q}
# ###ASSISTANT:"""
# messages = format_prompt(message)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
top_p=0.95,
top_k=1000,
temperature=0.2,
num_beams=1,
eos_token_id=[tokenizer.eos_token_id],
pad_token_id=tokenizer.eos_token_id,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
if new_token != '<':
# if "#" in new_token:
# break
# else:
partial_message += new_token
yield partial_message
gr.ChatInterface(
predict,
title=title,
description=description,
examples=examples,
theme="finlaymacklon/boxy_violet",
).launch(debug=True)