Spaces:
Running
Running
alvinhenrick
commited on
Commit
•
63f14b9
1
Parent(s):
b101c5c
Add streaming support
Browse files- app.py +40 -11
- medirag/rag/wf.py +11 -11
app.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
from pathlib import Path
|
2 |
-
|
3 |
-
import dspy
|
4 |
import gradio as gr
|
5 |
from dotenv import load_dotenv
|
6 |
-
|
7 |
from medirag.cache.local import SemanticCaching
|
8 |
from medirag.index.local import DailyMedIndexer
|
9 |
from medirag.rag.qa import RAG, DailyMedRetrieve
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
load_dotenv()
|
12 |
|
@@ -19,19 +21,43 @@ rm = DailyMedRetrieve(daily_med_indexer=indexer)
|
|
19 |
|
20 |
turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=4000)
|
21 |
dspy.settings.configure(lm=turbo, rm=rm)
|
|
|
|
|
22 |
|
23 |
-
rag = RAG(k=5)
|
24 |
sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
|
25 |
json_file='rag_test_cache.json', cosine_threshold=.90)
|
26 |
sm.load_cache()
|
27 |
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
def ask_med_question(query):
|
|
|
30 |
response = sm.lookup(question=query)
|
31 |
-
if
|
32 |
-
response
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
css = """
|
@@ -41,8 +67,8 @@ h1 {
|
|
41 |
}
|
42 |
#md {margin-top: 70px}
|
43 |
"""
|
44 |
-
# Set up the Gradio interface
|
45 |
|
|
|
46 |
with gr.Blocks(css=css) as app:
|
47 |
gr.Markdown("# DailyMed RAG")
|
48 |
with gr.Row():
|
@@ -54,9 +80,12 @@ with gr.Blocks(css=css) as app:
|
|
54 |
gr.Markdown("### Ask any question about medication usage and get answers based on DailyMed data.",
|
55 |
elem_id="md")
|
56 |
|
|
|
57 |
input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...")
|
58 |
output_text = gr.Textbox(interactive=False, label="Response", lines=10)
|
59 |
button = gr.Button("Submit")
|
60 |
-
|
|
|
|
|
61 |
|
62 |
app.launch()
|
|
|
1 |
from pathlib import Path
|
|
|
|
|
2 |
import gradio as gr
|
3 |
from dotenv import load_dotenv
|
|
|
4 |
from medirag.cache.local import SemanticCaching
|
5 |
from medirag.index.local import DailyMedIndexer
|
6 |
from medirag.rag.qa import RAG, DailyMedRetrieve
|
7 |
+
from medirag.rag.wf import RAGWorkflow
|
8 |
+
from llama_index.llms.openai import OpenAI
|
9 |
+
from llama_index.core import Settings
|
10 |
+
|
11 |
+
import dspy
|
12 |
|
13 |
load_dotenv()
|
14 |
|
|
|
21 |
|
22 |
turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=4000)
|
23 |
dspy.settings.configure(lm=turbo, rm=rm)
|
24 |
+
# Set the LLM model
|
25 |
+
Settings.llm = OpenAI(model='gpt-3.5-turbo')
|
26 |
|
|
|
27 |
sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
|
28 |
json_file='rag_test_cache.json', cosine_threshold=.90)
|
29 |
sm.load_cache()
|
30 |
|
31 |
+
# Initialize RAGWorkflow with indexer
|
32 |
+
rag = RAG(k=5)
|
33 |
+
rag_workflow = RAGWorkflow(indexer=indexer, timeout=60, top_k=10, top_n=5)
|
34 |
+
|
35 |
|
36 |
+
async def ask_med_question(query, enable_stream):
|
37 |
+
# Check the cache first
|
38 |
response = sm.lookup(question=query)
|
39 |
+
if response:
|
40 |
+
# Return cached response if found
|
41 |
+
yield response
|
42 |
+
else:
|
43 |
+
if enable_stream:
|
44 |
+
# Stream response using RAGWorkflow
|
45 |
+
result = await rag_workflow.run(query=query)
|
46 |
+
accumulated_response = ""
|
47 |
+
|
48 |
+
async for chunk in result.async_response_gen():
|
49 |
+
accumulated_response += chunk
|
50 |
+
yield accumulated_response # Accumulate and yield the updated response
|
51 |
+
|
52 |
+
# Save the accumulated response to the cache after streaming is complete
|
53 |
+
sm.save(query, accumulated_response)
|
54 |
+
else:
|
55 |
+
# Use RAG without streaming
|
56 |
+
response = rag(query).answer
|
57 |
+
yield response
|
58 |
+
|
59 |
+
# Save the response in the cache
|
60 |
+
sm.save(query, response)
|
61 |
|
62 |
|
63 |
css = """
|
|
|
67 |
}
|
68 |
#md {margin-top: 70px}
|
69 |
"""
|
|
|
70 |
|
71 |
+
# Set up the Gradio interface with a checkbox for enabling streaming
|
72 |
with gr.Blocks(css=css) as app:
|
73 |
gr.Markdown("# DailyMed RAG")
|
74 |
with gr.Row():
|
|
|
80 |
gr.Markdown("### Ask any question about medication usage and get answers based on DailyMed data.",
|
81 |
elem_id="md")
|
82 |
|
83 |
+
enable_stream = gr.Checkbox(label="Enable Streaming", value=False)
|
84 |
input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...")
|
85 |
output_text = gr.Textbox(interactive=False, label="Response", lines=10)
|
86 |
button = gr.Button("Submit")
|
87 |
+
|
88 |
+
# Update the button click function to include the checkbox value
|
89 |
+
button.click(fn=ask_med_question, inputs=[input_text, enable_stream], outputs=output_text)
|
90 |
|
91 |
app.launch()
|
medirag/rag/wf.py
CHANGED
@@ -1,22 +1,16 @@
|
|
1 |
import asyncio
|
2 |
from pathlib import Path
|
3 |
-
|
4 |
-
from llama_index.core import PromptTemplate
|
5 |
-
from llama_index.core.response_synthesizers import CompactAndRefine, TreeSummarize
|
6 |
from llama_index.core.postprocessor.llm_rerank import LLMRerank
|
|
|
|
|
7 |
from llama_index.core.workflow import Context, Workflow, StartEvent, StopEvent, step
|
8 |
-
from llama_index.llms.openai import OpenAI
|
9 |
from llama_index.core.workflow import Event
|
10 |
-
from llama_index.core.schema import NodeWithScore
|
11 |
from pydantic import BaseModel
|
12 |
|
13 |
from medirag.index.local import DailyMedIndexer
|
14 |
|
15 |
-
load_dotenv()
|
16 |
-
|
17 |
-
# Set the LLM model
|
18 |
-
Settings.llm = OpenAI(model='gpt-3.5-turbo')
|
19 |
-
|
20 |
|
21 |
# Event classes
|
22 |
class RetrieverEvent(Event):
|
@@ -88,7 +82,6 @@ class RAGWorkflow(Workflow):
|
|
88 |
@step
|
89 |
async def retrieve(self, ctx: Context, ev: QueryEvent) -> RetrieverEvent | None:
|
90 |
query = ctx.data["query"]
|
91 |
-
|
92 |
print(f"Query the database with: {query}")
|
93 |
|
94 |
if not self.indexer:
|
@@ -115,6 +108,13 @@ class RAGWorkflow(Workflow):
|
|
115 |
|
116 |
# Main function
|
117 |
async def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
data_dir = Path("../../data")
|
119 |
index_path = data_dir.joinpath("dm_spl_release_human_rx_part1")
|
120 |
|
|
|
1 |
import asyncio
|
2 |
from pathlib import Path
|
3 |
+
|
4 |
+
from llama_index.core import PromptTemplate
|
|
|
5 |
from llama_index.core.postprocessor.llm_rerank import LLMRerank
|
6 |
+
from llama_index.core.response_synthesizers import CompactAndRefine, TreeSummarize
|
7 |
+
from llama_index.core.schema import NodeWithScore
|
8 |
from llama_index.core.workflow import Context, Workflow, StartEvent, StopEvent, step
|
|
|
9 |
from llama_index.core.workflow import Event
|
|
|
10 |
from pydantic import BaseModel
|
11 |
|
12 |
from medirag.index.local import DailyMedIndexer
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
# Event classes
|
16 |
class RetrieverEvent(Event):
|
|
|
82 |
@step
|
83 |
async def retrieve(self, ctx: Context, ev: QueryEvent) -> RetrieverEvent | None:
|
84 |
query = ctx.data["query"]
|
|
|
85 |
print(f"Query the database with: {query}")
|
86 |
|
87 |
if not self.indexer:
|
|
|
108 |
|
109 |
# Main function
|
110 |
async def main():
|
111 |
+
from llama_index.llms.openai import OpenAI
|
112 |
+
from llama_index.core import Settings
|
113 |
+
from dotenv import load_dotenv
|
114 |
+
|
115 |
+
load_dotenv()
|
116 |
+
Settings.llm = OpenAI(model='gpt-3.5-turbo')
|
117 |
+
|
118 |
data_dir = Path("../../data")
|
119 |
index_path = data_dir.joinpath("dm_spl_release_human_rx_part1")
|
120 |
|