ankush13r commited on
Commit
858a66e
1 Parent(s): 02212c9

modify from rag to only vs

Browse files
Files changed (7) hide show
  1. .evn +1 -0
  2. app.py +29 -72
  3. handler.py +0 -14
  4. input_reader.py +0 -22
  5. requirements.txt +1 -1
  6. utils.py +0 -20
  7. rag.py → vectorstore.py +4 -27
.evn ADDED
@@ -0,0 +1 @@
 
 
1
+ EMBEDDINGS="BAAI/bge-m3"
app.py CHANGED
@@ -2,107 +2,64 @@ import os
2
  import gradio as gr
3
  from gradio.components import Textbox, Button, Slider, Checkbox
4
  from AinaTheme import theme
5
- from urllib.error import HTTPError
6
 
7
- from rag import RAG
8
- from utils import setup
9
 
10
  MAX_NEW_TOKENS = 700
11
- SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
 
13
- setup()
14
-
15
-
16
- rag = RAG(embeddings_model=os.getenv("EMBEDDINGS"))
17
 
 
18
 
19
  def eadop_rag(prompt, num_chunks):
20
- model_parameters = {"NUM_CHUNKS": num_chunks}
21
- try:
22
- return rag.get_context(prompt, model_parameters)
23
- except HTTPError as err:
24
- if err.code == 400:
25
- gr.Warning(
26
- "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
27
- )
28
- return None, None, None
29
- except:
30
  gr.Warning(
31
- "Inference endpoint is not available right now. Please try again later."
32
  )
33
- return None, None, None
34
-
35
-
36
  def clear():
37
  return (
38
- None,
39
- None,
40
  None,
41
  None,
42
  gr.Slider(value=2.0),
43
  )
44
 
45
-
46
  def gradio_app():
47
  with gr.Blocks(theme=theme) as demo:
48
  with gr.Row(equal_height=True):
49
- with gr.Column(variant="panel"):
50
- input_ = Textbox(
51
- lines=11,
52
- label="Input",
53
- placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
54
- )
55
- with gr.Row(variant="panel"):
56
- clear_btn = Button(
57
- "Clear",
58
- )
59
- submit_btn = Button("Submit", variant="primary", interactive=False)
60
 
61
- with gr.Row(variant="panel"):
62
- with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
63
- num_chunks = Slider(
 
 
 
 
 
64
  minimum=1,
65
  maximum=6,
66
  step=1,
67
  value=2,
68
  label="Number of chunks"
69
  )
70
-
71
- with gr.Column(variant="panel"):
72
- output = Textbox(
73
- lines=10,
74
- label="Context",
75
- interactive=False,
76
- show_copy_button=True
77
- )
78
- with gr.Accordion("Sources and context:", open=False):
79
- source_context = gr.Markdown(
80
- label="Sources",
81
- show_label=False,
82
- )
83
- with gr.Accordion("See full context evaluation:", open=False):
84
- context_evaluation = gr.Markdown(
85
- label="Full context",
86
- show_label=False,
87
- )
88
-
89
-
90
-
91
-
92
- input_.change(
93
- fn=None,
94
- inputs=[input_],
95
- api_name=False,
96
- js="""(i, m) => {
97
- document.getElementById('inputlength').textContent = i.length + ' '
98
- document.getElementById('inputlength').style.color = (i.length > m) ? "#ef4444" : "";
99
- }""",
100
- )
101
 
102
  clear_btn.click(
103
  fn=clear,
104
  inputs=[],
105
- outputs=[input_, output, source_context, context_evaluation, num_chunks],
106
  queue=False,
107
  api_name=False
108
  )
@@ -110,7 +67,7 @@ def gradio_app():
110
  submit_btn.click(
111
  fn=eadop_rag,
112
  inputs=[input_, num_chunks],
113
- outputs=[output, source_context, context_evaluation],
114
  api_name="get-eadop-rag"
115
  )
116
 
 
2
  import gradio as gr
3
  from gradio.components import Textbox, Button, Slider, Checkbox
4
  from AinaTheme import theme
 
5
 
6
+ from vectorstore import VectorStore
 
7
 
8
  MAX_NEW_TOKENS = 700
 
9
 
 
 
 
 
10
 
11
+ vectorStore = VectorStore(embeddings_model=os.getenv("EMBEDDINGS", "BAAI/bge-m3"))
12
 
13
  def eadop_rag(prompt, num_chunks):
14
+ prompt = prompt.strip()
15
+ if prompt == "":
 
 
 
 
 
 
 
 
16
  gr.Warning(
17
+ "Prompt can't be empty!"
18
  )
19
+ raise ValueError("Prompt can't be empty!")
20
+ return vectorStore.get_context(prompt, num_chunks)
21
+
22
  def clear():
23
  return (
 
 
24
  None,
25
  None,
26
  gr.Slider(value=2.0),
27
  )
28
 
 
29
  def gradio_app():
30
  with gr.Blocks(theme=theme) as demo:
31
  with gr.Row(equal_height=True):
32
+ output = Textbox(
33
+ lines=10,
34
+ label="Context",
35
+ interactive=False,
36
+ show_copy_button=True
37
+ )
 
 
 
 
 
38
 
39
+ with gr.Row(equal_height=True):
40
+ input_ = Textbox(
41
+ label="Input",
42
+ placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
43
+ )
44
+
45
+ with gr.Row(equal_height=True):
46
+ num_chunks = Slider(
47
  minimum=1,
48
  maximum=6,
49
  step=1,
50
  value=2,
51
  label="Number of chunks"
52
  )
53
+ with gr.Row(equal_height=True):
54
+ clear_btn = Button("Clear")
55
+ with gr.Row(equal_height=True):
56
+ submit_btn = Button("Submit", variant="primary")
57
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  clear_btn.click(
60
  fn=clear,
61
  inputs=[],
62
+ outputs=[input_, output, num_chunks],
63
  queue=False,
64
  api_name=False
65
  )
 
67
  submit_btn.click(
68
  fn=eadop_rag,
69
  inputs=[input_, num_chunks],
70
+ outputs=[output],
71
  api_name="get-eadop-rag"
72
  )
73
 
handler.py DELETED
@@ -1,14 +0,0 @@
1
- import json
2
-
3
- class ContentHandler():
4
- content_type = "application/json"
5
- accepts = "application/json"
6
-
7
- def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
- input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
- return input_str.encode('utf-8')
10
-
11
- def transform_output(self, output: bytes) -> str:
12
- response_json = json.loads(output.read().decode("utf-8"))
13
- return response_json[0]["generated_text"]
14
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
input_reader.py DELETED
@@ -1,22 +0,0 @@
1
- from typing import List
2
-
3
- from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
- from llama_index.core.readers import SimpleDirectoryReader
5
- from llama_index.core.schema import Document
6
- from llama_index.core import Settings
7
-
8
-
9
- class InputReader:
10
- def __init__(self, input_dir: str) -> None:
11
- self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
-
13
- def parse_documents(
14
- self,
15
- show_progress: bool = True,
16
- chunk_size: int = DEFAULT_CHUNK_SIZE,
17
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
- ) -> List[Document]:
19
- Settings.chunk_size = chunk_size
20
- Settings.chunk_overlap = chunk_overlap
21
- documents = self.reader.load_data(show_progress=show_progress)
22
- return documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio==4.29.0
2
  huggingface-hub==0.23.4
3
  openai==1.35.13
4
  python-dotenv==1.0.0
 
1
+ gradio==4.44.1
2
  huggingface-hub==0.23.4
3
  openai==1.35.13
4
  python-dotenv==1.0.0
utils.py CHANGED
@@ -3,12 +3,7 @@ import warnings
3
 
4
  from dotenv import load_dotenv
5
 
6
-
7
- from rag import RAG
8
-
9
  USER_INPUT = 100
10
-
11
-
12
  def setup():
13
  load_dotenv()
14
  warnings.filterwarnings("ignore")
@@ -16,18 +11,3 @@ def setup():
16
  logging.addLevelName(USER_INPUT, "USER_INPUT")
17
  logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
18
 
19
-
20
- def interactive(model: RAG):
21
- logging.info("Write `exit` when you want to stop the model.")
22
- print()
23
-
24
- query = ""
25
- while query.lower() != "exit":
26
- logging.log(USER_INPUT, "Write the query or `exit`:")
27
- query = input()
28
-
29
- if query.lower() == "exit":
30
- break
31
-
32
- response = model.get_response(query)
33
- print(response, end="\n\n")
 
3
 
4
  from dotenv import load_dotenv
5
 
 
 
 
6
  USER_INPUT = 100
 
 
7
  def setup():
8
  load_dotenv()
9
  warnings.filterwarnings("ignore")
 
11
  logging.addLevelName(USER_INPUT, "USER_INPUT")
12
  logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag.py → vectorstore.py RENAMED
@@ -2,39 +2,24 @@ import logging
2
  import os
3
  import requests
4
 
5
-
6
-
7
  from langchain_community.vectorstores import FAISS
8
  from langchain_community.embeddings import HuggingFaceEmbeddings
9
 
10
-
11
- class RAG:
12
- NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
13
-
14
- #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed
15
- #vectorstore = "vectorestore" # CA only
16
  vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
17
 
18
  def __init__(self, embeddings_model):
19
-
20
-
21
  # load vectore store
22
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
23
- self.vectore_store = FAISS.load_local(self.vectorstore, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
24
-
25
  logging.info("RAG loaded!")
26
 
27
  def get_context(self, instruction, number_of_contexts=2):
28
-
29
  documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
30
-
31
- return documentos
32
 
33
-
34
  def beautiful_context(self, docs):
35
-
36
  text_context = ""
37
-
38
  full_context = ""
39
  source_context = []
40
  for doc in docs:
@@ -44,12 +29,4 @@ class RAG:
44
  full_context += doc[0].page_content + "\n"
45
  source_context.append(doc[0].metadata["url"])
46
 
47
- return text_context, full_context, source_context
48
-
49
- def get_context(self, prompt: str, model_parameters: dict) -> str:
50
- try:
51
- docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
52
- return self.beautiful_context(docs)
53
- except Exception as err:
54
- print(err)
55
- return None, None, None
 
2
  import os
3
  import requests
4
 
 
 
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
 
8
+ class VectorStore:
 
 
 
 
 
9
  vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
10
 
11
  def __init__(self, embeddings_model):
 
 
12
  # load vectore store
13
  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
14
+ self.vectore_store = FAISS.load_local(self.vectorstore, embeddings, allow_dangerous_deserialization=True)
 
15
  logging.info("RAG loaded!")
16
 
17
  def get_context(self, instruction, number_of_contexts=2):
 
18
  documentos = self.vectore_store.similarity_search_with_score(instruction, k=number_of_contexts)
19
+ return self.beautiful_context(documentos)
 
20
 
 
21
  def beautiful_context(self, docs):
 
22
  text_context = ""
 
23
  full_context = ""
24
  source_context = []
25
  for doc in docs:
 
29
  full_context += doc[0].page_content + "\n"
30
  source_context.append(doc[0].metadata["url"])
31
 
32
+ return full_context