alvinhenrick commited on
Commit
63f14b9
1 Parent(s): b101c5c

Add streaming support

Browse files
Files changed (2) hide show
  1. app.py +40 -11
  2. medirag/rag/wf.py +11 -11
app.py CHANGED
@@ -1,12 +1,14 @@
1
  from pathlib import Path
2
-
3
- import dspy
4
  import gradio as gr
5
  from dotenv import load_dotenv
6
-
7
  from medirag.cache.local import SemanticCaching
8
  from medirag.index.local import DailyMedIndexer
9
  from medirag.rag.qa import RAG, DailyMedRetrieve
 
 
 
 
 
10
 
11
  load_dotenv()
12
 
@@ -19,19 +21,43 @@ rm = DailyMedRetrieve(daily_med_indexer=indexer)
19
 
20
  turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=4000)
21
  dspy.settings.configure(lm=turbo, rm=rm)
 
 
22
 
23
- rag = RAG(k=5)
24
  sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
25
  json_file='rag_test_cache.json', cosine_threshold=.90)
26
  sm.load_cache()
27
 
 
 
 
 
28
 
29
- def ask_med_question(query):
 
30
  response = sm.lookup(question=query)
31
- if not response:
32
- response = rag(query).answer
33
- sm.save(query, response)
34
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  css = """
@@ -41,8 +67,8 @@ h1 {
41
  }
42
  #md {margin-top: 70px}
43
  """
44
- # Set up the Gradio interface
45
 
 
46
  with gr.Blocks(css=css) as app:
47
  gr.Markdown("# DailyMed RAG")
48
  with gr.Row():
@@ -54,9 +80,12 @@ with gr.Blocks(css=css) as app:
54
  gr.Markdown("### Ask any question about medication usage and get answers based on DailyMed data.",
55
  elem_id="md")
56
 
 
57
  input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...")
58
  output_text = gr.Textbox(interactive=False, label="Response", lines=10)
59
  button = gr.Button("Submit")
60
- button.click(fn=ask_med_question, inputs=input_text, outputs=output_text)
 
 
61
 
62
  app.launch()
 
1
  from pathlib import Path
 
 
2
  import gradio as gr
3
  from dotenv import load_dotenv
 
4
  from medirag.cache.local import SemanticCaching
5
  from medirag.index.local import DailyMedIndexer
6
  from medirag.rag.qa import RAG, DailyMedRetrieve
7
+ from medirag.rag.wf import RAGWorkflow
8
+ from llama_index.llms.openai import OpenAI
9
+ from llama_index.core import Settings
10
+
11
+ import dspy
12
 
13
  load_dotenv()
14
 
 
21
 
22
  turbo = dspy.OpenAI(model='gpt-3.5-turbo', max_tokens=4000)
23
  dspy.settings.configure(lm=turbo, rm=rm)
24
+ # Set the LLM model
25
+ Settings.llm = OpenAI(model='gpt-3.5-turbo')
26
 
 
27
  sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
28
  json_file='rag_test_cache.json', cosine_threshold=.90)
29
  sm.load_cache()
30
 
31
+ # Initialize RAGWorkflow with indexer
32
+ rag = RAG(k=5)
33
+ rag_workflow = RAGWorkflow(indexer=indexer, timeout=60, top_k=10, top_n=5)
34
+
35
 
36
+ async def ask_med_question(query, enable_stream):
37
+ # Check the cache first
38
  response = sm.lookup(question=query)
39
+ if response:
40
+ # Return cached response if found
41
+ yield response
42
+ else:
43
+ if enable_stream:
44
+ # Stream response using RAGWorkflow
45
+ result = await rag_workflow.run(query=query)
46
+ accumulated_response = ""
47
+
48
+ async for chunk in result.async_response_gen():
49
+ accumulated_response += chunk
50
+ yield accumulated_response # Accumulate and yield the updated response
51
+
52
+ # Save the accumulated response to the cache after streaming is complete
53
+ sm.save(query, accumulated_response)
54
+ else:
55
+ # Use RAG without streaming
56
+ response = rag(query).answer
57
+ yield response
58
+
59
+ # Save the response in the cache
60
+ sm.save(query, response)
61
 
62
 
63
  css = """
 
67
  }
68
  #md {margin-top: 70px}
69
  """
 
70
 
71
+ # Set up the Gradio interface with a checkbox for enabling streaming
72
  with gr.Blocks(css=css) as app:
73
  gr.Markdown("# DailyMed RAG")
74
  with gr.Row():
 
80
  gr.Markdown("### Ask any question about medication usage and get answers based on DailyMed data.",
81
  elem_id="md")
82
 
83
+ enable_stream = gr.Checkbox(label="Enable Streaming", value=False)
84
  input_text = gr.Textbox(lines=2, label="Question", placeholder="Enter your question about a drug...")
85
  output_text = gr.Textbox(interactive=False, label="Response", lines=10)
86
  button = gr.Button("Submit")
87
+
88
+ # Update the button click function to include the checkbox value
89
+ button.click(fn=ask_med_question, inputs=[input_text, enable_stream], outputs=output_text)
90
 
91
  app.launch()
medirag/rag/wf.py CHANGED
@@ -1,22 +1,16 @@
1
  import asyncio
2
  from pathlib import Path
3
- from dotenv import load_dotenv
4
- from llama_index.core import PromptTemplate, Settings
5
- from llama_index.core.response_synthesizers import CompactAndRefine, TreeSummarize
6
  from llama_index.core.postprocessor.llm_rerank import LLMRerank
 
 
7
  from llama_index.core.workflow import Context, Workflow, StartEvent, StopEvent, step
8
- from llama_index.llms.openai import OpenAI
9
  from llama_index.core.workflow import Event
10
- from llama_index.core.schema import NodeWithScore
11
  from pydantic import BaseModel
12
 
13
  from medirag.index.local import DailyMedIndexer
14
 
15
- load_dotenv()
16
-
17
- # Set the LLM model
18
- Settings.llm = OpenAI(model='gpt-3.5-turbo')
19
-
20
 
21
  # Event classes
22
  class RetrieverEvent(Event):
@@ -88,7 +82,6 @@ class RAGWorkflow(Workflow):
88
  @step
89
  async def retrieve(self, ctx: Context, ev: QueryEvent) -> RetrieverEvent | None:
90
  query = ctx.data["query"]
91
-
92
  print(f"Query the database with: {query}")
93
 
94
  if not self.indexer:
@@ -115,6 +108,13 @@ class RAGWorkflow(Workflow):
115
 
116
  # Main function
117
  async def main():
 
 
 
 
 
 
 
118
  data_dir = Path("../../data")
119
  index_path = data_dir.joinpath("dm_spl_release_human_rx_part1")
120
 
 
1
  import asyncio
2
  from pathlib import Path
3
+
4
+ from llama_index.core import PromptTemplate
 
5
  from llama_index.core.postprocessor.llm_rerank import LLMRerank
6
+ from llama_index.core.response_synthesizers import CompactAndRefine, TreeSummarize
7
+ from llama_index.core.schema import NodeWithScore
8
  from llama_index.core.workflow import Context, Workflow, StartEvent, StopEvent, step
 
9
  from llama_index.core.workflow import Event
 
10
  from pydantic import BaseModel
11
 
12
  from medirag.index.local import DailyMedIndexer
13
 
 
 
 
 
 
14
 
15
  # Event classes
16
  class RetrieverEvent(Event):
 
82
  @step
83
  async def retrieve(self, ctx: Context, ev: QueryEvent) -> RetrieverEvent | None:
84
  query = ctx.data["query"]
 
85
  print(f"Query the database with: {query}")
86
 
87
  if not self.indexer:
 
108
 
109
  # Main function
110
  async def main():
111
+ from llama_index.llms.openai import OpenAI
112
+ from llama_index.core import Settings
113
+ from dotenv import load_dotenv
114
+
115
+ load_dotenv()
116
+ Settings.llm = OpenAI(model='gpt-3.5-turbo')
117
+
118
  data_dir = Path("../../data")
119
  index_path = data_dir.joinpath("dm_spl_release_human_rx_part1")
120