Tonic commited on
Commit
a6d437d
1 Parent(s): 458fb6a

add jina embeddings and reranker

Browse files
Files changed (5) hide show
  1. README.md +5 -5
  2. app.py +0 -235
  3. globalvars.py +25 -54
  4. langchainapp.py +0 -243
  5. yijinaembed.py +231 -0
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: 01aiYi NvidiaEmbed
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
- app_file: langchainapp.py
9
- pinned: false
10
  license: mit
11
  ---
 
1
  ---
2
  title: 01aiYi NvidiaEmbed
3
+ emoji: ☯️🧠🛌🏻🥟🧩
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 4.36.1
8
+ app_file: yijinaembed.py
9
+ pinned: true
10
  license: mit
11
  ---
app.py DELETED
@@ -1,235 +0,0 @@
1
- # app.py
2
- import spaces
3
- from torch.nn import DataParallel
4
- from torch import Tensor
5
- from transformers import AutoTokenizer, AutoModel
6
- from huggingface_hub import InferenceClient
7
- from openai import OpenAI
8
- from langchain_community.document_loaders import UnstructuredFileLoader
9
- from langchain_chroma import Chroma
10
- from chromadb import Documents, EmbeddingFunction, Embeddings
11
- from chromadb.config import Settings
12
- import chromadb #import HttpClient
13
- import os
14
- import tempfile
15
- import re
16
- import uuid
17
- import gradio as gr
18
- import torch
19
- import torch.nn.functional as F
20
- from dotenv import load_dotenv
21
- from utils import load_env_variables, parse_and_route, escape_special_characters
22
- from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name, metadata_prompt
23
- from sentence_transformers import SentenceTransformer
24
-
25
-
26
- load_dotenv()
27
-
28
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30'
29
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
30
- os.environ['CUDA_CACHE_DISABLE'] = '1'
31
-
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
-
34
- # Ensure the temporary directory exists
35
- temp_dir = '/tmp/gradio/'
36
- os.makedirs(temp_dir, exist_ok=True)
37
-
38
- # Set Gradio cache directory
39
- gr.components.file.GRADIO_CACHE = temp_dir
40
-
41
- ### Utils
42
- hf_token, yi_token = load_env_variables()
43
-
44
- def clear_cuda_cache():
45
- torch.cuda.empty_cache()
46
-
47
- client = OpenAI(api_key=yi_token, base_url=API_BASE)
48
-
49
- chroma_client = chromadb.Client(Settings())
50
-
51
- # Create a collection
52
- chroma_collection = chroma_client.create_collection("all-my-documents")
53
-
54
- class EmbeddingGenerator:
55
- def __init__(self, model_name: str, token: str, intention_client):
56
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True)
58
- self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device)
59
- self.intention_client = intention_client
60
-
61
- def clear_cuda_cache(self):
62
- torch.cuda.empty_cache()
63
-
64
- @spaces.GPU
65
- def compute_embeddings(self, input_text: str):
66
- escaped_input_text = escape_special_characters(input_text)
67
- intention_completion = self.intention_client.chat.completions.create(
68
- model="yi-large",
69
- messages=[
70
- {"role": "system", "content": escape_special_characters(intention_prompt)},
71
- {"role": "user", "content": escaped_input_text}
72
- ]
73
- )
74
- intention_output = intention_completion.choices[0].message.content
75
- # Parse and route the intention
76
- parsed_task = parse_and_route(intention_output)
77
- selected_task = parsed_task
78
- # Construct the prompt
79
- if selected_task in tasks:
80
- task_description = tasks[selected_task]
81
- else:
82
- task_description = tasks["DEFAULT"]
83
- print(f"Selected task not found: {selected_task}")
84
-
85
- query_prefix = f"Instruct: {task_description}\nQuery: "
86
- queries = [escaped_input_text]
87
-
88
- # Get the metadata
89
- metadata_completion = self.intention_client.chat.completions.create(
90
- model="yi-large",
91
- messages=[
92
- {"role": "system", "content": escape_special_characters(metadata_prompt)},
93
- {"role": "user", "content": escaped_input_text}
94
- ]
95
- )
96
- metadata_output = metadata_completion.choices[0].message.content
97
- metadata = self.extract_metadata(metadata_output)
98
-
99
- # Get the embeddings
100
- with torch.no_grad():
101
- inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
102
- outputs = self.model(**inputs)
103
- query_embeddings = outputs["sentence_embeddings"].mean(dim=1)
104
- query_embeddings = outputs.last_hidden_state.mean(dim=1)
105
-
106
- # Normalize embeddings
107
- query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
108
- embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
109
-
110
- self.clear_cuda_cache()
111
- return embeddings_list, metadata
112
-
113
- def extract_metadata(self, metadata_output: str):
114
- # Regex pattern to extract key-value pairs
115
- pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
116
- matches = pattern.findall(metadata_output)
117
- metadata = {key: value for key, value in matches}
118
- return metadata
119
-
120
- class MyEmbeddingFunction(EmbeddingFunction):
121
- def __init__(self, model_name: str, token: str, intention_client):
122
- self.model_name = model_name
123
- self.token = token
124
- self.intention_client = intention_client
125
-
126
- def create_embedding_generator(self):
127
- return EmbeddingGenerator(self.model_name, self.token, self.intention_client)
128
-
129
- def __call__(self, input: Documents) -> (Embeddings, list):
130
- embedding_generator = self.create_embedding_generator()
131
- embeddings_with_metadata = [embedding_generator.compute_embeddings(doc.page_content) for doc in input]
132
- embeddings = [item[0] for item in embeddings_with_metadata]
133
- metadata = [item[1] for item in embeddings_with_metadata]
134
- embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
135
- metadata_flattened = [meta for sublist in metadata for meta in sublist]
136
- return embeddings_flattened, metadata_flattened
137
-
138
- def load_documents(file_path: str, mode: str = "elements"):
139
- loader = UnstructuredFileLoader(file_path, mode=mode)
140
- docs = loader.load()
141
- return [doc.page_content for doc in docs]
142
-
143
- def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
144
- db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
145
- return db
146
-
147
- def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction):
148
- for doc in documents:
149
- embeddings, metadata = embedding_function.create_embedding_generator().compute_embeddings(doc)
150
- for embedding, meta in zip(embeddings, metadata):
151
- chroma_collection.add(
152
- ids=[str(uuid.uuid1())],
153
- documents=[doc],
154
- embeddings=[embedding],
155
- metadatas=[meta]
156
- )
157
-
158
- def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction):
159
- query_embeddings, query_metadata = embedding_function.create_embedding_generator().compute_embeddings(query_text)
160
- result_docs = chroma_collection.query(
161
- query_texts=[query_text],
162
- n_results=2
163
- )
164
- return result_docs
165
-
166
- # Initialize clients
167
- intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
168
- embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
169
- embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client)
170
- chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
171
-
172
- def respond(
173
- message,
174
- history: list[tuple[str, str]],
175
- system_message,
176
- max_tokens,
177
- temperature,
178
- top_p,
179
- ):
180
- retrieved_text = query_documents(message)
181
- messages = [{"role": "system", "content": escape_special_characters(system_message)}]
182
- for val in history:
183
- if val[0]:
184
- messages.append({"role": "user", "content": val[0]})
185
- if val[1]:
186
- messages.append({"role": "assistant", "content": val[1]})
187
- messages.append({"role": "user", "content": f"{retrieved_text}\n\n{escape_special_characters(message)}"})
188
- response = ""
189
- for message in intention_client.chat_completion(
190
- messages,
191
- max_tokens=max_tokens,
192
- stream=True,
193
- temperature=temperature,
194
- top_p=top_p,
195
- ):
196
- token = message.choices[0].delta.content
197
- response += token
198
- yield response
199
-
200
- def upload_documents(files):
201
- for file in files:
202
- loader = UnstructuredFileLoader(file.name)
203
- documents = loader.load()
204
- add_documents_to_chroma(documents, embedding_function)
205
- return "Documents uploaded and processed successfully!"
206
-
207
- def query_documents(query):
208
- results = query_chroma(query, embedding_function)
209
- return "\n\n".join([result.content for result in results])
210
-
211
- with gr.Blocks() as demo:
212
- with gr.Tab("Upload Documents"):
213
- document_upload = gr.File(file_count="multiple", file_types=["document"])
214
- upload_button = gr.Button("Upload and Process")
215
- upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
216
-
217
- with gr.Tab("Ask Questions"):
218
- with gr.Row():
219
- chat_interface = gr.ChatInterface(
220
- respond,
221
- additional_inputs=[
222
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
223
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
224
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
225
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
226
- ],
227
- )
228
- query_input = gr.Textbox(label="Query")
229
- query_button = gr.Button("Query")
230
- query_output = gr.Textbox()
231
- query_button.click(query_documents, inputs=query_input, outputs=query_output)
232
-
233
- if __name__ == "__main__":
234
- # os.system("chroma run --host localhost --port 8000 &")
235
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
globalvars.py CHANGED
@@ -3,7 +3,7 @@
3
  API_BASE = "https://api.01.ai/v1"
4
  API_KEY = "your key"
5
 
6
- model_name = 'nvidia/NV-Embed-v1'
7
 
8
  title = """
9
  # 👋🏻Welcome to 🙋🏻‍♂️Tonic's 📽️Nvidia 🛌🏻Embed V-1 !"""
@@ -15,76 +15,47 @@ Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder
15
  """
16
 
17
  tasks = {
18
- 'ClimateFEVER': 'Given a claim about climate change, retrieve documents that support or refute the claim',
19
- 'DBPedia': 'Given a query, retrieve relevant entity descriptions from DBPedia',
20
- 'FEVER': 'Given a claim, retrieve documents that support or refute the claim',
21
- 'FiQA2018': 'Given a financial question, retrieve user replies that best answer the question',
22
- 'HotpotQA': 'Given a multi-hop question, retrieve documents that can help answer the question',
23
- 'MSMARCO': 'Given a web search query, retrieve relevant passages that answer the query',
24
- 'NFCorpus': 'Given a question, retrieve relevant documents that best answer the question',
25
- 'NQ': 'Given a question, retrieve Wikipedia passages that answer the question',
26
- 'QuoraRetrieval': 'Given a question, retrieve questions that are semantically equivalent to the given question',
27
- 'SCIDOCS': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
28
- 'DEFAULT': 'Given a query, retrieve relevant entity descriptions from DBPedia',
29
  }
30
 
31
- intention_prompt= """
 
32
  "type": "object",
33
  "properties": {
34
- "ClimateFEVER": {
35
  "type": "boolean",
36
- "description" : "select this for climate science related text"
37
  },
38
- "DBPedia": {
39
  "type": "boolean",
40
- "description" : "select this for encyclopedic related knowledge"
41
  },
42
- "FEVER": {
43
  "type": "boolean",
44
- "description": "select this to verify a claim or embed a claim"
45
  },
46
- "FiQA2018": {
47
  "type": "boolean",
48
- "description" : "select this for financial questions or topics"
49
  },
50
- "HotpotQA": {
51
  "type": "boolean",
52
- "description" : "select this for a multi-hop question or for texts that provide multihop claims"
53
- },
54
- "MSMARCO": {
55
- "type": "boolean",
56
- "description": "Given a web search query, retrieve relevant passages that answer the query"
57
- },
58
- "NFCorpus": {
59
- "type": "boolean",
60
- "description" : "Given a question, retrieve relevant documents that best answer the question"
61
- },
62
- "NQ": {
63
- "type": "boolean",
64
- "description" : "Given a question, retrieve Wikipedia passages that answer the question"
65
- },
66
- "QuoraRetrieval": {
67
- "type": "boolean",
68
- "description": "Given a question, retrieve questions that are semantically equivalent to the given question"
69
- },
70
- "SCIDOCS": {
71
- "type": "boolean",
72
- "description": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper"
73
  }
74
  },
75
  "required": [
76
- "ClimateFEVER",
77
- "DBPedia",
78
- "FEVER",
79
- "FiQA2018",
80
- "HotpotQA",
81
- "MSMARCO",
82
- "NFCorpus",
83
- "NQ",
84
- "QuoraRetrieval",
85
- "SCIDOCS",
86
  ]
87
- produce a complete json schema."
88
 
89
  you will recieve a text , classify the text according to the schema above. ONLY PROVIDE THE FINAL JSON , DO NOT PRODUCE ANY ADDITION INSTRUCTION :"""
90
 
 
3
  API_BASE = "https://api.01.ai/v1"
4
  API_KEY = "your key"
5
 
6
+ model_name = "jinaai/jina-embeddings-v3"
7
 
8
  title = """
9
  # 👋🏻Welcome to 🙋🏻‍♂️Tonic's 📽️Nvidia 🛌🏻Embed V-1 !"""
 
15
  """
16
 
17
  tasks = {
18
+ 'retrieval.query': 'Used for query embeddings in asymmetric retrieval tasks',
19
+ 'retrieval.passage': 'Used for passage embeddings in asymmetric retrieval tasks',
20
+ 'separation': 'Used for embeddings in clustering and re-ranking applications',
21
+ 'classification': 'Used for embeddings in classification tasks',
22
+ 'text-matching': 'Used for embeddings in tasks that quantify similarity between two texts, such as STS or symmetric retrieval tasks',
23
+ 'DEFAULT': 'Used for general-purpose embeddings when no specific task is specified'
 
 
 
 
 
24
  }
25
 
26
+ intention_prompt = """
27
+ {
28
  "type": "object",
29
  "properties": {
30
+ "retrieval.query": {
31
  "type": "boolean",
32
+ "description": "Select this for query embeddings in asymmetric retrieval tasks"
33
  },
34
+ "retrieval.passage": {
35
  "type": "boolean",
36
+ "description": "Select this for passage embeddings in asymmetric retrieval tasks"
37
  },
38
+ "separation": {
39
  "type": "boolean",
40
+ "description": "Select this for embeddings in clustering and re-ranking applications"
41
  },
42
+ "classification": {
43
  "type": "boolean",
44
+ "description": "Select this for embeddings in classification tasks"
45
  },
46
+ "text-matching": {
47
  "type": "boolean",
48
+ "description": "Select this for embeddings in tasks that quantify similarity between two texts, such as STS or symmetric retrieval tasks"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  }
50
  },
51
  "required": [
52
+ "retrieval.query",
53
+ "retrieval.passage",
54
+ "separation",
55
+ "classification",
56
+ "text-matching"
 
 
 
 
 
57
  ]
58
+ }
59
 
60
  you will recieve a text , classify the text according to the schema above. ONLY PROVIDE THE FINAL JSON , DO NOT PRODUCE ANY ADDITION INSTRUCTION :"""
61
 
langchainapp.py DELETED
@@ -1,243 +0,0 @@
1
- # app.py
2
- import spaces
3
- from torch.nn import DataParallel
4
- from torch import Tensor
5
- from transformers import AutoTokenizer, AutoModel
6
- from huggingface_hub import InferenceClient
7
- from openai import OpenAI
8
- from langchain_community.embeddings import HuggingFaceInstructEmbeddings
9
- from langchain_community.document_loaders import UnstructuredFileLoader
10
- from langchain_chroma import Chroma
11
- from chromadb import Documents, EmbeddingFunction, Embeddings
12
- from chromadb.config import Settings
13
- import chromadb #import HttpClient
14
- from typing import List, Tuple, Dict, Any
15
- import os
16
- import re
17
- import uuid
18
- import gradio as gr
19
- import torch
20
- import torch.nn.functional as F
21
- from dotenv import load_dotenv
22
- from utils import load_env_variables, parse_and_route , escape_special_characters
23
- from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
24
- # import time
25
- # import httpx
26
-
27
- from langchain_community.chat_models import ChatOpenAI
28
- from langchain.retrievers.document_compressors import LLMChainExtractor
29
- from langchain.retrievers.multi_query import MultiQueryRetriever
30
- from langchain.retrievers import ContextualCompressionRetriever
31
- from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
32
- # from langchain.vectorstores import Chroma
33
-
34
-
35
-
36
- load_dotenv()
37
-
38
- os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
39
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
40
- os.environ['CUDA_CACHE_DISABLE'] = '1'
41
-
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
- ### Utils
44
- hf_token, yi_token = load_env_variables()
45
-
46
- # tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True)
47
- # Lazy load model
48
- model = None
49
-
50
- @spaces.GPU
51
- def load_model():
52
- global model
53
- if model is None:
54
- from transformers import AutoModel
55
- model = AutoModel.from_pretrained(model_name, token=hf_token, trust_remote_code=True).to(device)
56
- return model
57
-
58
- # Load model
59
- nvidiamodel = load_model()
60
- # nvidiamodel.set_pooling_include_prompt(include_prompt=False)
61
-
62
- def clear_cuda_cache():
63
- torch.cuda.empty_cache()
64
-
65
- client = OpenAI(api_key=yi_token, base_url=API_BASE)
66
-
67
- chroma_client = chromadb.Client(Settings())
68
-
69
- # Create a collection
70
- chroma_collection = chroma_client.create_collection("all-my-documents")
71
-
72
- @spaces.GPU
73
- class MyEmbeddingFunction(EmbeddingFunction):
74
- def __init__(self, model_name: str, token: str, intention_client):
75
- self.model_name = model_name
76
- self.token = token
77
- self.intention_client = intention_client
78
- self.hf_embeddings = HuggingFaceInstructEmbeddings(
79
- model_name=model_name,
80
- model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'},
81
- encode_kwargs={'normalize_embeddings': True}
82
- )
83
-
84
- def create_embedding_generator(self):
85
- return self.hf_embeddings
86
-
87
- def __call__(self, input: Documents) -> (List[List[float]], List[Dict[str, Any]]):
88
- embeddings_with_metadata = [self.compute_embeddings(doc.page_content) for doc in input]
89
- embeddings = [item[0] for item in embeddings_with_metadata]
90
- metadata = [item[1] for item in embeddings_with_metadata]
91
- embeddings_flattened = [emb for sublist in embeddings for emb in sublist]
92
- metadata_flattened = [meta for sublist in metadata for meta in sublist]
93
- return embeddings_flattened, metadata_flattened
94
-
95
- @spaces.GPU
96
- def compute_embeddings(self, input_text: str):
97
- escaped_input_text = escape_special_characters(input_text)
98
-
99
- # Get the intention
100
- intention_completion = self.intention_client.chat.completions.create(
101
- model="yi-large",
102
- messages=[
103
- {"role": "system", "content": escape_special_characters(intention_prompt)},
104
- {"role": "user", "content": escaped_input_text}
105
- ]
106
- )
107
- intention_output = intention_completion.choices[0].message.content
108
- parsed_task = parse_and_route(intention_output)
109
- selected_task = parsed_task if parsed_task in tasks else "DEFAULT"
110
- task_description = tasks[selected_task]
111
- # query_prefix = "Instruct: " +tasks[selected_task] +"\nQuery: "
112
- # Construct the embed_instruction and query_instruction dynamically
113
- embed_instruction = f"Instruct: {task_description}" +"\nQuery:"
114
- # query_instruction = f""
115
-
116
- # Update the hf_embeddings object with the new instructions
117
- self.hf_embeddings.embed_instruction = embed_instruction
118
- # self.hf_embeddings.query_instruction = query_instruction
119
-
120
- # Get the metadata
121
- metadata_completion = self.intention_client.chat.completions.create(
122
- model="yi-large",
123
- messages=[
124
- {"role": "system", "content": escape_special_characters(metadata_prompt)},
125
- {"role": "user", "content": escaped_input_text}
126
- ]
127
- )
128
- metadata_output = metadata_completion.choices[0].message.content
129
- metadata = self.extract_metadata(metadata_output)
130
-
131
- # Get the embeddings
132
- embeddings = self.hf_embeddings.embed_documents([escaped_input_text])
133
- return embeddings[0], metadata
134
-
135
- def extract_metadata(self, metadata_output: str) -> Dict[str, str]:
136
- pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
137
- matches = pattern.findall(metadata_output)
138
- metadata = {key: value for key, value in matches}
139
- return metadata
140
-
141
- def load_documents(file_path: str, mode: str = "elements"):
142
- loader = UnstructuredFileLoader(file_path, mode=mode)
143
- docs = loader.load()
144
- return [doc.page_content for doc in docs]
145
-
146
- def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
147
- db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
148
- return db
149
-
150
- def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction):
151
- for doc in documents:
152
- embeddings, metadata = embedding_function.compute_embeddings(doc)
153
- for embedding, meta in zip(embeddings, metadata):
154
- chroma_collection.add(
155
- ids=[str(uuid.uuid1())],
156
- documents=[doc],
157
- embeddings=[embedding],
158
- metadatas=[meta]
159
- )
160
-
161
- def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction):
162
- model = load_model()
163
- query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text)
164
- result_docs = chroma_collection.query(
165
- query_texts=[query_text],
166
- n_results=3
167
- )
168
- return result_docs
169
-
170
-
171
- def answer_query(message: str, chat_history: List[Tuple[str, str]]):
172
- base_compressor = LLMChainExtractor.from_llm(intention_client)
173
- db = Chroma(persist_directory="output/general_knowledge", embedding_function=embedding_function)
174
- base_retriever = db.as_retriever()
175
- mq_retriever = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=intention_client)
176
- compression_retriever = ContextualCompressionRetriever(base_compressor=base_compressor, base_retriever=mq_retriever)
177
-
178
- matched_docs = compression_retriever.get_relevant_documents(query=message)
179
- context = ""
180
- for doc in matched_docs:
181
- page_content = doc.page_content
182
- context += page_content
183
- context += "\n\n"
184
-
185
- template = """
186
- Answer the following question only by using the context given below in the triple backticks, do not use any other information to answer the question.
187
- If you can't answer the given question with the given context, you can return an empty string ('')
188
- Context: ```{context}```
189
- ----------------------------
190
- Question: {query}
191
- ----------------------------
192
- Answer: """
193
-
194
- human_message_prompt = HumanMessagePromptTemplate.from_template(template=template)
195
- chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
196
- prompt = chat_prompt.format_prompt(query=message, context=context)
197
- response = intention_client.chat(messages=prompt.to_messages()).content
198
- chat_history.append((message, response))
199
- return "", chat_history
200
-
201
-
202
- # Initialize clients
203
- intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
204
- embedding_function = MyEmbeddingFunction(model_name=model_name, token=hf_token, intention_client=intention_client)
205
- chroma_db = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function)
206
-
207
- def upload_documents(files):
208
- for file in files:
209
- loader = UnstructuredFileLoader(file.name)
210
- documents = loader.load()
211
- add_documents_to_chroma(documents, embedding_function)
212
- return "Documents uploaded and processed successfully!"
213
-
214
- def query_documents(query):
215
- model = load_model()
216
- results = query_chroma(query)
217
- return "\n\n".join([result.content for result in results])
218
-
219
- with gr.Blocks() as demo:
220
- with gr.Tab("Upload Documents"):
221
- document_upload = gr.File(file_count="multiple", file_types=["document"])
222
- upload_button = gr.Button("Upload and Process")
223
- upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
224
-
225
- with gr.Tab("Ask Questions"):
226
- with gr.Row():
227
- chat_interface = gr.ChatInterface(
228
- answer_query,
229
- additional_inputs=[
230
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
231
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
232
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
233
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
234
- ],
235
- )
236
- query_input = gr.Textbox(label="Query")
237
- query_button = gr.Button("Query")
238
- query_output = gr.Textbox()
239
- query_button.click(query_documents, inputs=query_input, outputs=query_output)
240
-
241
- if __name__ == "__main__":
242
- # os.system("chroma run --host localhost --port 8000 &")
243
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yijinaembed.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import re
4
+ import uuid
5
+ import gradio as gr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from dotenv import load_dotenv
9
+ from typing import List, Tuple, Dict, Any
10
+ from transformers import AutoTokenizer, AutoModel
11
+ from openai import OpenAI
12
+ from langchain_community.document_loaders import UnstructuredFileLoader
13
+ from langchain_chroma import Chroma
14
+ from chromadb import Documents, EmbeddingFunction, Embeddings
15
+ from chromadb.config import Settings
16
+ import chromadb
17
+ from utils import load_env_variables, parse_and_route, escape_special_characters
18
+ from globalvars import API_BASE, intention_prompt, tasks, system_message, metadata_prompt, model_name
19
+ import spaces
20
+ from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
21
+ from langchain_community.document_compressors.jina_rerank import JinaRerank
22
+ from langchain import hub
23
+ from langchain.chains import create_retrieval_chain
24
+ from langchain.chains.retrieval import create_stuff_documents_chain
25
+
26
+ load_dotenv()
27
+
28
+ # os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:180'
29
+ # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
30
+ # os.environ['CUDA_CACHE_DISABLE'] = '1'
31
+
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+
34
+ hf_token, yi_token = load_env_variables()
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token, trust_remote_code=True)
37
+ model = None
38
+
39
+ @spaces.GPU
40
+ def load_model():
41
+ global model
42
+ if model is None:
43
+ model = AutoModel.from_pretrained(model_name, token=hf_token, trust_remote_code=True).to(device)
44
+ return model
45
+
46
+ # Load model
47
+ jina_model = load_model()
48
+
49
+ def clear_cuda_cache():
50
+ torch.cuda.empty_cache()
51
+
52
+ client = OpenAI(api_key=yi_token, base_url=API_BASE)
53
+
54
+ chroma_client = chromadb.Client(Settings())
55
+
56
+ chroma_collection = chroma_client.create_collection("all-my-documents")
57
+
58
+ class JinaEmbeddingFunction(EmbeddingFunction):
59
+ def __init__(self, model, tokenizer, intention_client):
60
+ self.model = model
61
+ self.tokenizer = tokenizer
62
+ self.intention_client = intention_client
63
+
64
+ def __call__(self, input: Documents) -> Tuple[List[List[float]], List[Dict[str, Any]]]:
65
+ embeddings_with_metadata = [self.compute_embeddings(doc) for doc in input]
66
+ embeddings = [item[0] for item in embeddings_with_metadata]
67
+ metadata = [item[1] for item in embeddings_with_metadata]
68
+ return embeddings, metadata
69
+
70
+ @spaces.GPU
71
+ def compute_embeddings(self, input_text: str):
72
+ escaped_input_text = escape_special_characters(input_text)
73
+
74
+ # Get the intention
75
+ intention_completion = self.intention_client.chat.completions.create(
76
+ model="yi-large",
77
+ messages=[
78
+ {"role": "system", "content": escape_special_characters(intention_prompt)},
79
+ {"role": "user", "content": escaped_input_text}
80
+ ]
81
+ )
82
+ intention_output = intention_completion.choices[0].message.content
83
+ parsed_task = parse_and_route(intention_output)
84
+ selected_task = parsed_task if parsed_task in tasks else "DEFAULT"
85
+ task = tasks[selected_task]
86
+
87
+ # Get the metadata
88
+ metadata_completion = self.intention_client.chat.completions.create(
89
+ model="yi-large",
90
+ messages=[
91
+ {"role": "system", "content": escape_special_characters(metadata_prompt)},
92
+ {"role": "user", "content": escaped_input_text}
93
+ ]
94
+ )
95
+ metadata_output = metadata_completion.choices[0].message.content
96
+ metadata = self.extract_metadata(metadata_output)
97
+
98
+ # Compute embeddings using Jina model
99
+ encoded_input = self.tokenizer(escaped_input_text, padding=True, truncation=True, return_tensors="pt").to(device)
100
+ with torch.no_grad():
101
+ model_output = self.model(**encoded_input, task=task)
102
+
103
+ embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])
104
+ embeddings = F.normalize(embeddings, p=2, dim=1)
105
+
106
+ return embeddings.cpu().numpy().tolist()[0], metadata
107
+
108
+ def extract_metadata(self, metadata_output: str) -> Dict[str, str]:
109
+ pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
110
+ matches = pattern.findall(metadata_output)
111
+ metadata = {key: value for key, value in matches}
112
+ return metadata
113
+
114
+ @staticmethod
115
+ def mean_pooling(model_output, attention_mask):
116
+ token_embeddings = model_output[0]
117
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
118
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
119
+
120
+ def load_documents(file_path: str, mode: str = "elements"):
121
+ loader = UnstructuredFileLoader(file_path, mode=mode)
122
+ docs = loader.load()
123
+ return [doc.page_content for doc in docs]
124
+
125
+ def initialize_chroma(collection_name: str, embedding_function: JinaEmbeddingFunction):
126
+ db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
127
+ return db
128
+
129
+ @spaces.GPU
130
+ def add_documents_to_chroma(documents: list, embedding_function: JinaEmbeddingFunction):
131
+ for doc in documents:
132
+ embeddings, metadata = embedding_function.compute_embeddings(doc)
133
+ chroma_collection.add(
134
+ ids=[str(uuid.uuid1())],
135
+ documents=[doc],
136
+ embeddings=[embeddings],
137
+ metadatas=[metadata]
138
+ )
139
+
140
+ @spaces.GPU
141
+ def rerank_documents(query: str, documents: List[str]) -> List[str]:
142
+ compressor = JinaRerank()
143
+ retriever = chroma_db.as_retriever(search_kwargs={"k": 20})
144
+ compression_retriever = ContextualCompressionRetriever(
145
+ base_compressor=compressor, base_retriever=retriever
146
+ )
147
+
148
+ compressed_docs = compression_retriever.get_relevant_documents(query)
149
+
150
+ return [doc.page_content for doc in compressed_docs]
151
+
152
+ def query_chroma(query_text: str, embedding_function: JinaEmbeddingFunction):
153
+ query_embeddings, query_metadata = embedding_function.compute_embeddings(query_text)
154
+ result_docs = chroma_collection.query(
155
+ query_embeddings=[query_embeddings],
156
+ n_results=3
157
+ )
158
+ return result_docs
159
+
160
+ @spaces.GPU
161
+ def answer_query(message: str, chat_history: List[Tuple[str, str]], system_message: str, max_new_tokens: int, temperature: float, top_p: float):
162
+ # Query Chroma for relevant documents
163
+ results = query_chroma(message, embedding_function)
164
+ context = "\n\n".join([result['document'] for result in results['documents'][0]])
165
+
166
+ # Rerank the documents
167
+ reranked_docs = rerank_documents(message, context.split("\n\n"))
168
+ reranked_context = "\n\n".join(reranked_docs)
169
+
170
+ # Prepare the prompt for YI model
171
+ prompt = f"{system_message}\n\nContext: {reranked_context}\n\nHuman: {message}\n\nAssistant:"
172
+
173
+ # Generate response using YI model
174
+ response = client.chat.completions.create(
175
+ model="yi-large",
176
+ messages=[
177
+ {"role": "system", "content": system_message},
178
+ {"role": "user", "content": f"Context: {reranked_context}\n\nHuman: {message}"}
179
+ ],
180
+ max_tokens=max_new_tokens,
181
+ temperature=temperature,
182
+ top_p=top_p
183
+ )
184
+
185
+ assistant_response = response.choices[0].message.content
186
+ chat_history.append((message, assistant_response))
187
+ return "", chat_history
188
+
189
+ # Initialize clients
190
+ intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
191
+ embedding_function = JinaEmbeddingFunction(jina_model, tokenizer, intention_client)
192
+ chroma_db = initialize_chroma(collection_name="Jina-embeddings", embedding_function=embedding_function)
193
+
194
+ @spaces.GPU
195
+ def upload_documents(files):
196
+ for file in files:
197
+ loader = UnstructuredFileLoader(file.name)
198
+ documents = loader.load()
199
+ add_documents_to_chroma([doc.page_content for doc in documents], embedding_function)
200
+ return "Documents uploaded and processed successfully!"
201
+
202
+ @spaces.GPU
203
+ def query_documents(query):
204
+ results = query_chroma(query, embedding_function)
205
+ reranked_docs = rerank_documents(query, [result for result in results['documents'][0]])
206
+ return "\n\n".join(reranked_docs)
207
+
208
+ with gr.Blocks() as demo:
209
+ with gr.Tab("Upload Documents"):
210
+ document_upload = gr.File(file_count="multiple", file_types=["document"])
211
+ upload_button = gr.Button("Upload and Process")
212
+ upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text())
213
+
214
+ with gr.Tab("Ask Questions"):
215
+ with gr.Row():
216
+ chat_interface = gr.ChatInterface(
217
+ answer_query,
218
+ additional_inputs=[
219
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
220
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
221
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
222
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
223
+ ],
224
+ )
225
+ query_input = gr.Textbox(label="Query")
226
+ query_button = gr.Button("Query")
227
+ query_output = gr.Textbox()
228
+ query_button.click(query_documents, inputs=query_input, outputs=query_output)
229
+
230
+ if __name__ == "__main__":
231
+ demo.launch()