Spaces:
Sleeping
Sleeping
alvinhenrick
commited on
Commit
•
1366ab3
1
Parent(s):
6ecf1c3
semantic cache to support other tools
Browse files- app.py +2 -2
- medirag/cache/abc.py +21 -0
- medirag/cache/local.py +7 -5
- tests/cache/test_semantic_cache.py +9 -9
- tests/rag/test_rag.py +2 -3
app.py
CHANGED
@@ -2,7 +2,7 @@ import dspy
|
|
2 |
import gradio as gr
|
3 |
from dotenv import load_dotenv
|
4 |
|
5 |
-
from medirag.cache.local import
|
6 |
from medirag.index.kdbai import KDBAIDailyMedIndexer
|
7 |
from medirag.rag.qa import RAG, DailyMedRetrieve
|
8 |
from medirag.rag.wf import RAGWorkflow
|
@@ -21,7 +21,7 @@ dspy.settings.configure(lm=turbo, rm=rm)
|
|
21 |
# Set the LLM model
|
22 |
Settings.llm = OpenAI(model="gpt-3.5-turbo")
|
23 |
|
24 |
-
sm =
|
25 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
|
26 |
)
|
27 |
|
|
|
2 |
import gradio as gr
|
3 |
from dotenv import load_dotenv
|
4 |
|
5 |
+
from medirag.cache.local import LocalSemanticCache
|
6 |
from medirag.index.kdbai import KDBAIDailyMedIndexer
|
7 |
from medirag.rag.qa import RAG, DailyMedRetrieve
|
8 |
from medirag.rag.wf import RAGWorkflow
|
|
|
21 |
# Set the LLM model
|
22 |
Settings.llm = OpenAI(model="gpt-3.5-turbo")
|
23 |
|
24 |
+
sm = LocalSemanticCache(
|
25 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
|
26 |
)
|
27 |
|
medirag/cache/abc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class SemanticCache(ABC):
|
5 |
+
"""
|
6 |
+
Abstract base class for semantic caching mechanisms.
|
7 |
+
"""
|
8 |
+
|
9 |
+
@abstractmethod
|
10 |
+
def lookup(self, question: str, cosine_threshold: float):
|
11 |
+
"""
|
12 |
+
Retrieve a response from the cache based on the question and cosine similarity threshold.
|
13 |
+
"""
|
14 |
+
pass
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def save(self, question: str, answer: str):
|
18 |
+
"""
|
19 |
+
Save a question-answer pair to the cache.
|
20 |
+
"""
|
21 |
+
pass
|
medirag/cache/local.py
CHANGED
@@ -5,14 +5,16 @@ from pydantic import BaseModel, ValidationError
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from loguru import logger
|
7 |
|
|
|
8 |
|
9 |
-
|
|
|
10 |
questions: list[str] = []
|
11 |
embeddings: list[list[float]] = []
|
12 |
response_text: list[str] = []
|
13 |
|
14 |
|
15 |
-
class
|
16 |
def __init__(
|
17 |
self,
|
18 |
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
@@ -24,14 +26,14 @@ class SemanticCaching:
|
|
24 |
self.json_file = json_file
|
25 |
self.vector_index = faiss.IndexFlatIP(self.dimension)
|
26 |
self.encoder = SentenceTransformer(model_name)
|
27 |
-
self._cache =
|
28 |
self.load_cache()
|
29 |
|
30 |
def load_cache(self) -> None:
|
31 |
try:
|
32 |
with open(self.json_file, "r") as file:
|
33 |
data = json.load(file)
|
34 |
-
self._cache =
|
35 |
for emb in self._cache.embeddings:
|
36 |
np_emb = np.array(emb, dtype=np.float32)
|
37 |
faiss.normalize_L2(np_emb.reshape(1, -1))
|
@@ -71,7 +73,7 @@ class SemanticCaching:
|
|
71 |
logger.info("New response saved to cache.")
|
72 |
|
73 |
def clear(self):
|
74 |
-
self._cache =
|
75 |
self.vector_index.reset()
|
76 |
self.save_cache()
|
77 |
logger.info("Cache cleared.")
|
|
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
from loguru import logger
|
7 |
|
8 |
+
from medirag.cache.abc import SemanticCache
|
9 |
|
10 |
+
|
11 |
+
class SemanticCacheModel(BaseModel):
|
12 |
questions: list[str] = []
|
13 |
embeddings: list[list[float]] = []
|
14 |
response_text: list[str] = []
|
15 |
|
16 |
|
17 |
+
class LocalSemanticCache(SemanticCache):
|
18 |
def __init__(
|
19 |
self,
|
20 |
model_name: str = "sentence-transformers/all-mpnet-base-v2",
|
|
|
26 |
self.json_file = json_file
|
27 |
self.vector_index = faiss.IndexFlatIP(self.dimension)
|
28 |
self.encoder = SentenceTransformer(model_name)
|
29 |
+
self._cache = SemanticCacheModel() # Initialize with a default SemanticCache to avoid NoneType issues
|
30 |
self.load_cache()
|
31 |
|
32 |
def load_cache(self) -> None:
|
33 |
try:
|
34 |
with open(self.json_file, "r") as file:
|
35 |
data = json.load(file)
|
36 |
+
self._cache = SemanticCacheModel(**data) # Use unpacking to handle Pydantic validation
|
37 |
for emb in self._cache.embeddings:
|
38 |
np_emb = np.array(emb, dtype=np.float32)
|
39 |
faiss.normalize_L2(np_emb.reshape(1, -1))
|
|
|
73 |
logger.info("New response saved to cache.")
|
74 |
|
75 |
def clear(self):
|
76 |
+
self._cache = SemanticCacheModel()
|
77 |
self.vector_index.reset()
|
78 |
self.save_cache()
|
79 |
logger.info("Cache cleared.")
|
tests/cache/test_semantic_cache.py
CHANGED
@@ -1,31 +1,31 @@
|
|
1 |
import pytest
|
2 |
-
from medirag.cache.local import
|
3 |
|
4 |
|
5 |
# Fixture to initialize the SemanticCaching object
|
6 |
@pytest.fixture(scope="module")
|
7 |
-
def
|
8 |
# Initialize the SemanticCaching class with a test cache file
|
9 |
-
return
|
10 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="real_test_cache.json"
|
11 |
)
|
12 |
|
13 |
|
14 |
-
def test_save_and_lookup_in_cache(
|
15 |
# Clear any existing cache data
|
16 |
-
|
17 |
|
18 |
# Step 1: Lookup should return None for a question not in the cache
|
19 |
-
initial_lookup =
|
20 |
assert initial_lookup is None
|
21 |
|
22 |
# Step 2: Save a response to the cache
|
23 |
-
|
24 |
|
25 |
# Step 3: Lookup the same question; it should now return the cached response
|
26 |
-
cached_response =
|
27 |
assert cached_response is not None
|
28 |
assert cached_response == "Paris"
|
29 |
|
30 |
# Cleanup: Clear the cache after test
|
31 |
-
|
|
|
1 |
import pytest
|
2 |
+
from medirag.cache.local import LocalSemanticCache
|
3 |
|
4 |
|
5 |
# Fixture to initialize the SemanticCaching object
|
6 |
@pytest.fixture(scope="module")
|
7 |
+
def semantic_cache():
|
8 |
# Initialize the SemanticCaching class with a test cache file
|
9 |
+
return LocalSemanticCache(
|
10 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="real_test_cache.json"
|
11 |
)
|
12 |
|
13 |
|
14 |
+
def test_save_and_lookup_in_cache(semantic_cache):
|
15 |
# Clear any existing cache data
|
16 |
+
semantic_cache.clear()
|
17 |
|
18 |
# Step 1: Lookup should return None for a question not in the cache
|
19 |
+
initial_lookup = semantic_cache.lookup("What is the capital of France?")
|
20 |
assert initial_lookup is None
|
21 |
|
22 |
# Step 2: Save a response to the cache
|
23 |
+
semantic_cache.save("What is the capital of France?", "Paris")
|
24 |
|
25 |
# Step 3: Lookup the same question; it should now return the cached response
|
26 |
+
cached_response = semantic_cache.lookup("What is the capital of France?")
|
27 |
assert cached_response is not None
|
28 |
assert cached_response == "Paris"
|
29 |
|
30 |
# Cleanup: Clear the cache after test
|
31 |
+
semantic_cache.clear()
|
tests/rag/test_rag.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from medirag.cache.local import
|
2 |
from medirag.index.local import LocalIndexer
|
3 |
|
4 |
# from medirag.index.kdbai import KDBAIDailyMedIndexer
|
@@ -36,10 +36,9 @@ def test_rag_with_example(data_dir):
|
|
36 |
|
37 |
rag = RAG(k=3)
|
38 |
|
39 |
-
sm =
|
40 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
|
41 |
)
|
42 |
-
# sm.load_cache()
|
43 |
|
44 |
result1 = ask_med_question(sm, rag, query)
|
45 |
print(result1)
|
|
|
1 |
+
from medirag.cache.local import LocalSemanticCache
|
2 |
from medirag.index.local import LocalIndexer
|
3 |
|
4 |
# from medirag.index.kdbai import KDBAIDailyMedIndexer
|
|
|
36 |
|
37 |
rag = RAG(k=3)
|
38 |
|
39 |
+
sm = LocalSemanticCache(
|
40 |
model_name="sentence-transformers/all-mpnet-base-v2", dimension=768, json_file="rag_test_cache.json"
|
41 |
)
|
|
|
42 |
|
43 |
result1 = ask_med_question(sm, rag, query)
|
44 |
print(result1)
|