|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM |
|
import gradio as gr |
|
import torch |
|
|
|
|
|
camembert_model_name = "camembert-base" |
|
camembert_tokenizer = AutoTokenizer.from_pretrained(camembert_model_name) |
|
camembert_model = AutoModelForSequenceClassification.from_pretrained(camembert_model_name) |
|
|
|
|
|
gpt2_model_name = "dbddv01/gpt2-french-small" |
|
gpt2_tokenizer = AutoTokenizer.from_pretrained(gpt2_model_name) |
|
gpt2_model = AutoModelForCausalLM.from_pretrained(gpt2_model_name) |
|
|
|
|
|
intent_dict = { |
|
0: "salutation", |
|
1: "question_faq", |
|
2: "aide" |
|
} |
|
|
|
|
|
def chatbot(user_input): |
|
|
|
cam_inputs = camembert_tokenizer(user_input, return_tensors="pt", max_length=128, truncation=True) |
|
cam_outputs = camembert_model(**cam_inputs) |
|
|
|
|
|
intent = torch.argmax(cam_outputs.logits, dim=1).item() |
|
detected_intent = intent_dict.get(intent, "inconnu") |
|
|
|
|
|
if detected_intent == "salutation": |
|
response = "Bonjour! Comment puis-je vous aider aujourd'hui ?" |
|
elif detected_intent == "question_faq": |
|
|
|
gpt2_inputs = gpt2_tokenizer(user_input, return_tensors="pt") |
|
gpt2_outputs = gpt2_model.generate(gpt2_inputs["input_ids"], max_length=100, pad_token_id=gpt2_tokenizer.eos_token_id) |
|
response = gpt2_tokenizer.decode(gpt2_outputs[0], skip_special_tokens=True) |
|
elif detected_intent == "aide": |
|
response = "Je suis là pour vous aider ! Que puis-je faire pour vous ?" |
|
else: |
|
response = "Je ne suis pas sûr de comprendre. Pouvez-vous reformuler votre question ?" |
|
|
|
return response |
|
|
|
|
|
demo = gr.Interface( |
|
fn=chatbot, |
|
inputs="text", |
|
outputs="text", |
|
title="Chatbot BERT-GPT en Français" |
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|