File size: 4,329 Bytes
282b01b
f9af090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282b01b
 
f9af090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282b01b
f9af090
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import random


import json
from pathlib import Path
from pprint import pprint

import uuid
import chromadb
from chromadb.utils import embedding_functions


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer

models = {
    "wizardLM-7B-HF" : "TheBloke/wizardLM-7B-HF",
    "wizard-vicuna-13B-GPTQ" : "TheBloke/wizard-vicuna-13B-GPTQ",
    "Wizard-Vicuna-13B-Uncensored" : "ehartford/Wizard-Vicuna-13B-Uncensored",
    "WizardLM-13B" : "TheBloke/WizardLM-13B-V1.0-Uncensored-GPTQ",
    "Llama-2-7B" : "TheBloke/Llama-2-7b-Chat-GPTQ",
    "Vicuna-13B" : "TheBloke/vicuna-13B-v1.5-GPTQ",
    "WizardLM-13B-V1.2" : "TheBloke/WizardLM-13B-V1.2-GPTQ", # Trained from Llama-2 13b
    "Mistral-7B" : "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
}


model_name = "Mistral-7B"

tokenizer = AutoTokenizer.from_pretrained(models[model_name])


model = AutoModelForCausalLM.from_pretrained(models[model_name], 
                                             torch_dtype=torch.float16, 
                                             device_map="auto")


file_path='./data/faq_dataset.json'
data = json.loads(Path(file_path).read_text())


client = chromadb.Client()

emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="BAAI/bge-small-en-v1.5")

collection = client.create_collection(
    name="retrieval_qa",
    embedding_function=emb_fn,
    metadata={"hnsw:space": "cosine"} # l2 is the default
)

documents = [json.dumps(q) for q in data['questions']] # encode QnA as json strings for generating embeddings
metadatas = data['questions'] # retain QnA as dict in metadatas
ids = [str(uuid.uuid1()) for _ in documents]


collection.add(
    documents=documents,
    metadatas=metadatas,
    ids=ids
)

samples = [
    ["How can I return a product?"],
    ["What is the return policy?"],
    ["How can I contact customer support?"],
]


def respond(query):
    global samples
    docs = collection.query(query_texts=[query], n_results=3)
    chat = []
    related_questions = []
    references = "## References\n"

    system_message = "You are a helpful, respectful and honest support executive. Always be as helpfully as possible, while being correct. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. Use the following piece of context to answer the questions. If the information is not present in the provided context, answer that you don't know. Please don't share false information."

    for d in docs['metadatas'][0]:
        # prepare chat template
        system_message += f"\n Question: {d['question']} \n Answer: {d['answer']}"

        # Update references
        references += f"**{d['question']}**\n\n"
        references += f"> {d['answer']}\n\n"

        # Update related questions
        related_questions.append([d['question']])

    chat.append({"role": "system", "content": system_message})
    chat.append({"role": "user", "content": query})

    encodeds = tokenizer.apply_chat_template(chat, return_tensors="pt")

    model_inputs = encodeds.to(model.device)
    streamer = TextStreamer(tokenizer)

    model.to(model.device)

    generated_ids = model.generate(model_inputs, streamer=streamer, temperature=0.01, max_new_tokens=100, do_sample=True)
    answer = tokenizer.batch_decode(generated_ids[:, model_inputs.shape[1]:])[0]
    answer = answer.replace('</s>', '')
    samples = related_questions

    related = gr.Dataset.update(samples=related_questions)

    yield [answer, references, related]


def load_example(example_id):
    global samples
    return samples[example_id][0]


with gr.Blocks() as chatbot:
    with gr.Row():
        with gr.Column():
            answer_block = gr.Textbox(label="Answers", lines=2)
            question = gr.Textbox(label="Question")
            examples = gr.Dataset(samples=samples, components=[question], label="Similar questions", type="index")
            generate = gr.Button(value="Ask")
        with gr.Column():
            references_block = gr.Markdown("## References\n", label="global variable")

        examples.click(load_example, inputs=[examples], outputs=[question])
        generate.click(respond, inputs=question, outputs=[answer_block, references_block, examples])

chatbot.queue()
chatbot.launch()