File size: 4,783 Bytes
e2e8616
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import chromadb
from src.tools.retriever import Retriever
from src.tools.llm import LlmAgent
from src.model.block import Block
from src.model.doc import Doc
from chromadb.utils import embedding_functions
import gradio as gr


class Chatbot:
    def __init__(self, llm_agent : LlmAgent = None, retriever: Retriever = None, client_db=None):
        self.retriever = retriever
        self.llm = llm_agent
        self.client_db = client_db

    def get_response(self, query, histo):
        histo_conversation, histo_queries = self._get_histo(histo)
        language_of_query = self.llm.detect_language_v2(query).lower()
        queries = self.llm.translate_v2(histo_queries)
        if "en" in language_of_query:
            language_of_query = "en"
        else:
            language_of_query = "fr"
        block_sources = self.retriever.similarity_search(queries=queries)
        block_sources = self._select_best_sources(block_sources)
        sources_contents = [f"Paragraph title : {s.title}\n-----\n{s.content}" if s.title else f"Paragraph {s.index}\n-----\n{s.content}" for s in block_sources]
        context = '\n'.join(sources_contents)
        i = 1
        while (len(context) + len(histo_conversation) > 15000) and i < len(sources_contents):
            context = "\n".join(sources_contents[:-i])
            i += 1
        answer = self.llm.generate_paragraph_v2(query=query, histo=histo_conversation, context=context, language=language_of_query)
        answer = self._clean_chatgpt_answer(answer)
        return answer, block_sources

    

    @staticmethod
    def  _select_best_sources(sources: [Block], delta_1_2=0.15, delta_1_n=0.3, absolute=1.2, alpha=0.9) -> [Block]:
        """
        Select the best sources: not far from the very best, not far from the last selected, and not too bad per se
        """
        best_sources = []
        for idx, s in enumerate(sources):
            if idx == 0 \
                    or (s.distance - sources[idx - 1].distance < delta_1_2
                        and s.distance - sources[0].distance < delta_1_n) \
                    or s.distance < absolute:
                best_sources.append(s)
                delta_1_2 *= alpha
                delta_1_n *= alpha
                absolute *= alpha
            else:
                break
        return best_sources
    

    @staticmethod
    def _get_histo(histo: [(str, str)]) -> (str, str):
        histo_conversation = ""
        histo_queries = ""

        for (query, answer) in histo[-5:]:
            histo_conversation += f'user: {query} \n bot: {answer}\n'
            histo_queries += query + '\n'
        return histo_conversation[:-1], histo_queries
    

    @staticmethod
    def _clean_answer(answer: str) -> str:
        print(answer)
        answer = answer.strip('bot:')
        while answer and answer[-1] in {"'", '"', " ", "`"}:
            answer = answer[:-1]
        while answer and answer[0] in {"'", '"', " ", "`"}:
            answer = answer[1:]
        answer = answer.strip('bot:')
        if answer:
            if answer[-1] != ".":
                answer += "."
        return answer
    
    def _clean_chatgpt_answer(self,answer: str) -> str:
        answer = answer.strip('bot:')
        answer = answer.strip('Answer:')
        answer = answer.strip('Réponse:')
        while answer and answer[-1] in {"'", '"', " ", "`"}:
            answer = answer[:-1]
        return answer
    
    def upload_doc(self,input_doc,include_images_,actual_page_start):
        title = Doc.get_title(Doc,input_doc.name)
        extension = title.split('.')[-1]
        if extension and (extension == 'docx' or extension == 'pdf' or extension == 'html'):
            open_ai_embedding = embedding_functions.OpenAIEmbeddingFunction(api_key=os.environ['OPENAI_API_KEY'], model_name="text-embedding-ada-002")
            coll_name = "".join([c if c.isalnum() else "_" for c in title])
            collection = self.client_db.get_or_create_collection(name=coll_name,embedding_function=open_ai_embedding)

            if collection.count() == 0:
                gr.Info("Please wait while your document is being analysed")
                print("Database is empty")
                doc = Doc(path=input_doc.name,include_images=include_images_,actual_first_page=actual_page_start)

                # for block in doc.blocks:  #DEBUG PART
                #     print(f"{block.index} : {block.content}")

                retriever = Retriever(doc.container, collection=collection,llmagent=self.llm)
            else:
                print("Database is not empty")
                retriever = Retriever(collection=collection,llmagent=self.llm)

            self.retriever = retriever
        else:
            return False
        return True