Spaces:
Running
Running
alvinhenrick
commited on
Commit
•
03ddbba
1
Parent(s):
fedacb2
add logging
Browse files- app.py +2 -2
- medirag/cache/local.py +39 -35
- poetry.lock +33 -1
- pyproject.toml +1 -0
- requirements.txt +6 -4
- tests/rag/test_rag.py +1 -1
app.py
CHANGED
@@ -26,7 +26,7 @@ dspy.settings.configure(lm=turbo, rm=rm)
|
|
26 |
Settings.llm = OpenAI(model='gpt-3.5-turbo')
|
27 |
|
28 |
sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
|
29 |
-
json_file='rag_test_cache.json'
|
30 |
sm.load_cache()
|
31 |
|
32 |
# Initialize RAGWorkflow with indexer
|
@@ -41,7 +41,7 @@ def clear_cache():
|
|
41 |
|
42 |
async def ask_med_question(query, enable_stream):
|
43 |
# Check the cache first
|
44 |
-
response = sm.lookup(question=query)
|
45 |
if response:
|
46 |
# Return cached response if found
|
47 |
yield response
|
|
|
26 |
Settings.llm = OpenAI(model='gpt-3.5-turbo')
|
27 |
|
28 |
sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
|
29 |
+
json_file='rag_test_cache.json')
|
30 |
sm.load_cache()
|
31 |
|
32 |
# Initialize RAGWorkflow with indexer
|
|
|
41 |
|
42 |
async def ask_med_question(query, enable_stream):
|
43 |
# Check the cache first
|
44 |
+
response = sm.lookup(question=query, cosine_threshold=0.9)
|
45 |
if response:
|
46 |
# Return cached response if found
|
47 |
yield response
|
medirag/cache/local.py
CHANGED
@@ -1,75 +1,79 @@
|
|
1 |
import faiss
|
2 |
import json
|
3 |
import numpy as np
|
|
|
4 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
class SemanticCaching:
|
8 |
def __init__(self,
|
9 |
-
model_name='sentence-transformers/all-mpnet-base-v2',
|
10 |
-
dimension=768,
|
11 |
-
json_file='cache.json'
|
12 |
-
cosine_threshold=0.7):
|
13 |
|
14 |
self.model_name = model_name
|
15 |
self.dimension = dimension
|
16 |
-
self.
|
17 |
self.vector_index = faiss.IndexFlatIP(self.dimension)
|
18 |
self.encoder = SentenceTransformer(self.model_name)
|
19 |
-
self.
|
20 |
-
self.cache = self.load_cache()
|
21 |
|
22 |
-
def load_cache(self):
|
23 |
"""Load cache from a JSON file."""
|
24 |
-
local_cache = {'questions': [], 'embeddings': [], 'response_text': []}
|
25 |
try:
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
self.vector_index.add(_embedding)
|
33 |
-
return local_cache
|
34 |
-
else:
|
35 |
-
return local_cache
|
36 |
except FileNotFoundError:
|
37 |
-
|
|
|
|
|
38 |
except Exception as e:
|
39 |
-
|
40 |
-
|
41 |
|
42 |
def save_cache(self):
|
43 |
"""Save the current cache to a JSON file."""
|
|
|
44 |
with open(self.json_file, 'w') as file:
|
45 |
-
json.dump(
|
|
|
46 |
|
47 |
-
def lookup(self, question: str) -> str | None:
|
48 |
"""Check if a question is in the cache and return the cached response if it exists."""
|
49 |
embedding = self.encoder.encode([question], show_progress_bar=False)
|
50 |
faiss.normalize_L2(embedding)
|
51 |
-
|
52 |
-
# Search in the index
|
53 |
D, I = self.vector_index.search(embedding, 1)
|
54 |
|
55 |
-
if D[0][0] >=
|
56 |
row_id = I[0][0]
|
57 |
-
return self.cache
|
58 |
-
|
59 |
-
return None
|
60 |
|
61 |
def save(self, question: str, response: str):
|
62 |
"""Save a response to the cache."""
|
63 |
embedding = self.encoder.encode([question], show_progress_bar=False)
|
64 |
faiss.normalize_L2(embedding)
|
65 |
-
|
66 |
-
self.cache
|
67 |
-
self.cache
|
68 |
-
self.cache['response_text'].append(response)
|
69 |
self.vector_index.add(embedding)
|
70 |
self.save_cache()
|
|
|
71 |
|
72 |
def clear(self):
|
73 |
-
|
|
|
74 |
self.vector_index.reset()
|
75 |
self.save_cache()
|
|
|
|
1 |
import faiss
|
2 |
import json
|
3 |
import numpy as np
|
4 |
+
from pydantic import BaseModel, ValidationError
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
+
from loguru import logger
|
7 |
+
|
8 |
+
|
9 |
+
class SemanticCache(BaseModel):
|
10 |
+
questions: list[str] = []
|
11 |
+
embeddings: list[list[float]] = []
|
12 |
+
response_text: list[str] = []
|
13 |
|
14 |
|
15 |
class SemanticCaching:
|
16 |
def __init__(self,
|
17 |
+
model_name: str = 'sentence-transformers/all-mpnet-base-v2',
|
18 |
+
dimension: int = 768,
|
19 |
+
json_file: str = 'cache.json'):
|
|
|
20 |
|
21 |
self.model_name = model_name
|
22 |
self.dimension = dimension
|
23 |
+
self.json_file = json_file
|
24 |
self.vector_index = faiss.IndexFlatIP(self.dimension)
|
25 |
self.encoder = SentenceTransformer(self.model_name)
|
26 |
+
self.load_cache()
|
|
|
27 |
|
28 |
+
def load_cache(self) -> None:
|
29 |
"""Load cache from a JSON file."""
|
|
|
30 |
try:
|
31 |
+
with open(self.json_file, 'r') as file:
|
32 |
+
data = json.load(file)
|
33 |
+
data['embeddings'] = [np.array(e, dtype=np.float32) for e in data.get('embeddings', [])]
|
34 |
+
for emb in data['embeddings']:
|
35 |
+
self.vector_index.add(emb)
|
36 |
+
self.cache = SemanticCache(**data)
|
|
|
|
|
|
|
|
|
37 |
except FileNotFoundError:
|
38 |
+
logger.info("Cache file not found, initializing new cache.")
|
39 |
+
except ValidationError as e:
|
40 |
+
logger.error(f"Error in cache data structure: {e}")
|
41 |
except Exception as e:
|
42 |
+
logger.error(f"Failed to load or process cache: {e}")
|
43 |
+
self.cache = SemanticCache()
|
44 |
|
45 |
def save_cache(self):
|
46 |
"""Save the current cache to a JSON file."""
|
47 |
+
data = self.cache.model_dump_json()
|
48 |
with open(self.json_file, 'w') as file:
|
49 |
+
json.dump(data, file)
|
50 |
+
logger.info("Cache saved successfully.")
|
51 |
|
52 |
+
def lookup(self, question: str, cosine_threshold: float = 0.7) -> str | None:
|
53 |
"""Check if a question is in the cache and return the cached response if it exists."""
|
54 |
embedding = self.encoder.encode([question], show_progress_bar=False)
|
55 |
faiss.normalize_L2(embedding)
|
|
|
|
|
56 |
D, I = self.vector_index.search(embedding, 1)
|
57 |
|
58 |
+
if D[0][0] >= cosine_threshold:
|
59 |
row_id = I[0][0]
|
60 |
+
return self.cache.response_text[row_id]
|
61 |
+
return None
|
|
|
62 |
|
63 |
def save(self, question: str, response: str):
|
64 |
"""Save a response to the cache."""
|
65 |
embedding = self.encoder.encode([question], show_progress_bar=False)
|
66 |
faiss.normalize_L2(embedding)
|
67 |
+
self.cache.questions.append(question)
|
68 |
+
self.cache.embeddings.append(embedding.tolist())
|
69 |
+
self.cache.response_text.append(response)
|
|
|
70 |
self.vector_index.add(embedding)
|
71 |
self.save_cache()
|
72 |
+
logger.info("New response saved to cache.")
|
73 |
|
74 |
def clear(self):
|
75 |
+
"""Clear the cache."""
|
76 |
+
self.cache = SemanticCache()
|
77 |
self.vector_index.reset()
|
78 |
self.save_cache()
|
79 |
+
logger.info("Cache cleared.")
|
poetry.lock
CHANGED
@@ -1940,6 +1940,24 @@ llama-index-core = ">=0.11.0,<0.12.0"
|
|
1940 |
pandas = "*"
|
1941 |
pykx = ">=2.1.1,<3.0.0"
|
1942 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1943 |
[[package]]
|
1944 |
name = "lxml"
|
1945 |
version = "5.3.0"
|
@@ -5053,6 +5071,20 @@ files = [
|
|
5053 |
{file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
|
5054 |
]
|
5055 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5056 |
[[package]]
|
5057 |
name = "wrapt"
|
5058 |
version = "1.16.0"
|
@@ -5370,4 +5402,4 @@ multidict = ">=4.0"
|
|
5370 |
[metadata]
|
5371 |
lock-version = "2.0"
|
5372 |
python-versions = ">=3.10,<3.13"
|
5373 |
-
content-hash = "
|
|
|
1940 |
pandas = "*"
|
1941 |
pykx = ">=2.1.1,<3.0.0"
|
1942 |
|
1943 |
+
[[package]]
|
1944 |
+
name = "loguru"
|
1945 |
+
version = "0.7.2"
|
1946 |
+
description = "Python logging made (stupidly) simple"
|
1947 |
+
optional = false
|
1948 |
+
python-versions = ">=3.5"
|
1949 |
+
files = [
|
1950 |
+
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
|
1951 |
+
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
|
1952 |
+
]
|
1953 |
+
|
1954 |
+
[package.dependencies]
|
1955 |
+
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
|
1956 |
+
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
|
1957 |
+
|
1958 |
+
[package.extras]
|
1959 |
+
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
|
1960 |
+
|
1961 |
[[package]]
|
1962 |
name = "lxml"
|
1963 |
version = "5.3.0"
|
|
|
5071 |
{file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
|
5072 |
]
|
5073 |
|
5074 |
+
[[package]]
|
5075 |
+
name = "win32-setctime"
|
5076 |
+
version = "1.1.0"
|
5077 |
+
description = "A small Python utility to set file creation time on Windows"
|
5078 |
+
optional = false
|
5079 |
+
python-versions = ">=3.5"
|
5080 |
+
files = [
|
5081 |
+
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
|
5082 |
+
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
5083 |
+
]
|
5084 |
+
|
5085 |
+
[package.extras]
|
5086 |
+
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
|
5087 |
+
|
5088 |
[[package]]
|
5089 |
name = "wrapt"
|
5090 |
version = "1.16.0"
|
|
|
5402 |
[metadata]
|
5403 |
lock-version = "2.0"
|
5404 |
python-versions = ">=3.10,<3.13"
|
5405 |
+
content-hash = "e2a76129035f221f383481d3a02d57ba0a52337b56f684ce0b4e847e1262def5"
|
pyproject.toml
CHANGED
@@ -31,6 +31,7 @@ accelerate = ">=0.33.0"
|
|
31 |
gradio = ">=4.42.0"
|
32 |
pydantic = ">=2.8.2"
|
33 |
kdbai-client = ">=1.2.4"
|
|
|
34 |
|
35 |
|
36 |
|
|
|
31 |
gradio = ">=4.42.0"
|
32 |
pydantic = ">=2.8.2"
|
33 |
kdbai-client = ">=1.2.4"
|
34 |
+
loguru = "^0.7.2"
|
35 |
|
36 |
|
37 |
|
requirements.txt
CHANGED
@@ -11,7 +11,7 @@ async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
|
|
11 |
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
12 |
backoff==2.2.1 ; python_version >= "3.10" and python_version < "3.13"
|
13 |
beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "3.13"
|
14 |
-
certifi==2024.
|
15 |
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
|
16 |
click==8.1.7 ; python_version >= "3.10" and python_version < "3.13"
|
17 |
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (platform_system == "Windows" or sys_platform == "win32")
|
@@ -46,7 +46,7 @@ huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
|
|
46 |
huggingface-hub[inference]==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
|
47 |
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
|
48 |
importlib-resources==6.4.4 ; python_version >= "3.10" and python_version < "3.13"
|
49 |
-
ipython==8.
|
50 |
jedi==0.19.1 ; python_version >= "3.10" and python_version < "3.13"
|
51 |
jinja2==3.1.4 ; python_version >= "3.10" and python_version < "3.13"
|
52 |
jiter==0.5.0 ; python_version >= "3.10" and python_version < "3.13"
|
@@ -59,9 +59,9 @@ kiwisolver==1.4.5 ; python_version >= "3.10" and python_version < "3.13"
|
|
59 |
langchain-core==0.2.36 ; python_version >= "3.10" and python_version < "3.13"
|
60 |
langchain-text-splitters==0.2.2 ; python_version >= "3.10" and python_version < "3.13"
|
61 |
langchain==0.2.15 ; python_version >= "3.10" and python_version < "3.13"
|
62 |
-
langsmith==0.1.
|
63 |
llama-index-agent-openai==0.3.0 ; python_version >= "3.10" and python_version < "3.13"
|
64 |
-
llama-index-core==0.11.
|
65 |
llama-index-embeddings-huggingface==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
|
66 |
llama-index-embeddings-openai==0.2.3 ; python_version >= "3.10" and python_version < "3.13"
|
67 |
llama-index-llms-openai==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
@@ -69,6 +69,7 @@ llama-index-readers-file==0.2.0 ; python_version >= "3.10" and python_version <
|
|
69 |
llama-index-utils-workflow==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
70 |
llama-index-vector-stores-faiss==0.2.1 ; python_version >= "3.10" and python_version < "3.13"
|
71 |
llama-index-vector-stores-kdbai==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
|
|
|
72 |
lxml==5.3.0 ; python_version >= "3.10" and python_version < "3.13"
|
73 |
mako==1.3.5 ; python_version >= "3.10" and python_version < "3.13"
|
74 |
markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
|
@@ -164,6 +165,7 @@ urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
|
|
164 |
uvicorn==0.30.6 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
|
165 |
wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "3.13"
|
166 |
websockets==12.0 ; python_version >= "3.10" and python_version < "3.13"
|
|
|
167 |
wrapt==1.16.0 ; python_version >= "3.10" and python_version < "3.13"
|
168 |
xxhash==3.5.0 ; python_version >= "3.10" and python_version < "3.13"
|
169 |
yarl==1.9.4 ; python_version >= "3.10" and python_version < "3.13"
|
|
|
11 |
attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
12 |
backoff==2.2.1 ; python_version >= "3.10" and python_version < "3.13"
|
13 |
beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "3.13"
|
14 |
+
certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
|
15 |
charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
|
16 |
click==8.1.7 ; python_version >= "3.10" and python_version < "3.13"
|
17 |
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (platform_system == "Windows" or sys_platform == "win32")
|
|
|
46 |
huggingface-hub[inference]==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
|
47 |
idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
|
48 |
importlib-resources==6.4.4 ; python_version >= "3.10" and python_version < "3.13"
|
49 |
+
ipython==8.27.0 ; python_version >= "3.10" and python_version < "3.13"
|
50 |
jedi==0.19.1 ; python_version >= "3.10" and python_version < "3.13"
|
51 |
jinja2==3.1.4 ; python_version >= "3.10" and python_version < "3.13"
|
52 |
jiter==0.5.0 ; python_version >= "3.10" and python_version < "3.13"
|
|
|
59 |
langchain-core==0.2.36 ; python_version >= "3.10" and python_version < "3.13"
|
60 |
langchain-text-splitters==0.2.2 ; python_version >= "3.10" and python_version < "3.13"
|
61 |
langchain==0.2.15 ; python_version >= "3.10" and python_version < "3.13"
|
62 |
+
langsmith==0.1.107 ; python_version >= "3.10" and python_version < "3.13"
|
63 |
llama-index-agent-openai==0.3.0 ; python_version >= "3.10" and python_version < "3.13"
|
64 |
+
llama-index-core==0.11.3 ; python_version >= "3.10" and python_version < "3.13"
|
65 |
llama-index-embeddings-huggingface==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
|
66 |
llama-index-embeddings-openai==0.2.3 ; python_version >= "3.10" and python_version < "3.13"
|
67 |
llama-index-llms-openai==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
|
|
69 |
llama-index-utils-workflow==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
|
70 |
llama-index-vector-stores-faiss==0.2.1 ; python_version >= "3.10" and python_version < "3.13"
|
71 |
llama-index-vector-stores-kdbai==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
|
72 |
+
loguru==0.7.2 ; python_version >= "3.10" and python_version < "3.13"
|
73 |
lxml==5.3.0 ; python_version >= "3.10" and python_version < "3.13"
|
74 |
mako==1.3.5 ; python_version >= "3.10" and python_version < "3.13"
|
75 |
markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
|
|
|
165 |
uvicorn==0.30.6 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
|
166 |
wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "3.13"
|
167 |
websockets==12.0 ; python_version >= "3.10" and python_version < "3.13"
|
168 |
+
win32-setctime==1.1.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
|
169 |
wrapt==1.16.0 ; python_version >= "3.10" and python_version < "3.13"
|
170 |
xxhash==3.5.0 ; python_version >= "3.10" and python_version < "3.13"
|
171 |
yarl==1.9.4 ; python_version >= "3.10" and python_version < "3.13"
|
tests/rag/test_rag.py
CHANGED
@@ -9,7 +9,7 @@ load_dotenv() # take environment variables from .env.
|
|
9 |
|
10 |
|
11 |
def ask_med_question(sm, rag, query):
|
12 |
-
response = sm.lookup(question=query)
|
13 |
if not response:
|
14 |
response = rag(query).answer
|
15 |
sm.save(query, response)
|
|
|
9 |
|
10 |
|
11 |
def ask_med_question(sm, rag, query):
|
12 |
+
response = sm.lookup(question=query, cosine_threshold=0.9)
|
13 |
if not response:
|
14 |
response = rag(query).answer
|
15 |
sm.save(query, response)
|