Spaces:
Paused
Paused
import transformers | |
import re | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
from vllm import LLM, SamplingParams | |
import torch | |
import gradio as gr | |
import json | |
import os | |
import shutil | |
import requests | |
import chromadb | |
import pandas as pd | |
from chromadb.config import Settings | |
from chromadb.utils import embedding_functions | |
device = "cuda:0" | |
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="intfloat/multilingual-e5-base", device = "cuda") | |
client = chromadb.PersistentClient(path="education_corrected") | |
collection = client.get_collection(name="corrected", embedding_function = sentence_transformer_ef) | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
#Define variables | |
temperature=0.2 | |
max_new_tokens=1000 | |
top_p=0.92 | |
repetition_penalty=1.7 | |
model_name = "Pclanglais/Cassandre-Test" | |
llm = LLM(model_name, max_model_len=4096) | |
#Vector search over the database | |
def vector_search(collection, text): | |
results = collection.query( | |
query_texts=[text], | |
n_results=5, | |
) | |
document = [] | |
document_html = [] | |
id_list = "" | |
list_elm = 0 | |
for ids in results["ids"][0]: | |
first_link = str(results["metadatas"][0][list_elm]["identifier"]) | |
first_title = results["metadatas"][0][list_elm]["context"] + " " + results["documents"][0][list_elm] | |
list_elm = list_elm+1 | |
document.append(first_link + " : " + first_title) | |
document_html.append('<div class="source" id="' + first_link + '"><p><b>' + first_link + "</b> : " + first_title + "</div>") | |
document = "\n\n".join(document) | |
document_html = '<div id="source_listing">' + "".join(document_html) + "</div>" | |
# Replace this with the actual implementation of the vector search | |
return document, document_html | |
#CSS for references formatting | |
css = """ | |
.generation { | |
margin-left:2em; | |
margin-right:2em; | |
} | |
:target { | |
background-color: #CCF3DF; /* Change the text color to red */ | |
} | |
.source { | |
float:left; | |
max-width:17%; | |
margin-left:2%; | |
} | |
.tooltip { | |
position: relative; | |
cursor: pointer; | |
font-variant-position: super; | |
color: #97999b; | |
} | |
.tooltip:hover::after { | |
content: attr(data-text); | |
position: absolute; | |
left: 0; | |
top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */ | |
white-space: pre-wrap; /* Allows the text to wrap */ | |
width: 500px; /* Sets a fixed maximum width for the tooltip */ | |
max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */ | |
z-index: 1; | |
background-color: #f9f9f9; | |
color: #000; | |
border: 1px solid #ddd; | |
border-radius: 5px; | |
padding: 5px; | |
display: block; | |
box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */ | |
}""" | |
#Curtesy of chatgpt | |
def format_references(text): | |
# Define start and end markers for the reference | |
ref_start_marker = '<ref text="' | |
ref_end_marker = '</ref>' | |
# Initialize an empty list to hold parts of the text | |
parts = [] | |
current_pos = 0 | |
ref_number = 1 | |
# Loop until no more reference start markers are found | |
while True: | |
start_pos = text.find(ref_start_marker, current_pos) | |
if start_pos == -1: | |
# No more references found, add the rest of the text | |
parts.append(text[current_pos:]) | |
break | |
# Add text up to the start of the reference | |
parts.append(text[current_pos:start_pos]) | |
# Find the end of the reference text attribute | |
end_pos = text.find('">', start_pos) | |
if end_pos == -1: | |
# Malformed reference, break to avoid infinite loop | |
break | |
# Extract the reference text | |
ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
# Find the end of the reference tag | |
ref_end_pos = text.find(ref_end_marker, end_pos) | |
if ref_end_pos == -1: | |
# Malformed reference, break to avoid infinite loop | |
break | |
# Extract the reference ID | |
ref_id = text[end_pos + 2:ref_end_pos].strip() | |
# Create the HTML for the tooltip | |
tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[' + str(ref_number) +']</a></span>' | |
parts.append(tooltip_html) | |
# Update current_pos to the end of the current reference | |
current_pos = ref_end_pos + len(ref_end_marker) | |
ref_number = ref_number + 1 | |
# Join and return the parts | |
parts = ''.join(parts) | |
return parts | |
# Class to encapsulate the Falcon chatbot | |
class MistralChatBot: | |
def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
self.system_prompt = system_prompt | |
def predict(self, user_message): | |
fiches, fiches_html = vector_search(collection, user_message) | |
sampling_params = SamplingParams(temperature=.7, top_p=.95, max_tokens=2000, presence_penalty = 1.5, stop = ["``"]) | |
detailed_prompt = """<|im_start|>system | |
Tu es Cassandre, le chatbot de l'Éducation nationale qui donne des réponses sourcées.<|im_end|> | |
<|im_start|>user | |
Ecrit un texte référencé en réponse à cette question : """ + user_message + """ | |
Les références doivent être citées de cette manière : texte rédigé<ref text=\"[passage pertinent dans la référence]\">[\"identifiant de la référence\"]</ref>Si les références ne permettent pas de répondre, qu'il n'y a pas de réponse. | |
Les cinq références disponibles : """ + fiches + "<|im_end|>\n<|im_start|>assistant\n" | |
prompts = [detailed_prompt] | |
outputs = llm.generate(prompts, sampling_params, use_tqdm = False) | |
generated_text = outputs[0].outputs[0].text | |
generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + format_references(generated_text) + "</div>" | |
fiches_html = '<h2 style="text-align:center">Sources</h3>\n' + fiches_html | |
return generated_text, fiches_html | |
# Create the Falcon chatbot instance | |
mistral_bot = MistralChatBot() | |
# Define the Gradio interface | |
title = "Motta" | |
description = "Le LLM répond à toutes les questions sur la SDN." | |
examples = [ | |
[ | |
"Comment garantir la paix universelle?", # user_message | |
0.7 # temperature | |
] | |
] | |
additional_inputs=[ | |
gr.Slider( | |
label="Température", | |
value=0.2, # Default value | |
minimum=0.05, | |
maximum=1.0, | |
step=0.05, | |
interactive=True, | |
info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté", | |
), | |
] | |
demo = gr.Blocks() | |
with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo: | |
gr.HTML("""<h1 style="text-align:center">Motta</h1>""") | |
text_input = gr.Textbox(label="Votre question ou votre instruction.", type="text", lines=1) | |
text_button = gr.Button("Interroger Motta") | |
text_output = gr.HTML(label="La réponse de Motta") | |
embedding_output = gr.HTML(label="Les sources utilisées") | |
text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output, embedding_output]) | |
if __name__ == "__main__": | |
demo.queue().launch() |