kcelia commited on
Commit
d812385
1 Parent(s): 9a2d521

chore: handling user query

Browse files
Files changed (2) hide show
  1. app.py +27 -14
  2. utils_demo.py +28 -1
app.py CHANGED
@@ -7,6 +7,8 @@ from openai import OpenAI
7
  import os
8
  import json
9
  import re
 
 
10
 
11
  anonymizer = FHEAnonymizer()
12
 
@@ -15,6 +17,17 @@ client = OpenAI(
15
  )
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  def deidentify_text(input_text):
19
  anonymized_text, identified_words_with_prob = anonymizer(input_text)
20
 
@@ -74,10 +87,6 @@ def query_chatgpt(anonymized_query):
74
  return anonymized_response, deanonymized_response
75
 
76
 
77
- # Default demo text from the file
78
- with open("demo_text.txt", "r") as file:
79
- default_demo_text = file.read()
80
-
81
  with open("files/original_document.txt", "r") as file:
82
  original_document = file.read()
83
 
@@ -128,19 +137,23 @@ with demo:
128
  # """
129
  # )
130
 
 
131
  with gr.Row():
132
- input_text = gr.Textbox(
133
- value=default_demo_text,
134
- lines=1,
135
- placeholder="Input text here...",
136
- label="Input",
 
 
 
137
  )
138
 
139
- # List of example queries for easy access
140
- example_queries = ["Example Query 1", "Example Query 2", "Example Query 3"]
141
- examples_radio = gr.Radio(choices=example_queries, label="Example Queries")
142
-
143
- examples_radio.change(lambda example_query: example_query, inputs=[examples_radio], outputs=[input_text])
144
 
145
  anonymized_text_output = gr.Textbox(label="Anonymized Text with FHE", lines=1, interactive=True)
146
 
 
7
  import os
8
  import json
9
  import re
10
+ from utils_demo import *
11
+ from typing import List, Dict, Tuple
12
 
13
  anonymizer = FHEAnonymizer()
14
 
 
17
  )
18
 
19
 
20
+ def check_user_query_fn(user_query: str) -> Dict:
21
+ if is_user_query_valid(user_query):
22
+ # TODO: check if the query is related to our context
23
+ error_msg = ("Unable to process ❌: The request exceeds the length limit or falls "
24
+ "outside the scope of this document. Please refine your query.")
25
+ print(error_msg)
26
+ return {input_text: gr.update(value=error_msg)}
27
+ else:
28
+ # Collapsing Multiple Spaces
29
+ return {input_text: gr.update(value=re.sub(" +", " ", user_query))}
30
+
31
  def deidentify_text(input_text):
32
  anonymized_text, identified_words_with_prob = anonymizer(input_text)
33
 
 
87
  return anonymized_response, deanonymized_response
88
 
89
 
 
 
 
 
90
  with open("files/original_document.txt", "r") as file:
91
  original_document = file.read()
92
 
 
137
  # """
138
  # )
139
 
140
+ ########################## User Query Part ##########################
141
  with gr.Row():
142
+ input_text = gr.Textbox(value="Who lives in Maine?", label="User query", interactive=True)
143
+
144
+ default_query_box = gr.Radio(choices=list(DEFAULT_QUERIES.keys()), label="Example Queries")
145
+
146
+ default_query_box.change(
147
+ fn=lambda default_query_box: DEFAULT_QUERIES[default_query_box],
148
+ inputs=[default_query_box],
149
+ outputs=[input_text]
150
  )
151
 
152
+ input_text.change(
153
+ check_user_query_fn,
154
+ inputs=[input_text],
155
+ outputs=[input_text],
156
+ )
157
 
158
  anonymized_text_output = gr.Textbox(label="Anonymized Text with FHE", lines=1, interactive=True)
159
 
utils_demo.py CHANGED
@@ -1,6 +1,15 @@
1
  import torch
2
  import numpy as np
3
- import random
 
 
 
 
 
 
 
 
 
4
 
5
  def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
6
  """
@@ -20,3 +29,21 @@ def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
20
  mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
21
  return np.array(mean_pooled_batch)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+
4
+
5
+ MAX_USER_QUERY_LEN = 35
6
+
7
+ # List of example queries for easy access
8
+ DEFAULT_QUERIES = {
9
+ "Example Query 1": "Who visited microsoft.com on September 18?",
10
+ "Example Query 2": "Does Kate has drive ?",
11
+ "Example Query 3": "What phone number can be used to contact David Johnson?",
12
+ }
13
 
14
  def get_batch_text_representation(texts, model, tokenizer, batch_size=1):
15
  """
 
29
  mean_pooled_batch.extend(mean_pooled.cpu().detach().numpy())
30
  return np.array(mean_pooled_batch)
31
 
32
+
33
+ def is_user_query_valid(user_query: str) -> bool:
34
+ """
35
+ Check if the `user_query` is None and not empty.
36
+ Args:
37
+ user_query (str): The input text to be checked.
38
+ Returns:
39
+ bool: True if the `user_query` is None or empty, False otherwise.
40
+ """
41
+ # If the query is not part of the default queries
42
+ is_default_query = user_query in DEFAULT_QUERIES.values()
43
+
44
+ # Check if the query exceeds the length limit
45
+ is_exceeded_max_length = user_query is not None and len(user_query) <= MAX_USER_QUERY_LEN
46
+
47
+ return not is_default_query and not is_exceeded_max_length
48
+
49
+