import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer from threading import Thread tokenizer = AutoTokenizer.from_pretrained("ShieldX/manovyadh-1.1B-v1") model = AutoModelForCausalLM.from_pretrained("ShieldX/manovyadh-1.1B-v1") # Check for GPU availability if torch.cuda.is_available(): device = "cuda" else: device = "cpu" # Move model and inputs to the GPU (if available) title = "🌱 ManoVyadh 🌱" description = "Mental Health Counselling Chatbot" examples = ["I have been feeling more and more down for over a month. I have started having trouble sleeping due to panic attacks, but they are almost never triggered by something that I know of.", "I self-harm, and I stop for a while. Then when I see something sad or depressing, I automatically want to self-harm.", "I am feeling sad for my friend's divorce"] class StopOnTokens(StoppingCriteria): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: stop_ids = [1, 2] for stop_id in stop_ids: if input_ids[0][-1] == stop_id: return True return False def predict(message, history): history_transformer_format = history + [[message, ""]] stop = StopOnTokens() sys_msg = """###SYSTEM: You are an AI assistant that helps people cope with stress and improve their mental health. User will tell you about their feelings and challenges. Your task is to listen empathetically and offer helpful suggestions. While responding, think about the user’s needs and goals and show compassion and support""" messages = "".join(["".join([sys_msg + "\n###USER:"+item[0], "\n###ASSISTANT:"+item[1]]) #curr_system_message + for item in history_transformer_format]) # def format_prompt(q): # return f"""{sys_msg} # ###USER: {q} # ###ASSISTANT:""" # messages = format_prompt(message) model_inputs = tokenizer([messages], return_tensors="pt").to(device) streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=False) generate_kwargs = dict( model_inputs, streamer=streamer, max_new_tokens=256, do_sample=True, top_p=0.95, top_k=1000, temperature=0.2, num_beams=1, eos_token_id=[tokenizer.eos_token_id], pad_token_id=tokenizer.eos_token_id, stopping_criteria=StoppingCriteriaList([stop]) ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() partial_message = "" for new_token in streamer: if new_token != '<': # if "#" in new_token: # break # else: partial_message += new_token yield partial_message gr.ChatInterface( predict, title=title, description=description, examples=examples, theme="finlaymacklon/boxy_violet", ).launch(debug=True)