|
import os |
|
import torch |
|
import gradio as gr |
|
import requests |
|
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM |
|
from peft import PeftModel, PeftConfig |
|
from textwrap import wrap, fill |
|
|
|
|
|
Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1" |
|
def mistral_query(payload): |
|
response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload) |
|
return response.json() |
|
def mistral_inference(input_text): |
|
payload = {"inputs": input_text} |
|
return mistral_query(payload) |
|
|
|
|
|
def wrap_text(text, width=90): |
|
lines = text.split('\n') |
|
wrapped_lines = [fill(line, width=width) for line in lines] |
|
wrapped_text = '\n'.join(wrapped_lines) |
|
return wrapped_text |
|
|
|
class ChatbotInterface(): |
|
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): |
|
self.name = name |
|
self.system_prompt = system_prompt |
|
self.chatbot = gr.Chatbot() |
|
self.chat_history = [] |
|
|
|
with gr.Row() as row: |
|
row.justify = "end" |
|
self.msg = gr.Textbox(scale=7) |
|
|
|
self.submit = gr.Button("Submit", scale=1) |
|
|
|
clear = gr.ClearButton([self.msg, self.chatbot]) |
|
chat_history = [] |
|
|
|
self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot]) |
|
|
|
def respond(self, msg, chatbot): |
|
raise NotImplementedError |
|
|
|
class GaiaMinimed(ChatbotInterface): |
|
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): |
|
super().__init__(name, system_prompt) |
|
|
|
def respond(self, msg, history): |
|
formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:" |
|
input_ids = tokenizer.encode( |
|
formatted_input, |
|
return_tensors="pt", |
|
add_special_tokens=False |
|
) |
|
response = peft_model.generate( |
|
input_ids=input_ids, |
|
max_length=500, |
|
use_cache=False, |
|
early_stopping=False, |
|
bos_token_id=peft_model.config.bos_token_id, |
|
eos_token_id=peft_model.config.eos_token_id, |
|
pad_token_id=peft_model.config.eos_token_id, |
|
temperature=0.4, |
|
do_sample=True |
|
) |
|
response_text = tokenizer.decode(response[0], skip_special_tokens=True) |
|
|
|
self.chat_history.append([formatted_input, response_text]) |
|
|
|
return "", self.chat_history |
|
|
|
class FalconBot(ChatbotInterface): |
|
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): |
|
super().__init__(name, system_prompt) |
|
|
|
def respond(self, msg, chatbot): |
|
falcon_response = falcon_inference(msg) |
|
falcon_output = falcon_response[0]["generated_text"] |
|
self.chat_history.append([msg, falcon_output]) |
|
return "", falcon_output |
|
|
|
class MistralBot(ChatbotInterface): |
|
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): |
|
super().__init__(name, system_prompt) |
|
|
|
def respond(self, msg, chatbot): |
|
mistral_response = mistral_inference(msg) |
|
mistral_output = mistral_response[0]["generated_text"] |
|
self.chat_history.append([msg, mistral_output]) |
|
return "", mistral_output |
|
|
|
if __name__ == "__main__": |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
base_model_id = "tiiuae/falcon-7b-instruct" |
|
model_directory = "Tonic/GaiaMiniMed" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left") |
|
|
|
|
|
model_config = AutoConfig.from_pretrained(base_model_id) |
|
|
|
peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config) |
|
peft_model = PeftModel.from_pretrained(peft_model, model_directory) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row() as intro: |
|
gr.Markdown( |
|
""" |
|
# MedChat: Your Medical Assistant Chatbot |
|
|
|
Welcome to MedChat, your friendly medical assistant chatbot! 🩺 |
|
|
|
Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help! |
|
|
|
## How it Works |
|
Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses. |
|
|
|
## Explore and Compare |
|
Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best. |
|
|
|
_Ready to get started? Type your question and let's begin!_ |
|
""" |
|
) |
|
with gr.Row() as row: |
|
with gr.Column() as col1: |
|
with gr.Tab("GaiaMinimed") as gaia: |
|
gaia_bot = GaiaMinimed("GaiaMinimed") |
|
with gr.Column() as col2: |
|
with gr.Tab("MistralMed") as mistral: |
|
mistral_bot = MistralBot("MistralMed") |
|
with gr.Tab("Falcon-7B") as falcon7b: |
|
falcon_bot = FalconBot("Falcon-7B") |
|
|
|
gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg]) |
|
mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg]) |
|
falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg]) |
|
|
|
demo.launch() |