Ferris2dotOh / app.py
Craig Pretzinger
Committed local changes to app.py
4d2f914
raw
history blame
4.84 kB
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
import openai
import os
import faiss
import numpy as np
import requests
from datasets import load_dataset
ds = load_dataset("epfl-llm/guidelines")
# Load OpenAI and Serper API keys from Hugging Face secrets
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure the OpenAI API key is pulled correctly
serper_api_key = os.getenv("SERPER_API_KEY") # Ensure the Serper API key is pulled correctly
# Load PubMedBERT tokenizer and model for FDA-related processing
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
# FAISS setup for vector search (embedding-based memory)
dimension = 768 # PubMedBERT embedding size
index = faiss.IndexFlatL2(dimension)
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
outputs = model(**inputs, output_hidden_states=True) # Ensure hidden states are returned
hidden_state = outputs.hidden_states[-1] # Get the last hidden state
return hidden_state.mean(dim=1).detach().numpy() # Take the mean across the sequence
# Example: Embed past conversation and store in FAISS
past_conversation = "FDA approval for companion diagnostics requires careful documentation."
past_embedding = embed_text(past_conversation)
index.add(past_embedding)
# Embed the incoming query and search for related memory
def search_memory(query):
query_embedding = embed_text(query)
D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation
return I
# Function to handle FDA-specific queries with PubMedBERT
def handle_fda_query(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
outputs = model(**inputs)
logits = outputs.logits
response = "Processed FDA-related query via PubMedBERT"
return response
# Function to handle general queries using GPT-4o
def handle_openai_query(prompt):
response = openai.Completion.create(
engine="gpt-4o", # Using GPT-4o as per instruction
prompt=prompt,
max_tokens=100
)
return response.choices[0].text.strip()
# Web search with Serper API
def web_search(query):
url = f"https://google.serper.dev/search"
headers = {
"X-API-KEY": serper_api_key
}
params = {
"q": query
}
response = requests.get(url, headers=headers, params=params)
return response.json()
# Main assistant function that delegates to either OpenAI, PubMedBERT, or Serper (web search)
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p,
):
# Prepare the context for OpenAI and PubMedBERT
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
# Check if the query is related to FDA
openai_response = handle_openai_query(f"Is this query FDA-related: {message}")
if "FDA" in openai_response or "regulatory" in openai_response:
# Search past conversations/memory using FAISS
memory_index = search_memory(message)
if memory_index:
return f"Found relevant past memory: {past_conversation}" # Return past context from memory
# If no memory match, proceed with PubMedBERT
return handle_fda_query(message)
# If query asks for a web search, perform web search
if "search the web" in message.lower():
return web_search(message)
# General conversational handling with GPT-4o
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
token = message.choices[0].delta.content
response += token
yield response
# Create Gradio ChatInterface for interaction
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are Ferris2.0, an FDA expert.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
],
)
if __name__ == "__main__":
demo.launch(share=True)