import os import torch import gradio as gr import requests from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM from peft import PeftModel, PeftConfig from textwrap import wrap, fill ## using Mistral Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1" def mistral_query(payload): response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload) return response.json() def mistral_inference(input_text): payload = {"inputs": input_text} return mistral_query(payload) # Functions to Wrap the Prompt Correctly def wrap_text(text, width=90): lines = text.split('\n') wrapped_lines = [fill(line, width=width) for line in lines] wrapped_text = '\n'.join(wrapped_lines) return wrapped_text class ChatbotInterface(): def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): self.name = name self.system_prompt = system_prompt self.chatbot = gr.Chatbot() self.chat_history = [] with gr.Row() as row: row.justify = "end" self.msg = gr.Textbox(scale=7) #self.msg.change(fn=, inputs=, outputs=) self.submit = gr.Button("Submit", scale=1) clear = gr.ClearButton([self.msg, self.chatbot]) chat_history = [] self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot]) def respond(self, msg, chatbot): raise NotImplementedError class GaiaMinimed(ChatbotInterface): def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): super().__init__(name, system_prompt) def respond(self, msg, history): formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:" input_ids = tokenizer.encode( formatted_input, return_tensors="pt", add_special_tokens=False ) response = peft_model.generate( input_ids=input_ids, max_length=500, use_cache=False, early_stopping=False, bos_token_id=peft_model.config.bos_token_id, eos_token_id=peft_model.config.eos_token_id, pad_token_id=peft_model.config.eos_token_id, temperature=0.4, do_sample=True ) response_text = tokenizer.decode(response[0], skip_special_tokens=True) self.chat_history.append([formatted_input, response_text]) return "", self.chat_history class FalconBot(ChatbotInterface): def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): super().__init__(name, system_prompt) def respond(self, msg, chatbot): falcon_response = falcon_inference(msg) falcon_output = falcon_response[0]["generated_text"] self.chat_history.append([msg, falcon_output]) return "", falcon_output class MistralBot(ChatbotInterface): def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): super().__init__(name, system_prompt) def respond(self, msg, chatbot): mistral_response = mistral_inference(msg) mistral_output = mistral_response[0]["generated_text"] self.chat_history.append([msg, mistral_output]) return "", mistral_output if __name__ == "__main__": # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Use the base model's ID base_model_id = "tiiuae/falcon-7b-instruct" model_directory = "Tonic/GaiaMiniMed" # Instantiate the Tokenizer tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left") # Specify the configuration class for the model model_config = AutoConfig.from_pretrained(base_model_id) # Load the PEFT model with the specified configuration peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config) peft_model = PeftModel.from_pretrained(peft_model, model_directory) with gr.Blocks() as demo: with gr.Row() as intro: gr.Markdown( """ # MedChat: Your Medical Assistant Chatbot Welcome to MedChat, your friendly medical assistant chatbot! 🩺 Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help! ## How it Works Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses. ## Explore and Compare Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best. _Ready to get started? Type your question and let's begin!_ """ ) with gr.Row() as row: with gr.Column() as col1: with gr.Tab("GaiaMinimed") as gaia: gaia_bot = GaiaMinimed("GaiaMinimed") with gr.Column() as col2: with gr.Tab("MistralMed") as mistral: mistral_bot = MistralBot("MistralMed") with gr.Tab("Falcon-7B") as falcon7b: falcon_bot = FalconBot("Falcon-7B") gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg]) mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg]) falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg]) demo.launch()