File size: 5,127 Bytes
59654d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
from sentence_transformers import SentenceTransformer, util
from transformers import pipeline, GPT2Tokenizer
import os

# Define paths and models
filename = "output_country_details.txt"  # Adjust the filename as needed
retrieval_model_name = 'output/sentence-transformer-finetuned/'
gpt2_model_name = "gpt2"  # GPT-2 model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Load models
try:
    retrieval_model = SentenceTransformer(retrieval_model_name)
    gpt_model = pipeline("text-generation", model=gpt2_model_name)
    print("Models loaded successfully.")
except Exception as e:
    print(f"Failed to load models: {e}")

# Load and preprocess text from the country details file
def load_and_preprocess_text(filename):
    try:
        with open(filename, 'r', encoding='utf-8') as file:
            segments = [line.strip() for line in file if line.strip()]
        print("Text loaded and preprocessed successfully.")
        return segments
    except Exception as e:
        print(f"Failed to load or preprocess text: {e}")
        return []

segments = load_and_preprocess_text(filename)

def find_relevant_segment(user_query, segments):
    try:
        query_embedding = retrieval_model.encode(user_query)
        segment_embeddings = retrieval_model.encode(segments)
        similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
        best_idx = similarities.argmax()
        print("Relevant segment found:", segments[best_idx])
        return segments[best_idx]
    except Exception as e:
        print(f"Error finding relevant segment: {e}")
        return ""

def generate_response(user_query, relevant_segment):
    try:
        # Construct the prompt with the user query
        prompt = f"Thank you for your question! this is an additional fact about your topic: {relevant_segment}"
        
        # Generate response with adjusted max_length for completeness
        max_tokens = len(tokenizer(prompt)['input_ids']) + 50
        response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
        
        # Clean and format the response
        response_cleaned = clean_up_response(response, relevant_segment)
        return response_cleaned
    except Exception as e:
        print(f"Error generating response: {e}")
        return ""

def clean_up_response(response, segments):
    # Split the response into sentences
    sentences = response.split('.')
    
    # Remove empty sentences and any repetitive parts
    cleaned_sentences = []
    for sentence in sentences:
        if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
            cleaned_sentences.append(sentence.strip())
    
    # Join the sentences back together
    cleaned_response = '. '.join(cleaned_sentences).strip()
    
    # Check if the last sentence ends with a complete sentence
    if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
        cleaned_response += "."
    
    return cleaned_response

# Define the welcome message with markdown for formatting and larger fonts
welcome_message = """
# Welcome to VISABOT!

## Your AI-driven visa assistant for all travel-related queries.

"""

# Define topics and countries with flag emojis
topics = """
### Feel Free to ask me anything from the topics below!
- Visa issuance
- Documents needed
- Application process
- Processing time
- Recommended Vaccines
- Health Risks
- Healthcare Facilities
- Currency Information
- Embassy Information 
- Allowed stay
"""

countries = """
### Our chatbot can currently answer questions for these countries!
- πŸ‡¨πŸ‡³ China
- πŸ‡«πŸ‡· France
- πŸ‡¬πŸ‡Ή Guatemala
- πŸ‡±πŸ‡§ Lebanon
- πŸ‡²πŸ‡½ Mexico
- πŸ‡΅πŸ‡­ Philippines
- πŸ‡·πŸ‡Έ Serbia
- πŸ‡ΈπŸ‡± Sierra Leone
- πŸ‡ΏπŸ‡¦ South Africa
- πŸ‡»πŸ‡³ Vietnam
"""

# Define the Gradio app interface
def query_model(question):
    if question == "":  # If there's no input, the bot will display the greeting message.
        return welcome_message
    relevant_segment = find_relevant_segment(question, segments)
    response = generate_response(question, relevant_segment)
    return response

# Create Gradio Blocks interface for custom layout
with gr.Blocks() as demo:
    gr.Markdown(welcome_message)  # Display the welcome message with large fonts
    with gr.Row():
        with gr.Column():
            gr.Markdown(topics)  # Display the topics on the left
        with gr.Column():
            gr.Markdown(countries)  # Display the countries with flag emojis on the right
    with gr.Row():
        img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500)  # Adjust width as needed
    with gr.Row():
        with gr.Column():
            question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
            answer = gr.Textbox(label="VisaBot Response", placeholder="VisaBot will respond here...", interactive=False, lines=10)
            submit_button = gr.Button("Submit")
            submit_button.click(fn=query_model, inputs=question, outputs=answer)

# Launch the app
demo.launch()