alvinhenrick commited on
Commit
1366ab3
1 Parent(s): 6ecf1c3

semantic cache to support other tools

Browse files
app.py CHANGED
@@ -2,7 +2,7 @@ import dspy
2
  import gradio as gr
3
  from dotenv import load_dotenv
4
 
5
- from medirag.cache.local import SemanticCaching
6
  from medirag.index.kdbai import KDBAIDailyMedIndexer
7
  from medirag.rag.qa import RAG, DailyMedRetrieve
8
  from medirag.rag.wf import RAGWorkflow
@@ -21,7 +21,7 @@ dspy.settings.configure(lm=turbo, rm=rm)
21
  # Set the LLM model
22
  Settings.llm = OpenAI(model="gpt-3.5-turbo")
23
 
24
- sm = SemanticCaching(
25
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
26
  )
27
 
 
2
  import gradio as gr
3
  from dotenv import load_dotenv
4
 
5
+ from medirag.cache.local import LocalSemanticCache
6
  from medirag.index.kdbai import KDBAIDailyMedIndexer
7
  from medirag.rag.qa import RAG, DailyMedRetrieve
8
  from medirag.rag.wf import RAGWorkflow
 
21
  # Set the LLM model
22
  Settings.llm = OpenAI(model="gpt-3.5-turbo")
23
 
24
+ sm = LocalSemanticCache(
25
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
26
  )
27
 
medirag/cache/abc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class SemanticCache(ABC):
5
+ """
6
+ Abstract base class for semantic caching mechanisms.
7
+ """
8
+
9
+ @abstractmethod
10
+ def lookup(self, question: str, cosine_threshold: float):
11
+ """
12
+ Retrieve a response from the cache based on the question and cosine similarity threshold.
13
+ """
14
+ pass
15
+
16
+ @abstractmethod
17
+ def save(self, question: str, answer: str):
18
+ """
19
+ Save a question-answer pair to the cache.
20
+ """
21
+ pass
medirag/cache/local.py CHANGED
@@ -5,14 +5,16 @@ from pydantic import BaseModel, ValidationError
5
  from sentence_transformers import SentenceTransformer
6
  from loguru import logger
7
 
 
8
 
9
- class SemanticCache(BaseModel):
 
10
  questions: list[str] = []
11
  embeddings: list[list[float]] = []
12
  response_text: list[str] = []
13
 
14
 
15
- class SemanticCaching:
16
  def __init__(
17
  self,
18
  model_name: str = "sentence-transformers/all-mpnet-base-v2",
@@ -24,14 +26,14 @@ class SemanticCaching:
24
  self.json_file = json_file
25
  self.vector_index = faiss.IndexFlatIP(self.dimension)
26
  self.encoder = SentenceTransformer(model_name)
27
- self._cache = SemanticCache() # Initialize with a default SemanticCache to avoid NoneType issues
28
  self.load_cache()
29
 
30
  def load_cache(self) -> None:
31
  try:
32
  with open(self.json_file, "r") as file:
33
  data = json.load(file)
34
- self._cache = SemanticCache(**data) # Use unpacking to handle Pydantic validation
35
  for emb in self._cache.embeddings:
36
  np_emb = np.array(emb, dtype=np.float32)
37
  faiss.normalize_L2(np_emb.reshape(1, -1))
@@ -71,7 +73,7 @@ class SemanticCaching:
71
  logger.info("New response saved to cache.")
72
 
73
  def clear(self):
74
- self._cache = SemanticCache()
75
  self.vector_index.reset()
76
  self.save_cache()
77
  logger.info("Cache cleared.")
 
5
  from sentence_transformers import SentenceTransformer
6
  from loguru import logger
7
 
8
+ from medirag.cache.abc import SemanticCache
9
 
10
+
11
+ class SemanticCacheModel(BaseModel):
12
  questions: list[str] = []
13
  embeddings: list[list[float]] = []
14
  response_text: list[str] = []
15
 
16
 
17
+ class LocalSemanticCache(SemanticCache):
18
  def __init__(
19
  self,
20
  model_name: str = "sentence-transformers/all-mpnet-base-v2",
 
26
  self.json_file = json_file
27
  self.vector_index = faiss.IndexFlatIP(self.dimension)
28
  self.encoder = SentenceTransformer(model_name)
29
+ self._cache = SemanticCacheModel() # Initialize with a default SemanticCache to avoid NoneType issues
30
  self.load_cache()
31
 
32
  def load_cache(self) -> None:
33
  try:
34
  with open(self.json_file, "r") as file:
35
  data = json.load(file)
36
+ self._cache = SemanticCacheModel(**data) # Use unpacking to handle Pydantic validation
37
  for emb in self._cache.embeddings:
38
  np_emb = np.array(emb, dtype=np.float32)
39
  faiss.normalize_L2(np_emb.reshape(1, -1))
 
73
  logger.info("New response saved to cache.")
74
 
75
  def clear(self):
76
+ self._cache = SemanticCacheModel()
77
  self.vector_index.reset()
78
  self.save_cache()
79
  logger.info("Cache cleared.")
tests/cache/test_semantic_cache.py CHANGED
@@ -1,31 +1,31 @@
1
  import pytest
2
- from medirag.cache.local import SemanticCaching
3
 
4
 
5
  # Fixture to initialize the SemanticCaching object
6
  @pytest.fixture(scope="module")
7
- def semantic_caching():
8
  # Initialize the SemanticCaching class with a test cache file
9
- return SemanticCaching(
10
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="real_test_cache.json"
11
  )
12
 
13
 
14
- def test_save_and_lookup_in_cache(semantic_caching):
15
  # Clear any existing cache data
16
- semantic_caching.clear()
17
 
18
  # Step 1: Lookup should return None for a question not in the cache
19
- initial_lookup = semantic_caching.lookup("What is the capital of France?")
20
  assert initial_lookup is None
21
 
22
  # Step 2: Save a response to the cache
23
- semantic_caching.save("What is the capital of France?", "Paris")
24
 
25
  # Step 3: Lookup the same question; it should now return the cached response
26
- cached_response = semantic_caching.lookup("What is the capital of France?")
27
  assert cached_response is not None
28
  assert cached_response == "Paris"
29
 
30
  # Cleanup: Clear the cache after test
31
- semantic_caching.clear()
 
1
  import pytest
2
+ from medirag.cache.local import LocalSemanticCache
3
 
4
 
5
  # Fixture to initialize the SemanticCaching object
6
  @pytest.fixture(scope="module")
7
+ def semantic_cache():
8
  # Initialize the SemanticCaching class with a test cache file
9
+ return LocalSemanticCache(
10
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="real_test_cache.json"
11
  )
12
 
13
 
14
+ def test_save_and_lookup_in_cache(semantic_cache):
15
  # Clear any existing cache data
16
+ semantic_cache.clear()
17
 
18
  # Step 1: Lookup should return None for a question not in the cache
19
+ initial_lookup = semantic_cache.lookup("What is the capital of France?")
20
  assert initial_lookup is None
21
 
22
  # Step 2: Save a response to the cache
23
+ semantic_cache.save("What is the capital of France?", "Paris")
24
 
25
  # Step 3: Lookup the same question; it should now return the cached response
26
+ cached_response = semantic_cache.lookup("What is the capital of France?")
27
  assert cached_response is not None
28
  assert cached_response == "Paris"
29
 
30
  # Cleanup: Clear the cache after test
31
+ semantic_cache.clear()
tests/rag/test_rag.py CHANGED
@@ -1,4 +1,4 @@
1
- from medirag.cache.local import SemanticCaching
2
  from medirag.index.local import LocalIndexer
3
 
4
  # from medirag.index.kdbai import KDBAIDailyMedIndexer
@@ -36,10 +36,9 @@ def test_rag_with_example(data_dir):
36
 
37
  rag = RAG(k=3)
38
 
39
- sm = SemanticCaching(
40
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
41
  )
42
- # sm.load_cache()
43
 
44
  result1 = ask_med_question(sm, rag, query)
45
  print(result1)
 
1
+ from medirag.cache.local import LocalSemanticCache
2
  from medirag.index.local import LocalIndexer
3
 
4
  # from medirag.index.kdbai import KDBAIDailyMedIndexer
 
36
 
37
  rag = RAG(k=3)
38
 
39
+ sm = LocalSemanticCache(
40
  model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
41
  )
 
42
 
43
  result1 = ask_med_question(sm, rag, query)
44
  print(result1)