File size: 1,840 Bytes
99e744f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.chains import MapReduceDocumentsChain, ReduceDocumentsChain, LLMChain, StuffDocumentsChain
from langchain.prompts import PromptTemplate

def get_map_reduce_chain(pipeline_or_llm,model_type)-> LLMChain:

    if model_type == "openai":
        llm = pipeline_or_llm
        map_template  = """The following is a set of documents
            {docs}
            Based on this list of docs, please identify the main themes.
            Helpful Answer:"""
        map_prompt = PromptTemplate.from_template(map_template)
        reduce_template = """The following is set of summaries:
        {docs}
        Take these and distill into a final, consolidated summary of the main themes.
        Helpful Answer:"""
        reduce_prompt = PromptTemplate.from_template(reduce_template)

    else:
        map_prompt = PromptTemplate.from_template(template="{docs}")
        reduce_prompt = PromptTemplate.from_template(template="{docs}")
        llm = HuggingFacePipeline(pipeline=pipeline_or_llm)


    map_chain = LLMChain(llm = llm, prompt=map_prompt)
    reduce_chain = LLMChain(llm = llm, prompt = reduce_prompt,verbose = True)
    combine_documents_chain = StuffDocumentsChain(llm_chain=reduce_chain, document_variable_name="docs")
    reduce_documents_chain = ReduceDocumentsChain(
        combine_documents_chain=combine_documents_chain,
        collapse_documents_chain=combine_documents_chain,
        token_max=16384,
        verbose = True,
    )
    map_reduce_chain = MapReduceDocumentsChain(
        llm_chain=map_chain,
        reduce_documents_chain=reduce_documents_chain,
        document_variable_name="docs",
        return_intermediate_steps=False,
        verbose = True,
    )
        
    return map_reduce_chain