File size: 4,662 Bytes
282b01b f9af090 371d8a8 f9af090 371d8a8 f9af090 86b6d09 f9af090 371d8a8 f9af090 371d8a8 f9af090 371d8a8 f9af090 371d8a8 282b01b 371d8a8 f9af090 371d8a8 f9af090 86b6d09 f9af090 371d8a8 f9af090 371d8a8 f9af090 371d8a8 282b01b f9af090 371d8a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import random
import json
from pathlib import Path
from pprint import pprint
import uuid
import chromadb
from chromadb.utils import embedding_functions
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(
f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
models = {
"wizardLM-7B-HF": "TheBloke/wizardLM-7B-HF",
"wizard-vicuna-13B-GPTQ": "TheBloke/wizard-vicuna-13B-GPTQ",
"Wizard-Vicuna-13B-Uncensored": "ehartford/Wizard-Vicuna-13B-Uncensored",
"WizardLM-13B": "TheBloke/WizardLM-13B-V1.0-Uncensored-GPTQ",
"Llama-2-7B": "TheBloke/Llama-2-7b-Chat-GPTQ",
"Vicuna-13B": "TheBloke/vicuna-13B-v1.5-GPTQ",
"WizardLM-13B-V1.2": "TheBloke/WizardLM-13B-V1.2-GPTQ", # Trained from Llama-2 13b
"Mistral-7B": "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ"
}
model_name = "Mistral-7B"
tokenizer = AutoTokenizer.from_pretrained(models[model_name])
# tokenizer.use_default_system_prompt = True
tokenizer.chat_template = tokenizer.default_chat_template
model = AutoModelForCausalLM.from_pretrained(models[model_name],
torch_dtype=torch.float16,
device_map="auto")
file_path = './data/faq_dataset.json'
data = json.loads(Path(file_path).read_text())
client = chromadb.Client()
emb_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="BAAI/bge-small-en-v1.5")
collection = client.create_collection(
name="retrieval_qa",
embedding_function=emb_fn,
metadata={"hnsw:space": "cosine"} # l2 is the default
)
# encode QnA as json strings for generating embeddings
documents = [json.dumps(q) for q in data['questions']]
metadatas = data['questions'] # retain QnA as dict in metadatas
ids = [str(uuid.uuid1()) for _ in documents]
collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
samples = [
["How can I return a product?"],
["What is the return policy?"],
["How can I contact customer support?"],
]
def respond(query):
global samples
docs = collection.query(query_texts=[query], n_results=3)
chat = []
related_questions = []
references = "## References\n"
system_message = "You are a helpful, respectful and honest support executive. Always be as helpfully as possible, while being correct. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. Use the following piece of context to answer the questions. If the information is not present in the provided context, answer that you don't know. Please don't share false information."
for d in docs['metadatas'][0]:
# prepare chat template
system_message += f"\n Question: {d['question']} \n Answer: {d['answer']}"
# Update references
references += f"**{d['question']}**\n\n"
references += f"> {d['answer']}\n\n"
# Update related questions
related_questions.append([d['question']])
chat.append({"role": "system", "content": system_message})
chat.append({"role": "user", "content": query})
encodeds = tokenizer.apply_chat_template(chat, return_tensors="pt")
model_inputs = encodeds.to(model.device)
streamer = TextStreamer(tokenizer)
model.to(model.device)
generated_ids = model.generate(
model_inputs, streamer=streamer, temperature=0.01, max_new_tokens=100, do_sample=True)
answer = tokenizer.batch_decode(
generated_ids[:, model_inputs.shape[1]:])[0]
answer = answer.replace('</s>', '')
samples = related_questions
related = gr.update(samples=related_questions)
yield [answer, references, related]
def load_example(example_id):
global samples
return samples[example_id][0]
with gr.Blocks() as chatbot:
with gr.Row():
with gr.Column():
answer_block = gr.Textbox(label="Answers", lines=2)
question = gr.Textbox(label="Question")
examples = gr.Dataset(samples=samples, components=[
question], label="Similar questions", type="index")
generate = gr.Button(value="Ask")
with gr.Column():
references_block = gr.Markdown(
"## References\n", label="global variable")
examples.click(load_example, inputs=[examples], outputs=[question])
generate.click(respond, inputs=question, outputs=[
answer_block, references_block, examples])
chatbot.queue()
chatbot.launch()
|