Spaces:
Runtime error
Runtime error
File size: 5,127 Bytes
59654d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import gradio as gr
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, GPT2Tokenizer
import os
# Define paths and models
filename = "output_country_details.txt" # Adjust the filename as needed
retrieval_model_name = 'output/sentence-transformer-finetuned/'
gpt2_model_name = "gpt2" # GPT-2 model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Load models
try:
retrieval_model = SentenceTransformer(retrieval_model_name)
gpt_model = pipeline("text-generation", model=gpt2_model_name)
print("Models loaded successfully.")
except Exception as e:
print(f"Failed to load models: {e}")
# Load and preprocess text from the country details file
def load_and_preprocess_text(filename):
try:
with open(filename, 'r', encoding='utf-8') as file:
segments = [line.strip() for line in file if line.strip()]
print("Text loaded and preprocessed successfully.")
return segments
except Exception as e:
print(f"Failed to load or preprocess text: {e}")
return []
segments = load_and_preprocess_text(filename)
def find_relevant_segment(user_query, segments):
try:
query_embedding = retrieval_model.encode(user_query)
segment_embeddings = retrieval_model.encode(segments)
similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
best_idx = similarities.argmax()
print("Relevant segment found:", segments[best_idx])
return segments[best_idx]
except Exception as e:
print(f"Error finding relevant segment: {e}")
return ""
def generate_response(user_query, relevant_segment):
try:
# Construct the prompt with the user query
prompt = f"Thank you for your question! this is an additional fact about your topic: {relevant_segment}"
# Generate response with adjusted max_length for completeness
max_tokens = len(tokenizer(prompt)['input_ids']) + 50
response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
# Clean and format the response
response_cleaned = clean_up_response(response, relevant_segment)
return response_cleaned
except Exception as e:
print(f"Error generating response: {e}")
return ""
def clean_up_response(response, segments):
# Split the response into sentences
sentences = response.split('.')
# Remove empty sentences and any repetitive parts
cleaned_sentences = []
for sentence in sentences:
if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
cleaned_sentences.append(sentence.strip())
# Join the sentences back together
cleaned_response = '. '.join(cleaned_sentences).strip()
# Check if the last sentence ends with a complete sentence
if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
cleaned_response += "."
return cleaned_response
# Define the welcome message with markdown for formatting and larger fonts
welcome_message = """
# Welcome to VISABOT!
## Your AI-driven visa assistant for all travel-related queries.
"""
# Define topics and countries with flag emojis
topics = """
### Feel Free to ask me anything from the topics below!
- Visa issuance
- Documents needed
- Application process
- Processing time
- Recommended Vaccines
- Health Risks
- Healthcare Facilities
- Currency Information
- Embassy Information
- Allowed stay
"""
countries = """
### Our chatbot can currently answer questions for these countries!
- π¨π³ China
- π«π· France
- π¬πΉ Guatemala
- π±π§ Lebanon
- π²π½ Mexico
- π΅π Philippines
- π·πΈ Serbia
- πΈπ± Sierra Leone
- πΏπ¦ South Africa
- π»π³ Vietnam
"""
# Define the Gradio app interface
def query_model(question):
if question == "": # If there's no input, the bot will display the greeting message.
return welcome_message
relevant_segment = find_relevant_segment(question, segments)
response = generate_response(question, relevant_segment)
return response
# Create Gradio Blocks interface for custom layout
with gr.Blocks() as demo:
gr.Markdown(welcome_message) # Display the welcome message with large fonts
with gr.Row():
with gr.Column():
gr.Markdown(topics) # Display the topics on the left
with gr.Column():
gr.Markdown(countries) # Display the countries with flag emojis on the right
with gr.Row():
img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500) # Adjust width as needed
with gr.Row():
with gr.Column():
question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10)
submit_button = gr.Button("Submit")
submit_button.click(fn=query_model, inputs=question, outputs=answer)
# Launch the app
demo.launch()
|