Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer | |
import openai | |
import os | |
import faiss | |
import numpy as np | |
import requests | |
from datasets import load_dataset | |
ds = load_dataset("epfl-llm/guidelines") | |
# Load OpenAI and Serper API keys from Hugging Face secrets | |
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure the OpenAI API key is pulled correctly | |
serper_api_key = os.getenv("SERPER_API_KEY") # Ensure the Serper API key is pulled correctly | |
# Load PubMedBERT tokenizer and model for FDA-related processing | |
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) | |
# FAISS setup for vector search (embedding-based memory) | |
dimension = 768 # PubMedBERT embedding size | |
index = faiss.IndexFlatL2(dimension) | |
def embed_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
outputs = model(**inputs, output_hidden_states=True) # Ensure hidden states are returned | |
hidden_state = outputs.hidden_states[-1] # Get the last hidden state | |
return hidden_state.mean(dim=1).detach().numpy() # Take the mean across the sequence | |
# Example: Embed past conversation and store in FAISS | |
past_conversation = "FDA approval for companion diagnostics requires careful documentation." | |
past_embedding = embed_text(past_conversation) | |
index.add(past_embedding) | |
# Embed the incoming query and search for related memory | |
def search_memory(query): | |
query_embedding = embed_text(query) | |
D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation | |
return I | |
# Function to handle FDA-specific queries with PubMedBERT | |
def handle_fda_query(query): | |
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
response = "Processed FDA-related query via PubMedBERT" | |
return response | |
# Function to handle general queries using GPT-4o | |
def handle_openai_query(prompt): | |
response = openai.Completion.create( | |
engine="gpt-4o", # Using GPT-4o as per instruction | |
prompt=prompt, | |
max_tokens=100 | |
) | |
return response.choices[0].text.strip() | |
# Web search with Serper API | |
def web_search(query): | |
url = f"https://google.serper.dev/search" | |
headers = { | |
"X-API-KEY": serper_api_key | |
} | |
params = { | |
"q": query | |
} | |
response = requests.get(url, headers=headers, params=params) | |
return response.json() | |
# Main assistant function that delegates to either OpenAI, PubMedBERT, or Serper (web search) | |
def respond( | |
message, | |
history: list[tuple[str, str]], | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
# Prepare the context for OpenAI and PubMedBERT | |
messages = [{"role": "system", "content": system_message}] | |
for val in history: | |
if val[0]: | |
messages.append({"role": "user", "content": val[0]}) | |
if val[1]: | |
messages.append({"role": "assistant", "content": val[1]}) | |
messages.append({"role": "user", "content": message}) | |
# Check if the query is related to FDA | |
openai_response = handle_openai_query(f"Is this query FDA-related: {message}") | |
if "FDA" in openai_response or "regulatory" in openai_response: | |
# Search past conversations/memory using FAISS | |
memory_index = search_memory(message) | |
if memory_index: | |
return f"Found relevant past memory: {past_conversation}" # Return past context from memory | |
# If no memory match, proceed with PubMedBERT | |
return handle_fda_query(message) | |
# If query asks for a web search, perform web search | |
if "search the web" in message.lower(): | |
return web_search(message) | |
# General conversational handling with GPT-4o | |
response = "" | |
for message in client.chat_completion( | |
messages, | |
max_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
token = message.choices[0].delta.content | |
response += token | |
yield response | |
# Create Gradio ChatInterface for interaction | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are Ferris2.0, an FDA expert.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |