Craig Pretzinger commited on
Commit
1ee7467
1 Parent(s): b1f5115

Updated files for enhanced PubMedBERT and GPT-4o-mini integration

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +57 -146
.gitignore CHANGED
@@ -1 +1,2 @@
1
  venv/
 
 
1
  venv/
2
+ .env
app.py CHANGED
@@ -1,177 +1,88 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
- from transformers import BertTokenizer, BertForSequenceClassification
4
  import openai
5
  import os
 
 
 
 
6
  import faiss
7
  import numpy as np
8
- import requests
9
- from datasets import load_dataset
10
 
11
- # Load OpenAI API key and organization ID from environment variables
 
 
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
- openai.Organization = os.getenv("OPENAI_ORG_ID")
 
14
 
15
  # Load PubMedBERT tokenizer and model
16
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
17
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
18
 
19
- # FAISS setup for vector search (embedding-based memory)
20
- dimension = 768 # PubMedBERT embedding size
21
  index = faiss.IndexFlatL2(dimension)
22
 
23
- # Embed text using PubMedBERT
24
  def embed_text(text):
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
26
  outputs = model(**inputs, output_hidden_states=True)
27
  hidden_state = outputs.hidden_states[-1]
28
  return hidden_state.mean(dim=1).detach().numpy()
29
 
30
- # Add past conversation embedding to FAISS index
31
- past_conversation = "FDA approval for companion diagnostics requires careful documentation."
32
- past_embedding = embed_text(past_conversation)
33
- past_embedding = np.array(past_embedding) # Convert to numpy array
34
-
35
- # Reshape if necessary (e.g., (1, 768) for PubMedBERT)
36
- past_embedding = past_embedding.reshape(1, -1)
37
-
38
- index.add(past_embedding)
39
-
40
- # Search past conversations/memory using FAISS
41
- def search_memory(query):
42
- query_embedding = embed_text(query)
43
- D, I = index.search(query_embedding, k=1)
44
- return I
45
-
46
- # Handle FDA-specific queries with PubMedBERT
47
  def handle_fda_query(query):
48
- inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
49
  outputs = model(**inputs)
50
  logits = outputs.logits
51
- response = "Processed FDA-related query via PubMedBERT"
52
- return response
53
-
54
- # Handle general queries using GPT-4O
55
- def handle_openai_query(prompt):
56
- response = openai.Chat.create(
57
- model="gpt-4-0314-16k-512",
58
- messages=[
59
- {"role": "user", "content": prompt}
60
- ],
61
- temperature=0.7,
62
- max_tokens=100
63
- )
64
- return response.choices[0].message.content
65
-
66
- # Web search with Serper API
67
- def web_search(query):
68
- url = f"https://google.serper.dev/search"
69
- headers = {
70
- "X-API-KEY": os.getenv("SERPER_API_KEY")
71
- }
72
- params = {
73
- "q": query
74
- }
75
- response = requests.get(url, headers=headers, params=params)
76
- return response.json()
77
-
78
- # Contextual Short-Term Memory (CSTM)
79
- cstm = []
80
-
81
- # Long-Term Memory (LTM)
82
- ltm = [] # Load knowledge base articles or FAQs
83
-
84
- # Semantic search function
85
- def semantic_search(query, cstm, ltm):
86
- # Generate embeddings for query and CSTM/LTM
87
- query_embedding = embed_text(query)
88
- cstm_embeddings = [embed_text(text) for text in cstm]
89
- ltm_embeddings = [embed_text(text) for text in ltm]
90
-
91
- # Calculate similarity scores
92
- cstm_scores = calculate_similarity(query_embedding, cstm_embeddings)
93
- ltm_scores = calculate_similarity(query_embedding, ltm_embeddings)
94
-
95
- # Retrieve top relevant results from CSTM and LTM
96
- top_cstm = np.argmax(cstm_scores)
97
- top_ltm = np.argmax(ltm_scores)
98
-
99
- return top_cstm, top_ltm
100
-
101
- # Calculate similarity between embeddings
102
- def calculate_similarity(query_embedding, embeddings):
103
- similarity_scores = []
104
- for embedding in embeddings:
105
- score = cosine_similarity(query_embedding, embedding)
106
- similarity_scores.append(score)
107
- return similarity_scores
108
-
109
- # Cosine similarity function
110
- def cosine_similarity(a, b):
111
- dot_product = np.dot(a, b)
112
- magnitude_a = np.linalg.norm(a)
113
- magnitude_b = np.linalg.norm(b)
114
- return dot_product / (magnitude_a * magnitude_b)
115
-
116
- # Main assistant function
117
- def respond(
118
- message,
119
- history: list[tuple[str, str]],
120
- system_message,
121
- max_tokens,
122
- temperature,
123
- top_p,
124
- ):
125
- # Prepare context for OpenAI and PubMedBERT
126
- messages = [{"role": "system", "content": system_message}]
127
-
128
- for val in history:
129
- if val[0]:
130
- messages.append({"role": "user", "content": val[0]})
131
- if val[1]:
132
- messages.append({"role": "assistant", "content": val[1]})
133
-
134
- messages.append({"role": "user", "content": message})
135
-
136
- # Check if query is FDA-related
137
- openai_response = handle_openai_query(f"Is this query FDA-related: {message}")
138
-
139
- if "FDA" in openai_response or "regulatory" in openai_response:
140
- # Search past conversations/memory using FAISS
141
- memory_index = search_memory(message)
142
- if memory_index:
143
- return f"Found relevant past memory: {past_conversation}"
144
-
145
- # If no memory match, proceed with PubMedBERT
146
- return handle_fda_query(message)
147
-
148
- # If query asks for web search, perform web search
149
- if "search the web" in message.lower():
150
- return web_search(message)
151
-
152
- # Perform semantic search on CSTM and LTM
153
- top_cstm, top_ltm = semantic_search(message, cstm, ltm)
154
- if top_cstm:
155
- return f"Found relevant context: {cstm[top_cstm]}"
156
- elif top_ltm:
157
- return f"Found relevant knowledge: {ltm[top_ltm]}"
158
-
159
- # General conversational handling with GPT-4O
160
- response = handle_openai_query(message)
161
- return response
162
-
163
-
164
- # Create Gradio ChatInterface for interaction
165
- demo = gr.ChatInterface(
166
- respond,
167
- additional_inputs=[
168
- gr.Textbox(value="You are Ferris2.0, an FDA expert.", label="System message"),
169
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
170
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
171
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
172
  ],
 
173
  )
174
 
175
-
176
  if __name__ == "__main__":
177
  demo.launch()
 
1
  import gradio as gr
 
 
2
  import openai
3
  import os
4
+ from dotenv import load_dotenv
5
+ import requests
6
+ from transformers import BertTokenizer, BertForSequenceClassification
7
+ import torch
8
  import faiss
9
  import numpy as np
 
 
10
 
11
+ # Load .env
12
+ load_dotenv()
13
+
14
+ # API Keys and Org ID
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
+ openai.organization = os.getenv("OPENAI_ORG_ID")
17
+ serper_api_key = os.getenv("SERPER_API_KEY")
18
 
19
  # Load PubMedBERT tokenizer and model
20
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
21
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
22
 
23
+ # FAISS setup for vector search
24
+ dimension = 768
25
  index = faiss.IndexFlatL2(dimension)
26
 
27
+ # Function to embed text (PubMedBERT)
28
  def embed_text(text):
29
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
30
  outputs = model(**inputs, output_hidden_states=True)
31
  hidden_state = outputs.hidden_states[-1]
32
  return hidden_state.mean(dim=1).detach().numpy()
33
 
34
+ # Function to retrieve info from PubMedBERT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def handle_fda_query(query):
36
+ inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
37
  outputs = model(**inputs)
38
  logits = outputs.logits
39
+ prediction = torch.argmax(logits, dim=1).item()
40
+
41
+ # Simulate a meaningful FDA-related response
42
+ if prediction == 1:
43
+ return f"FDA Query Processed: '{query}' contains important regulatory information."
44
+ else:
45
+ return f"FDA Query Processed: '{query}' seems to be general and not regulatory-heavy."
46
+
47
+ # Function to enhance info via GPT-4o-mini
48
+ def enhance_with_gpt4o(fda_response):
49
+ try:
50
+ response = openai.ChatCompletion.create(
51
+ model="gpt-4o-mini", # Correct model
52
+ messages=[{"role": "system", "content": "You are an expert FDA assistant."}, {"role": "user", "content": f"Enhance this FDA info: {fda_response}"}],
53
+ max_tokens=150
54
+ )
55
+ return response['choices'][0]['message']['content']
56
+ except Exception as e:
57
+ return f"Error: {str(e)}"
58
+
59
+ # Main function that gets PubMedBERT output and enhances it using GPT-4o-mini
60
+ def respond(message, system_message, max_tokens, temperature, top_p):
61
+ try:
62
+ # First retrieve info via PubMedBERT
63
+ fda_response = handle_fda_query(message)
64
+
65
+ # Then enhance this info via GPT-4o-mini
66
+ enhanced_response = enhance_with_gpt4o(fda_response)
67
+
68
+ # Return both the PubMedBERT result and the enhanced version
69
+ return f"Original Info from PubMedBERT: {fda_response}\n\nEnhanced Info via GPT-4o-mini: {enhanced_response}"
70
+
71
+ except Exception as e:
72
+ return f"Error: {str(e)}"
73
+
74
+ # Gradio Interface
75
+ demo = gr.Interface(
76
+ fn=respond,
77
+ inputs=[
78
+ gr.Textbox(label="Enter your FDA query", placeholder="Ask Ferris2.0 anything FDA-related."),
79
+ gr.Textbox(value="You are Ferris2.0, the most advanced FDA Regulatory Assistant.", label="System message"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
81
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
82
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
83
  ],
84
+ outputs="text",
85
  )
86
 
 
87
  if __name__ == "__main__":
88
  demo.launch()