alvinhenrick commited on
Commit
03ddbba
1 Parent(s): fedacb2

add logging

Browse files
Files changed (6) hide show
  1. app.py +2 -2
  2. medirag/cache/local.py +39 -35
  3. poetry.lock +33 -1
  4. pyproject.toml +1 -0
  5. requirements.txt +6 -4
  6. tests/rag/test_rag.py +1 -1
app.py CHANGED
@@ -26,7 +26,7 @@ dspy.settings.configure(lm=turbo, rm=rm)
26
  Settings.llm = OpenAI(model='gpt-3.5-turbo')
27
 
28
  sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
29
- json_file='rag_test_cache.json', cosine_threshold=.90)
30
  sm.load_cache()
31
 
32
  # Initialize RAGWorkflow with indexer
@@ -41,7 +41,7 @@ def clear_cache():
41
 
42
  async def ask_med_question(query, enable_stream):
43
  # Check the cache first
44
- response = sm.lookup(question=query)
45
  if response:
46
  # Return cached response if found
47
  yield response
 
26
  Settings.llm = OpenAI(model='gpt-3.5-turbo')
27
 
28
  sm = SemanticCaching(model_name='sentence-transformers/all-mpnet-base-v2', dimension=768,
29
+ json_file='rag_test_cache.json')
30
  sm.load_cache()
31
 
32
  # Initialize RAGWorkflow with indexer
 
41
 
42
  async def ask_med_question(query, enable_stream):
43
  # Check the cache first
44
+ response = sm.lookup(question=query, cosine_threshold=0.9)
45
  if response:
46
  # Return cached response if found
47
  yield response
medirag/cache/local.py CHANGED
@@ -1,75 +1,79 @@
1
  import faiss
2
  import json
3
  import numpy as np
 
4
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
5
 
6
 
7
  class SemanticCaching:
8
  def __init__(self,
9
- model_name='sentence-transformers/all-mpnet-base-v2',
10
- dimension=768,
11
- json_file='cache.json',
12
- cosine_threshold=0.7):
13
 
14
  self.model_name = model_name
15
  self.dimension = dimension
16
- self.cosine_threshold = cosine_threshold
17
  self.vector_index = faiss.IndexFlatIP(self.dimension)
18
  self.encoder = SentenceTransformer(self.model_name)
19
- self.json_file = json_file
20
- self.cache = self.load_cache()
21
 
22
- def load_cache(self):
23
  """Load cache from a JSON file."""
24
- local_cache = {'questions': [], 'embeddings': [], 'response_text': []}
25
  try:
26
- if self.json_file:
27
- with open(self.json_file, 'r') as file:
28
- local_cache = json.load(file)
29
- if 'embeddings' in local_cache and len(local_cache['embeddings']) > 0:
30
- for embedding in local_cache['embeddings']:
31
- _embedding = np.array(embedding, dtype=np.float32)
32
- self.vector_index.add(_embedding)
33
- return local_cache
34
- else:
35
- return local_cache
36
  except FileNotFoundError:
37
- return local_cache
 
 
38
  except Exception as e:
39
- print(f"Failed to load or process cache: {e}")
40
- return local_cache
41
 
42
  def save_cache(self):
43
  """Save the current cache to a JSON file."""
 
44
  with open(self.json_file, 'w') as file:
45
- json.dump(self.cache, file)
 
46
 
47
- def lookup(self, question: str) -> str | None:
48
  """Check if a question is in the cache and return the cached response if it exists."""
49
  embedding = self.encoder.encode([question], show_progress_bar=False)
50
  faiss.normalize_L2(embedding)
51
-
52
- # Search in the index
53
  D, I = self.vector_index.search(embedding, 1)
54
 
55
- if D[0][0] >= self.cosine_threshold:
56
  row_id = I[0][0]
57
- return self.cache['response_text'][row_id]
58
- else:
59
- return None
60
 
61
  def save(self, question: str, response: str):
62
  """Save a response to the cache."""
63
  embedding = self.encoder.encode([question], show_progress_bar=False)
64
  faiss.normalize_L2(embedding)
65
-
66
- self.cache['questions'].append(question)
67
- self.cache['embeddings'].append(embedding.tolist())
68
- self.cache['response_text'].append(response)
69
  self.vector_index.add(embedding)
70
  self.save_cache()
 
71
 
72
  def clear(self):
73
- self.cache = {'questions': [], 'embeddings': [], 'response_text': []}
 
74
  self.vector_index.reset()
75
  self.save_cache()
 
 
1
  import faiss
2
  import json
3
  import numpy as np
4
+ from pydantic import BaseModel, ValidationError
5
  from sentence_transformers import SentenceTransformer
6
+ from loguru import logger
7
+
8
+
9
+ class SemanticCache(BaseModel):
10
+ questions: list[str] = []
11
+ embeddings: list[list[float]] = []
12
+ response_text: list[str] = []
13
 
14
 
15
  class SemanticCaching:
16
  def __init__(self,
17
+ model_name: str = 'sentence-transformers/all-mpnet-base-v2',
18
+ dimension: int = 768,
19
+ json_file: str = 'cache.json'):
 
20
 
21
  self.model_name = model_name
22
  self.dimension = dimension
23
+ self.json_file = json_file
24
  self.vector_index = faiss.IndexFlatIP(self.dimension)
25
  self.encoder = SentenceTransformer(self.model_name)
26
+ self.load_cache()
 
27
 
28
+ def load_cache(self) -> None:
29
  """Load cache from a JSON file."""
 
30
  try:
31
+ with open(self.json_file, 'r') as file:
32
+ data = json.load(file)
33
+ data['embeddings'] = [np.array(e, dtype=np.float32) for e in data.get('embeddings', [])]
34
+ for emb in data['embeddings']:
35
+ self.vector_index.add(emb)
36
+ self.cache = SemanticCache(**data)
 
 
 
 
37
  except FileNotFoundError:
38
+ logger.info("Cache file not found, initializing new cache.")
39
+ except ValidationError as e:
40
+ logger.error(f"Error in cache data structure: {e}")
41
  except Exception as e:
42
+ logger.error(f"Failed to load or process cache: {e}")
43
+ self.cache = SemanticCache()
44
 
45
  def save_cache(self):
46
  """Save the current cache to a JSON file."""
47
+ data = self.cache.model_dump_json()
48
  with open(self.json_file, 'w') as file:
49
+ json.dump(data, file)
50
+ logger.info("Cache saved successfully.")
51
 
52
+ def lookup(self, question: str, cosine_threshold: float = 0.7) -> str | None:
53
  """Check if a question is in the cache and return the cached response if it exists."""
54
  embedding = self.encoder.encode([question], show_progress_bar=False)
55
  faiss.normalize_L2(embedding)
 
 
56
  D, I = self.vector_index.search(embedding, 1)
57
 
58
+ if D[0][0] >= cosine_threshold:
59
  row_id = I[0][0]
60
+ return self.cache.response_text[row_id]
61
+ return None
 
62
 
63
  def save(self, question: str, response: str):
64
  """Save a response to the cache."""
65
  embedding = self.encoder.encode([question], show_progress_bar=False)
66
  faiss.normalize_L2(embedding)
67
+ self.cache.questions.append(question)
68
+ self.cache.embeddings.append(embedding.tolist())
69
+ self.cache.response_text.append(response)
 
70
  self.vector_index.add(embedding)
71
  self.save_cache()
72
+ logger.info("New response saved to cache.")
73
 
74
  def clear(self):
75
+ """Clear the cache."""
76
+ self.cache = SemanticCache()
77
  self.vector_index.reset()
78
  self.save_cache()
79
+ logger.info("Cache cleared.")
poetry.lock CHANGED
@@ -1940,6 +1940,24 @@ llama-index-core = ">=0.11.0,<0.12.0"
1940
  pandas = "*"
1941
  pykx = ">=2.1.1,<3.0.0"
1942
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1943
  [[package]]
1944
  name = "lxml"
1945
  version = "5.3.0"
@@ -5053,6 +5071,20 @@ files = [
5053
  {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
5054
  ]
5055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5056
  [[package]]
5057
  name = "wrapt"
5058
  version = "1.16.0"
@@ -5370,4 +5402,4 @@ multidict = ">=4.0"
5370
  [metadata]
5371
  lock-version = "2.0"
5372
  python-versions = ">=3.10,<3.13"
5373
- content-hash = "f03d8ab60928b71fb7bdb7dd2adfefc967249fdf18ff2660f9ba4182df1f0de4"
 
1940
  pandas = "*"
1941
  pykx = ">=2.1.1,<3.0.0"
1942
 
1943
+ [[package]]
1944
+ name = "loguru"
1945
+ version = "0.7.2"
1946
+ description = "Python logging made (stupidly) simple"
1947
+ optional = false
1948
+ python-versions = ">=3.5"
1949
+ files = [
1950
+ {file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
1951
+ {file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
1952
+ ]
1953
+
1954
+ [package.dependencies]
1955
+ colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
1956
+ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
1957
+
1958
+ [package.extras]
1959
+ dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
1960
+
1961
  [[package]]
1962
  name = "lxml"
1963
  version = "5.3.0"
 
5071
  {file = "websockets-12.0.tar.gz", hash = "sha256:81df9cbcbb6c260de1e007e58c011bfebe2dafc8435107b0537f393dd38c8b1b"},
5072
  ]
5073
 
5074
+ [[package]]
5075
+ name = "win32-setctime"
5076
+ version = "1.1.0"
5077
+ description = "A small Python utility to set file creation time on Windows"
5078
+ optional = false
5079
+ python-versions = ">=3.5"
5080
+ files = [
5081
+ {file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
5082
+ {file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
5083
+ ]
5084
+
5085
+ [package.extras]
5086
+ dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
5087
+
5088
  [[package]]
5089
  name = "wrapt"
5090
  version = "1.16.0"
 
5402
  [metadata]
5403
  lock-version = "2.0"
5404
  python-versions = ">=3.10,<3.13"
5405
+ content-hash = "e2a76129035f221f383481d3a02d57ba0a52337b56f684ce0b4e847e1262def5"
pyproject.toml CHANGED
@@ -31,6 +31,7 @@ accelerate = ">=0.33.0"
31
  gradio = ">=4.42.0"
32
  pydantic = ">=2.8.2"
33
  kdbai-client = ">=1.2.4"
 
34
 
35
 
36
 
 
31
  gradio = ">=4.42.0"
32
  pydantic = ">=2.8.2"
33
  kdbai-client = ">=1.2.4"
34
+ loguru = "^0.7.2"
35
 
36
 
37
 
requirements.txt CHANGED
@@ -11,7 +11,7 @@ async-timeout==4.0.3 ; python_version >= "3.10" and python_version < "3.11"
11
  attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
12
  backoff==2.2.1 ; python_version >= "3.10" and python_version < "3.13"
13
  beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "3.13"
14
- certifi==2024.7.4 ; python_version >= "3.10" and python_version < "3.13"
15
  charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
16
  click==8.1.7 ; python_version >= "3.10" and python_version < "3.13"
17
  colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (platform_system == "Windows" or sys_platform == "win32")
@@ -46,7 +46,7 @@ huggingface-hub==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
46
  huggingface-hub[inference]==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
47
  idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
48
  importlib-resources==6.4.4 ; python_version >= "3.10" and python_version < "3.13"
49
- ipython==8.26.0 ; python_version >= "3.10" and python_version < "3.13"
50
  jedi==0.19.1 ; python_version >= "3.10" and python_version < "3.13"
51
  jinja2==3.1.4 ; python_version >= "3.10" and python_version < "3.13"
52
  jiter==0.5.0 ; python_version >= "3.10" and python_version < "3.13"
@@ -59,9 +59,9 @@ kiwisolver==1.4.5 ; python_version >= "3.10" and python_version < "3.13"
59
  langchain-core==0.2.36 ; python_version >= "3.10" and python_version < "3.13"
60
  langchain-text-splitters==0.2.2 ; python_version >= "3.10" and python_version < "3.13"
61
  langchain==0.2.15 ; python_version >= "3.10" and python_version < "3.13"
62
- langsmith==0.1.106 ; python_version >= "3.10" and python_version < "3.13"
63
  llama-index-agent-openai==0.3.0 ; python_version >= "3.10" and python_version < "3.13"
64
- llama-index-core==0.11.2 ; python_version >= "3.10" and python_version < "3.13"
65
  llama-index-embeddings-huggingface==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
66
  llama-index-embeddings-openai==0.2.3 ; python_version >= "3.10" and python_version < "3.13"
67
  llama-index-llms-openai==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
@@ -69,6 +69,7 @@ llama-index-readers-file==0.2.0 ; python_version >= "3.10" and python_version <
69
  llama-index-utils-workflow==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
70
  llama-index-vector-stores-faiss==0.2.1 ; python_version >= "3.10" and python_version < "3.13"
71
  llama-index-vector-stores-kdbai==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
 
72
  lxml==5.3.0 ; python_version >= "3.10" and python_version < "3.13"
73
  mako==1.3.5 ; python_version >= "3.10" and python_version < "3.13"
74
  markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
@@ -164,6 +165,7 @@ urllib3==2.2.2 ; python_version >= "3.10" and python_version < "3.13"
164
  uvicorn==0.30.6 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
165
  wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "3.13"
166
  websockets==12.0 ; python_version >= "3.10" and python_version < "3.13"
 
167
  wrapt==1.16.0 ; python_version >= "3.10" and python_version < "3.13"
168
  xxhash==3.5.0 ; python_version >= "3.10" and python_version < "3.13"
169
  yarl==1.9.4 ; python_version >= "3.10" and python_version < "3.13"
 
11
  attrs==24.2.0 ; python_version >= "3.10" and python_version < "3.13"
12
  backoff==2.2.1 ; python_version >= "3.10" and python_version < "3.13"
13
  beautifulsoup4==4.12.3 ; python_version >= "3.10" and python_version < "3.13"
14
+ certifi==2024.8.30 ; python_version >= "3.10" and python_version < "3.13"
15
  charset-normalizer==3.3.2 ; python_version >= "3.10" and python_version < "3.13"
16
  click==8.1.7 ; python_version >= "3.10" and python_version < "3.13"
17
  colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.13" and (platform_system == "Windows" or sys_platform == "win32")
 
46
  huggingface-hub[inference]==0.24.6 ; python_version >= "3.10" and python_version < "3.13"
47
  idna==3.8 ; python_version >= "3.10" and python_version < "3.13"
48
  importlib-resources==6.4.4 ; python_version >= "3.10" and python_version < "3.13"
49
+ ipython==8.27.0 ; python_version >= "3.10" and python_version < "3.13"
50
  jedi==0.19.1 ; python_version >= "3.10" and python_version < "3.13"
51
  jinja2==3.1.4 ; python_version >= "3.10" and python_version < "3.13"
52
  jiter==0.5.0 ; python_version >= "3.10" and python_version < "3.13"
 
59
  langchain-core==0.2.36 ; python_version >= "3.10" and python_version < "3.13"
60
  langchain-text-splitters==0.2.2 ; python_version >= "3.10" and python_version < "3.13"
61
  langchain==0.2.15 ; python_version >= "3.10" and python_version < "3.13"
62
+ langsmith==0.1.107 ; python_version >= "3.10" and python_version < "3.13"
63
  llama-index-agent-openai==0.3.0 ; python_version >= "3.10" and python_version < "3.13"
64
+ llama-index-core==0.11.3 ; python_version >= "3.10" and python_version < "3.13"
65
  llama-index-embeddings-huggingface==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
66
  llama-index-embeddings-openai==0.2.3 ; python_version >= "3.10" and python_version < "3.13"
67
  llama-index-llms-openai==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
 
69
  llama-index-utils-workflow==0.2.0 ; python_version >= "3.10" and python_version < "3.13"
70
  llama-index-vector-stores-faiss==0.2.1 ; python_version >= "3.10" and python_version < "3.13"
71
  llama-index-vector-stores-kdbai==0.3.1 ; python_version >= "3.10" and python_version < "3.13"
72
+ loguru==0.7.2 ; python_version >= "3.10" and python_version < "3.13"
73
  lxml==5.3.0 ; python_version >= "3.10" and python_version < "3.13"
74
  mako==1.3.5 ; python_version >= "3.10" and python_version < "3.13"
75
  markdown-it-py==3.0.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
 
165
  uvicorn==0.30.6 ; python_version >= "3.10" and python_version < "3.13" and sys_platform != "emscripten"
166
  wcwidth==0.2.13 ; python_version >= "3.10" and python_version < "3.13"
167
  websockets==12.0 ; python_version >= "3.10" and python_version < "3.13"
168
+ win32-setctime==1.1.0 ; python_version >= "3.10" and python_version < "3.13" and sys_platform == "win32"
169
  wrapt==1.16.0 ; python_version >= "3.10" and python_version < "3.13"
170
  xxhash==3.5.0 ; python_version >= "3.10" and python_version < "3.13"
171
  yarl==1.9.4 ; python_version >= "3.10" and python_version < "3.13"
tests/rag/test_rag.py CHANGED
@@ -9,7 +9,7 @@ load_dotenv() # take environment variables from .env.
9
 
10
 
11
  def ask_med_question(sm, rag, query):
12
- response = sm.lookup(question=query)
13
  if not response:
14
  response = rag(query).answer
15
  sm.save(query, response)
 
9
 
10
 
11
  def ask_med_question(sm, rag, query):
12
+ response = sm.lookup(question=query, cosine_threshold=0.9)
13
  if not response:
14
  response = rag(query).answer
15
  sm.save(query, response)