johannoriel commited on
Commit
645a356
1 Parent(s): d7dc2a6
Files changed (2) hide show
  1. plugins/ragllm.py +76 -19
  2. requirements.txt +1 -1
plugins/ragllm.py CHANGED
@@ -10,10 +10,18 @@ from typing import List, Dict, Any
10
  import requests
11
  import torch
12
  from transformers import AutoTokenizer, AutoModel
 
 
13
 
14
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  MAX_LENGTH = 512
16
- CHUNK_SIZE = 200 # Nombre de mots par chunk
 
 
 
 
 
 
17
 
18
  def mean_pooling(model_output, attention_mask):
19
  token_embeddings = model_output[0]
@@ -43,6 +51,7 @@ translations["en"].update({
43
  "rag_error_fetching_models_ollama": "Error fetching Ollama models: ",
44
  "rag_error_calling_llm": "Error calling LLM: ",
45
  "rag_processing" : "Processing...",
 
46
  })
47
 
48
  translations["fr"].update({
@@ -67,28 +76,36 @@ translations["fr"].update({
67
  "rag_error_fetching_models_ollama": "Erreur lors de la récupération des modèles Ollama : ",
68
  "rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
69
  "rag_processing" : "En cours de traitement...",
 
70
  })
71
 
72
  class RagllmPlugin(Plugin):
73
  def __init__(self, name: str, plugin_manager):
74
  super().__init__(name, plugin_manager)
75
- self.config = self.load_llm_config()
 
 
 
76
  self.embeddings = None
77
  self.chunks = None
 
78
 
79
  def load_llm_config(self) -> Dict:
80
- with open('.llm-config.yml', 'r') as file:
81
- return yaml.safe_load(file)
 
 
 
82
 
83
  def get_tabs(self):
84
  return [{"name": "RAG", "plugin": "ragllm"}]
85
 
86
  def get_config_fields(self):
87
- return {
88
  "provider": {
89
  "type": "select",
90
  "label": t("rag_model_provider"),
91
- "options": [("ollama", "Ollama"), ("groq", "Groq")],
92
  "default": "ollama"
93
  },
94
  "llm_model": {
@@ -132,6 +149,15 @@ class RagllmPlugin(Plugin):
132
  "default": 3
133
  }
134
  }
 
 
 
 
 
 
 
 
 
135
 
136
  def get_config_ui(self, config):
137
  updated_config = {}
@@ -201,6 +227,8 @@ class RagllmPlugin(Plugin):
201
  return ["ollama/qwen2"]
202
  elif provider == 'groq':
203
  return ["groq/llama3-70b-8192", "groq/mixtral-8x7b-32768"]
 
 
204
  else:
205
  return ["none"]
206
 
@@ -211,12 +239,23 @@ class RagllmPlugin(Plugin):
211
  self.embeddings = np.vstack([self.get_embedding(c, embedder) for c in self.chunks])
212
 
213
  def get_embedding(self, text: str, model: str) -> np.ndarray:
214
- tokenizer = AutoTokenizer.from_pretrained(model)
215
- model = AutoModel.from_pretrained(model, trust_remote_code=True).to(DEVICE)
216
- inputs = tokenizer(text, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)
217
- with torch.no_grad():
218
- model_output = model(**inputs)
219
- return mean_pooling(model_output, inputs['attention_mask']).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  def calculate_similarity(self, query_embedding: np.ndarray, method: str) -> np.ndarray:
222
  if method == 'cosine':
@@ -238,13 +277,31 @@ class RagllmPlugin(Plugin):
238
  def call_llm(self, prompt: str, sysprompt: str) -> str:
239
  try:
240
  llm_model = st.session_state.ragllm_llm_model
241
- #print(f"---------------------------------------\nCalling LLM {llm_model} \n with sysprompt {sysprompt} \n and prompt {prompt} \n and context len of {len(context)}")
242
- messages = [
243
- {"role": "system", "content": sysprompt},
244
- {"role": "user", "content": prompt}
245
- ]
246
- response = completion(model=llm_model, messages=messages)
247
- return response['choices'][0]['message']['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  except Exception as e:
249
  return f"{t('rag_error_calling_llm')}{str(e)}"
250
 
 
10
  import requests
11
  import torch
12
  from transformers import AutoTokenizer, AutoModel
13
+ from huggingface_hub import InferenceClient
14
+ from langchain_huggingface import HuggingFaceEmbeddings
15
 
16
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
17
  MAX_LENGTH = 512
18
+ CHUNK_SIZE = 200
19
+
20
+ def mean_pooling(model_output, attention_mask):
21
+ token_embeddings = model_output[0]
22
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
23
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
24
+
25
 
26
  def mean_pooling(model_output, attention_mask):
27
  token_embeddings = model_output[0]
 
51
  "rag_error_fetching_models_ollama": "Error fetching Ollama models: ",
52
  "rag_error_calling_llm": "Error calling LLM: ",
53
  "rag_processing" : "Processing...",
54
+ "rag_hf_api_key": "HuggingFace API Token",
55
  })
56
 
57
  translations["fr"].update({
 
76
  "rag_error_fetching_models_ollama": "Erreur lors de la récupération des modèles Ollama : ",
77
  "rag_error_calling_llm": "Erreur lors de l'appel au LLM : ",
78
  "rag_processing" : "En cours de traitement...",
79
+ "rag_hf_api_key": "Token API HuggingFace",
80
  })
81
 
82
  class RagllmPlugin(Plugin):
83
  def __init__(self, name: str, plugin_manager):
84
  super().__init__(name, plugin_manager)
85
+ try:
86
+ self.config = self.load_llm_config()
87
+ except:
88
+ self.config = {}
89
  self.embeddings = None
90
  self.chunks = None
91
+ self.hf_client = None
92
 
93
  def load_llm_config(self) -> Dict:
94
+ try:
95
+ with open('.llm-config.yml', 'r') as file:
96
+ return yaml.safe_load(file)
97
+ except:
98
+ return {}
99
 
100
  def get_tabs(self):
101
  return [{"name": "RAG", "plugin": "ragllm"}]
102
 
103
  def get_config_fields(self):
104
+ fields = {
105
  "provider": {
106
  "type": "select",
107
  "label": t("rag_model_provider"),
108
+ "options": [("ollama", "Ollama"), ("groq", "Groq"), ("huggingface", "HuggingFace")],
109
  "default": "ollama"
110
  },
111
  "llm_model": {
 
149
  "default": 3
150
  }
151
  }
152
+ # Add HuggingFace API key field if provider is huggingface
153
+ if 'provider' in self.config and self.config.get('provider') == 'huggingface':
154
+ fields["hf_api_key"] = {
155
+ "type": "password",
156
+ "label": t("rag_hf_api_key"),
157
+ "default": ""
158
+ }
159
+
160
+ return fields
161
 
162
  def get_config_ui(self, config):
163
  updated_config = {}
 
227
  return ["ollama/qwen2"]
228
  elif provider == 'groq':
229
  return ["groq/llama3-70b-8192", "groq/mixtral-8x7b-32768"]
230
+ elif provider == 'huggingface':
231
+ return ["HuggingFaceH4/zephyr-7b-beta"]
232
  else:
233
  return ["none"]
234
 
 
239
  self.embeddings = np.vstack([self.get_embedding(c, embedder) for c in self.chunks])
240
 
241
  def get_embedding(self, text: str, model: str) -> np.ndarray:
242
+ if self.config.get('provider') == 'huggingface':
243
+ if not hasattr(self, 'hf_embeddings'):
244
+ self.hf_embeddings = HuggingFaceEmbeddings(
245
+ model_name=model,
246
+ task="feature-extraction",
247
+ encode_kwargs={'normalize': True}
248
+ )
249
+ embedding = self.hf_embeddings.embed_query(text)
250
+ return np.array(embedding).reshape(1, -1)
251
+ else:
252
+ # Original embedding logic
253
+ tokenizer = AutoTokenizer.from_pretrained(model)
254
+ model = AutoModel.from_pretrained(model, trust_remote_code=True).to(DEVICE)
255
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt").to(DEVICE)
256
+ with torch.no_grad():
257
+ model_output = model(**inputs)
258
+ return mean_pooling(model_output, inputs['attention_mask']).cpu().numpy()
259
 
260
  def calculate_similarity(self, query_embedding: np.ndarray, method: str) -> np.ndarray:
261
  if method == 'cosine':
 
277
  def call_llm(self, prompt: str, sysprompt: str) -> str:
278
  try:
279
  llm_model = st.session_state.ragllm_llm_model
280
+ if self.config.get('provider') == 'huggingface':
281
+ if not self.hf_client:
282
+ self.hf_client = InferenceClient(token=self.config.get('hf_api_key'))
283
+
284
+ messages = [
285
+ {"role": "system", "content": sysprompt},
286
+ {"role": "user", "content": prompt}
287
+ ]
288
+
289
+ response = self.hf_client.text_generation(
290
+ model=llm_model,
291
+ prompt=prompt,
292
+ max_new_tokens=512,
293
+ temperature=0.7,
294
+ stream=False
295
+ )
296
+ return response
297
+ else:
298
+ messages = [
299
+ {"role": "system", "content": sysprompt},
300
+ {"role": "user", "content": prompt}
301
+ ]
302
+ response = completion(model=llm_model, messages=messages)
303
+ return response['choices'][0]['message']['content']
304
+
305
  except Exception as e:
306
  return f"{t('rag_error_calling_llm')}{str(e)}"
307
 
requirements.txt CHANGED
@@ -9,4 +9,4 @@ PyDictionary
9
  matplotlib
10
  litellm
11
  sentencepiece
12
-
 
9
  matplotlib
10
  litellm
11
  sentencepiece
12
+ langchain_huggingface