Jkalonji commited on
Commit
19ca331
1 Parent(s): 07b3f91

Creating RAG necessary functions

Browse files
Files changed (1) hide show
  1. RAG.py +78 -0
RAG.py CHANGED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.document_loaders import DirectoryLoader, PyPDFLoader
6
+
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.embeddings import SentenceTransformerEmbeddings
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.llms import HuggingFaceHub
11
+ from langchain.chains import RetrievalQA
12
+ from config import HUGGINGFACEHUB_API_TOKEN
13
+
14
+ from transformers import pipeline
15
+
16
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = HUGGINGFACEHUB_API_TOKEN
17
+
18
+ # Vous pouvez choisir parmi les nombreux midèles disponibles sur HugginFace (https://huggingface.co/models)
19
+ model_name = "llmware/industry-bert-insurance-v0.1"
20
+
21
+ def remove_special_characters(string):
22
+ return re.sub(r"\n", " ", string)
23
+
24
+
25
+ def RAG_Langchain(query):
26
+ embeddings = SentenceTransformerEmbeddings(model_name=model_name)
27
+ repo_id = "llmware/bling-sheared-llama-1.3b-0.1"
28
+
29
+ loader = DirectoryLoader('data/', glob="**/*.pdf", show_progress=True, loader_cls=PyPDFLoader)
30
+
31
+ documents = loader.load()
32
+
33
+ # La taille des chunks est un paramètre important pour la qualité de l'information retrouvée. Il existe plusieurs méthodes
34
+ # pour en choisir la valeur.
35
+ # L'overlap correspond au nombre de caractères partagés entre un chunk et le chunk suivant
36
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
37
+
38
+ texts = text_splitter.split_documents(documents)
39
+
40
+ chunk = texts[0]
41
+ chunk.page_content = remove_special_characters(chunk.page_content)
42
+
43
+ #Data Preparation
44
+ for chunks in texts:
45
+ chunks.page_content = remove_special_characters(chunks.page_content)
46
+
47
+ # On charge tous les documents dans la base de données vectorielle, pour les utiliser ensuite
48
+ vector_stores=Chroma.from_documents(texts, embeddings, collection_metadata = {"hnsw:space": "cosine"}, persist_directory="stores/insurance_cosine")
49
+
50
+ #Retrieval
51
+ load_vector_store=Chroma(persist_directory="stores/insurance_cosine", embedding_function=embeddings)
52
+
53
+ #On prend pour l'instant k=1, on verra plus tard comment sélectionner les résultats de contexte
54
+ docs = load_vector_store.similarity_search_with_score(query=query, k=1)
55
+ results = {"Score":[],"Content":[],"Metadata":[]};
56
+
57
+ for i in docs:
58
+ doc, score = i
59
+ #print({"Score":score, "Content":doc.page_content, "Metadata":doc.metadata})
60
+ results['Score'].append(score)
61
+ results['Content'].append(doc.page_content)
62
+ results['Metadata'].append(doc.metadata)
63
+
64
+ context = results['Content']
65
+
66
+ return results
67
+
68
+
69
+ def generateResponseBasedOnContext(model_name, context_string, query):
70
+
71
+ question_answerer = pipeline("question-answering", model=model_name)
72
+ context_prompt = "You are a sports expert. Answer the user's question by using following context: "
73
+
74
+ context = context_prompt + context_string
75
+ print("context : ", context)
76
+
77
+ result = question_answerer(question=query, context=context)
78
+ return result['answer']