Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer | |
from threading import Thread | |
tokenizer = AutoTokenizer.from_pretrained("ShieldX/manovyadh-1.1B-v1") | |
model = AutoModelForCausalLM.from_pretrained("ShieldX/manovyadh-1.1B-v1") | |
# Check for GPU availability | |
if torch.cuda.is_available(): | |
device = "cuda" | |
else: | |
device = "cpu" | |
# Move model and inputs to the GPU (if available) | |
model.to(device) | |
title = "🌱 ManoVyadh 🌱" | |
description = "Mental Health Counselling Chatbot" | |
examples = ["I have been feeling more and more down for over a month. I have started having trouble sleeping due to panic attacks, but they are almost never triggered by something that I know of.", "I self-harm, and I stop for a while. Then when I see something sad or depressing, I automatically want to self-harm.", "I am feeling sad for my friend's divorce"] | |
class StopOnTokens(StoppingCriteria): | |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: | |
stop_ids = [1, 2] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
def predict(message, history): | |
history_transformer_format = history + [[message, ""]] | |
stop = StopOnTokens() | |
sys_msg = """###SYSTEM: You are an AI assistant that helps people cope with stress and improve their mental health. User will tell you about their feelings and challenges. Your task is to listen empathetically and offer helpful suggestions. While responding, think about the user’s needs and goals and show compassion and support""" | |
messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message + | |
for item in history_transformer_format]) | |
# def format_prompt(q): | |
# return f"""{sys_msg} | |
# ###USER: {q} | |
# ###ASSISTANT:""" | |
# messages = format_prompt(message) | |
model_inputs = tokenizer([messages], return_tensors="pt").to(device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False) | |
generate_kwargs = dict( | |
model_inputs, | |
streamer=streamer, | |
max_new_tokens=256, | |
do_sample=True, | |
top_p=0.95, | |
top_k=1000, | |
temperature=0.2, | |
num_beams=1, | |
eos_token_id=[tokenizer.eos_token_id], | |
pad_token_id=tokenizer.eos_token_id, | |
stopping_criteria=StoppingCriteriaList([stop]) | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
partial_message = "" | |
for new_token in streamer: | |
if new_token != '<': | |
# if "#" in new_token: | |
# break | |
# else: | |
partial_message += new_token | |
yield partial_message | |
gr.ChatInterface( | |
predict, | |
title=title, | |
description=description, | |
examples=examples, | |
theme="finlaymacklon/boxy_violet", | |
).launch(debug=True) |