hibalaz commited on
Commit
3672892
1 Parent(s): 696d3e8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -55
app.py CHANGED
@@ -1,24 +1,30 @@
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
- from transformers import pipeline, GPT2Tokenizer
4
  import os
 
 
 
5
 
6
- # Define paths and models
7
- filename = "output_country_details.txt" # Adjust the filename as needed
8
  retrieval_model_name = 'output/sentence-transformer-finetuned/'
9
- gpt2_model_name = "gpt2" # GPT-2 model
10
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
11
 
12
- # Load models
 
 
 
 
13
  try:
14
  retrieval_model = SentenceTransformer(retrieval_model_name)
15
- gpt_model = pipeline("text-generation", model=gpt2_model_name)
16
  print("Models loaded successfully.")
17
  except Exception as e:
18
  print(f"Failed to load models: {e}")
19
 
20
- # Load and preprocess text from the country details file
21
  def load_and_preprocess_text(filename):
 
 
 
22
  try:
23
  with open(filename, 'r', encoding='utf-8') as file:
24
  segments = [line.strip() for line in file if line.strip()]
@@ -31,61 +37,82 @@ def load_and_preprocess_text(filename):
31
  segments = load_and_preprocess_text(filename)
32
 
33
  def find_relevant_segment(user_query, segments):
 
 
 
 
34
  try:
35
- query_embedding = retrieval_model.encode(user_query)
36
- segment_embeddings = retrieval_model.encode(segments)
 
 
 
 
 
 
 
 
 
37
  similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
38
  best_idx = similarities.argmax()
39
- print("Relevant segment found:", segments[best_idx])
40
- return segments[best_idx]
41
  except Exception as e:
42
- print(f"Error finding relevant segment: {e}")
43
  return ""
44
 
 
45
  def generate_response(user_query, relevant_segment):
 
 
 
46
  try:
47
- # Construct the prompt with the user query
48
- prompt = f"Thank you for your question! this is an additional fact about your topic: {relevant_segment}"
49
-
50
- # Generate response with adjusted max_length for completeness
51
- max_tokens = len(tokenizer(prompt)['input_ids']) + 50
52
- response = gpt_model(prompt, max_length=max_tokens, temperature=0.25)[0]['generated_text']
53
-
54
- # Clean and format the response
55
- response_cleaned = clean_up_response(response, relevant_segment)
56
- return response_cleaned
 
 
 
 
 
 
57
  except Exception as e:
58
- print(f"Error generating response: {e}")
59
- return ""
60
 
61
- def clean_up_response(response, segments):
62
- # Split the response into sentences
63
- sentences = response.split('.')
64
-
65
- # Remove empty sentences and any repetitive parts
66
- cleaned_sentences = []
67
- for sentence in sentences:
68
- if sentence.strip() and sentence.strip() not in segments and sentence.strip() not in cleaned_sentences:
69
- cleaned_sentences.append(sentence.strip())
70
-
71
- # Join the sentences back together
72
- cleaned_response = '. '.join(cleaned_sentences).strip()
73
-
74
- # Check if the last sentence ends with a complete sentence
75
- if cleaned_response and not cleaned_response.endswith((".", "!", "?")):
76
- cleaned_response += "."
77
-
78
- return cleaned_response
79
 
80
- # Define the welcome message with markdown for formatting and larger fonts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  welcome_message = """
82
  # Welcome to VISABOT!
83
 
84
  ## Your AI-driven visa assistant for all travel-related queries.
85
-
86
  """
87
 
88
- # Define topics and countries with flag emojis
89
  topics = """
90
  ### Feel Free to ask me anything from the topics below!
91
  - Visa issuance
@@ -114,24 +141,33 @@ countries = """
114
  - 🇻🇳 Vietnam
115
  """
116
 
117
- # Define the Gradio app interface
118
  def query_model(question):
119
- if question == "": # If there's no input, the bot will display the greeting message.
 
 
 
 
 
 
 
 
 
120
  return welcome_message
121
  relevant_segment = find_relevant_segment(question, segments)
122
  response = generate_response(question, relevant_segment)
123
  return response
124
 
125
- # Create Gradio Blocks interface for custom layout
126
  with gr.Blocks() as demo:
127
- gr.Markdown(welcome_message) # Display the welcome message with large fonts
128
  with gr.Row():
129
  with gr.Column():
130
- gr.Markdown(topics) # Display the topics on the left
131
  with gr.Column():
132
- gr.Markdown(countries) # Display the countries with flag emojis on the right
133
  with gr.Row():
134
- img = gr.Image(os.path.join(os.getcwd(), "final.png"), width=500) # Adjust width as needed
135
  with gr.Row():
136
  with gr.Column():
137
  question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
@@ -139,5 +175,6 @@ with gr.Blocks() as demo:
139
  submit_button = gr.Button("Submit")
140
  submit_button.click(fn=query_model, inputs=question, outputs=answer)
141
 
142
- # Launch the app
143
- demo.launch()
 
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer, util
3
+ import openai
4
  import os
5
+ import os
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+
8
 
9
+ # Initialize paths and model identifiers for easy configuration and maintenance
10
+ filename = "output_country_details.txt" # Path to the file storing country-specific details
11
  retrieval_model_name = 'output/sentence-transformer-finetuned/'
 
 
12
 
13
+ openai.api_key = 'sk-proj-BVO7g5ig8PKdlQwDCZSeT3BlbkFJAvilYAEcPFbA0XOjz7ce'
14
+
15
+
16
+
17
+ # Attempt to load the necessary models and provide feedback on success or failure
18
  try:
19
  retrieval_model = SentenceTransformer(retrieval_model_name)
 
20
  print("Models loaded successfully.")
21
  except Exception as e:
22
  print(f"Failed to load models: {e}")
23
 
 
24
  def load_and_preprocess_text(filename):
25
+ """
26
+ Load and preprocess text from a file, removing empty lines and stripping whitespace.
27
+ """
28
  try:
29
  with open(filename, 'r', encoding='utf-8') as file:
30
  segments = [line.strip() for line in file if line.strip()]
 
37
  segments = load_and_preprocess_text(filename)
38
 
39
  def find_relevant_segment(user_query, segments):
40
+ """
41
+ Find the most relevant text segment for a user's query using cosine similarity among sentence embeddings.
42
+ This version tries to match country names in the query with those in the segments.
43
+ """
44
  try:
45
+ # Lowercase the query for better matching
46
+ lower_query = user_query.lower()
47
+ # Filter segments to include only those containing country names mentioned in the query
48
+ country_segments = [seg for seg in segments if any(country.lower() in seg.lower() for country in ['Guatemala', 'Mexico', 'U.S.', 'United States'])]
49
+
50
+ # If no specific country segments found, default to general matching
51
+ if not country_segments:
52
+ country_segments = segments
53
+
54
+ query_embedding = retrieval_model.encode(lower_query)
55
+ segment_embeddings = retrieval_model.encode(country_segments)
56
  similarities = util.pytorch_cos_sim(query_embedding, segment_embeddings)[0]
57
  best_idx = similarities.argmax()
58
+ return country_segments[best_idx]
 
59
  except Exception as e:
60
+ print(f"Error in finding relevant segment: {e}")
61
  return ""
62
 
63
+
64
  def generate_response(user_query, relevant_segment):
65
+ """
66
+ Generate a response emphasizing the bot's capability in providing country-specific visa information.
67
+ """
68
  try:
69
+ system_message = "You are a visa chatbot specialized in providing country-specific visa requirement information."
70
+ user_message = f"Here's the information on visa requirements for your query: {relevant_segment}"
71
+ messages = [
72
+ {"role": "system", "content": system_message},
73
+ {"role": "user", "content": user_message}
74
+ ]
75
+ response = openai.ChatCompletion.create(
76
+ model="gpt-4-turbo", # Verify model name
77
+ messages=messages,
78
+ max_tokens=150,
79
+ temperature=0.7,
80
+ top_p=1,
81
+ frequency_penalty=0,
82
+ presence_penalty=0
83
+ )
84
+ return response['choices'][0]['message']['content'].strip()
85
  except Exception as e:
86
+ print(f"Error in generating response: {e}")
87
+ return f"Error in generating response: {e}"
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+
91
+
92
+
93
+
94
+ # Define and configure the Gradio application interface to interact with users.
95
+ # Define and configure the Gradio application interface to interact with users.
96
+ def query_model(question):
97
+ """
98
+ Process a question, find relevant information, and generate a response, specifically for U.S. visa questions.
99
+ """
100
+ if question == "":
101
+ return "Welcome to VisaBot! Ask me anything about U.S. visa processes."
102
+ relevant_segment = find_relevant_segment(question, segments)
103
+ if not relevant_segment:
104
+ return "Could not find U.S.-specific information. Please refine your question."
105
+ response = generate_response(question, relevant_segment)
106
+ return response
107
+
108
+
109
+ # Define the welcome message and specific topics and countries the chatbot can provide information about.
110
  welcome_message = """
111
  # Welcome to VISABOT!
112
 
113
  ## Your AI-driven visa assistant for all travel-related queries.
 
114
  """
115
 
 
116
  topics = """
117
  ### Feel Free to ask me anything from the topics below!
118
  - Visa issuance
 
141
  - 🇻🇳 Vietnam
142
  """
143
 
144
+ # Define and configure the Gradio application interface to interact with users.
145
  def query_model(question):
146
+ """
147
+ Process a question, find relevant information, and generate a response.
148
+
149
+ Args:
150
+ question (str): User's input question.
151
+
152
+ Returns:
153
+ str: Generated response or a default welcome message if no question is provided.
154
+ """
155
+ if question == "":
156
  return welcome_message
157
  relevant_segment = find_relevant_segment(question, segments)
158
  response = generate_response(question, relevant_segment)
159
  return response
160
 
161
+ # Setup the Gradio Blocks interface with custom layout components
162
  with gr.Blocks() as demo:
163
+ gr.Markdown(welcome_message) # Display the formatted welcome message
164
  with gr.Row():
165
  with gr.Column():
166
+ gr.Markdown(topics) # Show the topics on the left side
167
  with gr.Column():
168
+ gr.Markdown(countries) # Display the list of countries on the right side
169
  with gr.Row():
170
+ img = gr.Image(os.path.join(os.getcwd(), "poster.png"), width=500) # Include an image for visual appeal
171
  with gr.Row():
172
  with gr.Column():
173
  question = gr.Textbox(label="Your question", placeholder="What do you want to ask about?")
 
175
  submit_button = gr.Button("Submit")
176
  submit_button.click(fn=query_model, inputs=question, outputs=answer)
177
 
178
+ # Launch the Gradio app to allow user interaction
179
+ demo.launch(share= True)
180
+