WenqingZhang commited on
Commit
4b86909
1 Parent(s): 7f74d35

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +314 -38
  2. server.py +45 -0
app.py CHANGED
@@ -1,49 +1,325 @@
1
  import gradio as gr
2
- import easyocr
3
-
4
- # 创建一个 EasyOCR 阅读器实例
5
- reader = easyocr.Reader(['en']) # 选择语言,这里选择英文
6
-
7
- def toggle_visibility(input_type):
8
- """根据输入类型控制文本框和文件上传控件的可见性"""
9
- user_input_visible = input_type == "Text Input"
10
- file_upload_visible = input_type == "File Upload"
11
- return gr.update(visible=user_input_visible), gr.update(visible=file_upload_visible)
12
-
13
- def process_input(input_type, user_input, uploaded_file):
14
- print('ooooocr')
15
- if input_type == "File Upload" and uploaded_file is not None:
16
- # 读取上传的文件
17
- with open(uploaded_file.name, "rb") as f:
18
- image = f.read()
19
- results = reader.readtext(image)
20
- # 提取识别的文本
21
- extracted_text = ' '.join([text[1] for text in results])
22
- print("提取的文本:")
23
- print(extracted_text)
24
- return extracted_text
25
- elif input_type == "Text Input":
26
- return user_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  demo = gr.Blocks()
29
 
 
 
30
  with demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  gr.Markdown("# Step 1: Generate the keys")
32
- b_gen_key_and_install = gr.Button("Generate the keys and send public part to server")
33
- evaluation_key = gr.Textbox(label="Evaluation key (truncated):", max_lines=4, interactive=False)
34
- user_id = gr.Textbox(label="", max_lines=4, interactive=False, visible=False)
35
 
36
- gr.Markdown("# Step 2: Choose Input Method")
37
- input_type = gr.Radio(choices=["Text Input", "File Upload"], label="Select Input Method")
38
- user_input = gr.Textbox(label="Enter Text", placeholder="Type here...", visible=False) # Initially hidden
39
- file_upload = gr.File(label="Upload File", file_types=[".jpg", ".png"], visible=False) # Initially hidden
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # 使用change事件来调用toggle_visibility函数
42
- input_type.change(toggle_visibility, inputs=input_type, outputs=[user_input, file_upload])
43
 
44
- submit_button = gr.Button("Submit")
45
- output_text = gr.Textbox(label="Extracted Text")
 
 
 
 
 
 
 
 
46
 
47
- submit_button.click(process_input, inputs=[input_type, user_input, file_upload], outputs=output_text)
 
48
 
49
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
+ from requests import head
3
+ from transformer_vectorizer import TransformerVectorizer
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+
6
+ from concrete.ml.deployment import FHEModelClient
7
+ import numpy
8
+ import os
9
+ from pathlib import Path
10
+ import requests
11
+ import json
12
+ import base64
13
+ import subprocess
14
+ import shutil
15
+ import time
16
+
17
+ # This repository's directory
18
+ REPO_DIR = Path(__file__).parent
19
+
20
+ subprocess.Popen(["uvicorn", "server:app"], cwd=REPO_DIR)
21
+
22
+ # Wait 5 sec for the server to start
23
+ time.sleep(5)
24
+
25
+ # Encrypted data limit for the browser to display
26
+ # (encrypted data is too large to display in the browser)
27
+ ENCRYPTED_DATA_BROWSER_LIMIT = 500
28
+ N_USER_KEY_STORED = 20
29
+ model_names=['financial_rating','legal_rating']
30
+ FHE_MODEL_PATH = "deployment/financial_rating"
31
+ FHE_LEGAL_PATH = "deployment/legal_rating"
32
+ #FHE_LEGAL_PATH="deployment/legal_rating"
33
+
34
+ print("Loading the transformer model...")
35
+
36
+ # Initialize the transformer vectorizer
37
+ transformer_vectorizer = TransformerVectorizer()
38
+ vectorizer = TfidfVectorizer()
39
+
40
+ def clean_tmp_directory():
41
+ # Allow 20 user keys to be stored.
42
+ # Once that limitation is reached, deleted the oldest.
43
+ path_sub_directories = sorted([f for f in Path(".fhe_keys/").iterdir() if f.is_dir()], key=os.path.getmtime)
44
+
45
+ user_ids = []
46
+ if len(path_sub_directories) > N_USER_KEY_STORED:
47
+ n_files_to_delete = len(path_sub_directories) - N_USER_KEY_STORED
48
+ for p in path_sub_directories[:n_files_to_delete]:
49
+ user_ids.append(p.name)
50
+ shutil.rmtree(p)
51
+
52
+ list_files_tmp = Path("tmp/").iterdir()
53
+ # Delete all files related to user_id
54
+ for file in list_files_tmp:
55
+ for user_id in user_ids:
56
+ if file.name.endswith(f"{user_id}.npy"):
57
+ file.unlink()
58
+ model_nams=[]
59
+
60
+ def keygen(selected_tasks):
61
+ # Clean tmp directory if needed
62
+ clean_tmp_directory()
63
+
64
+ print("Initializing FHEModelClient...")
65
+
66
+
67
+ if not selected_tasks:
68
+ return "choose task first" # 修改提示信息为英文
69
+
70
+ if "legal_rating" in selected_tasks:
71
+ model_names.append('legal_rating')
72
+ if "financial_rating" in selected_tasks:
73
+ model_names.append('financial_rating')
74
+
75
+ # Let's create a user_id
76
+
77
+
78
+ user_id = numpy.random.randint(0, 2**32)
79
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
80
+ fhe_api.load()
81
+
82
+
83
+ # Generate a fresh key
84
+ fhe_api.generate_private_and_evaluation_keys(force=True)
85
+ evaluation_key = fhe_api.get_serialized_evaluation_keys()
86
+
87
+ # Save evaluation_key in a file, since too large to pass through regular Gradio
88
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
89
+ eval_key_path = Path(f"tmp/tmp_evaluation_key__{user_id}.npy")
90
+ numpy.save(eval_key_path, evaluation_key)
91
+
92
+ user_id_leagl = numpy.random.randint(0, 2**32)
93
+ fhe_api_legal= FHEModelClient(FHE_LEGAL_PATH, f".fhe_keys/{user_id_leagl}")
94
+ fhe_api_legal.load()
95
+
96
+
97
+
98
+ evaluation_keys = []
99
+ evaluation_keys.append(list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT])
100
+ #evaluation_keys.append(list(evaluation_key_legal)[:ENCRYPTED_DATA_BROWSER_LIMIT])
101
+
102
+ return [list(evaluation_key)[:ENCRYPTED_DATA_BROWSER_LIMIT], [user_id]]
103
+ def encode_quantize_encrypt(text, user_id):
104
+ if not user_id:
105
+ raise gr.Error("You need to generate FHE keys first.")
106
+
107
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
108
+ fhe_api.load()
109
+ encodings = transformer_vectorizer.transform([text])
110
+ quantized_encodings = fhe_api.model.quantize_input(encodings).astype(numpy.uint8)
111
+ encrypted_quantized_encoding = fhe_api.quantize_encrypt_serialize(encodings)
112
+
113
+ # Save encrypted_quantized_encoding in a file, since too large to pass through regular Gradio
114
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
115
+ numpy.save(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy", encrypted_quantized_encoding)
116
+
117
+ # Compute size
118
+ encrypted_quantized_encoding_shorten = list(encrypted_quantized_encoding)[:ENCRYPTED_DATA_BROWSER_LIMIT]
119
+ encrypted_quantized_encoding_shorten_hex = ''.join(f'{i:02x}' for i in encrypted_quantized_encoding_shorten)
120
+ return (
121
+ encodings[0],
122
+ quantized_encodings[0],
123
+ encrypted_quantized_encoding_shorten_hex,
124
+ )
125
+
126
+
127
+
128
+ def run_fhe(user_id):
129
+ encoded_data_path = Path(f"tmp/tmp_encrypted_quantized_encoding_{user_id}.npy")
130
+ if not user_id:
131
+ raise gr.Error("You need to generate FHE keys first.")
132
+ if not encoded_data_path.is_file():
133
+ raise gr.Error("No encrypted data was found. Encrypt the data before trying to predict.")
134
+
135
+ # Read encrypted_quantized_encoding from the file
136
+ encrypted_quantized_encoding = numpy.load(encoded_data_path)
137
+
138
+ # Read evaluation_key from the file
139
+ evaluation_key = numpy.load(f"tmp/tmp_evaluation_key_{user_id}.npy")
140
+
141
+ # Use base64 to encode the encodings and evaluation key
142
+ encrypted_quantized_encoding = base64.b64encode(encrypted_quantized_encoding).decode()
143
+ encoded_evaluation_key = base64.b64encode(evaluation_key).decode()
144
+
145
+ query = {}
146
+ query["evaluation_key"] = encoded_evaluation_key
147
+ query["encrypted_encoding"] = encrypted_quantized_encoding
148
+ headers = {"Content-type": "application/json"}
149
+ response = requests.post(
150
+ "http://localhost:8000/predict_sentiment", data=json.dumps(query), headers=headers
151
+ )
152
+ encrypted_prediction = base64.b64decode(response.json()["encrypted_prediction"])
153
+
154
+ # Save encrypted_prediction in a file, since too large to pass through regular Gradio
155
+ # buttons, https://github.com/gradio-app/gradio/issues/1877
156
+ numpy.save(f"tmp/tmp_encrypted_prediction_{user_id}.npy", encrypted_prediction)
157
+ encrypted_prediction_shorten = list(encrypted_prediction)[:ENCRYPTED_DATA_BROWSER_LIMIT]
158
+ encrypted_prediction_shorten_hex = ''.join(f'{i:02x}' for i in encrypted_prediction_shorten)
159
+ return encrypted_prediction_shorten_hex
160
+
161
+
162
+ def decrypt_prediction(user_id):
163
+ encoded_data_path = Path(f"tmp/tmp_encrypted_prediction_{user_id}.npy")
164
+ if not user_id:
165
+ raise gr.Error("You need to generate FHE keys first.")
166
+ if not encoded_data_path.is_file():
167
+ raise gr.Error("No encrypted prediction was found. Run the prediction over the encrypted data first.")
168
+
169
+ # Read encrypted_prediction from the file
170
+ encrypted_prediction = numpy.load(encoded_data_path).tobytes()
171
+
172
+ fhe_api = FHEModelClient(FHE_MODEL_PATH, f".fhe_keys/{user_id}")
173
+ fhe_api.load()
174
+
175
+ # We need to retrieve the private key that matches the client specs (see issue #18)
176
+ fhe_api.generate_private_and_evaluation_keys(force=False)
177
+
178
+ predictions = fhe_api.deserialize_decrypt_dequantize(encrypted_prediction)
179
+ print(predictions)
180
+ return {
181
+ "low_relative": predictions[0][0],
182
+ "medium_relative": predictions[0][1],
183
+ "high_relative": predictions[0][2],
184
+ }
185
+
186
 
187
  demo = gr.Blocks()
188
 
189
+
190
+ print("Starting the demo...")
191
  with demo:
192
+
193
+ gr.Markdown(
194
+ """
195
+ <p align="center">
196
+ <img width=200 src="https://user-images.githubusercontent.com/5758427/197816413-d9cddad3-ba38-4793-847d-120975e1da11.png">
197
+ </p>
198
+ <h2 align="center">Sentiment Analysis On Encrypted Data Using Homomorphic Encryption</h2>
199
+ <p align="center">
200
+ <a href="https://github.com/zama-ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197972109-faaaff3e-10e2-4ab6-80f5-7531f7cfb08f.png">Concrete-ML</a>
201
+
202
+ <a href="https://docs.zama.ai/concrete-ml"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197976802-fddd34c5-f59a-48d0-9bff-7ad1b00cb1fb.png">Documentation</a>
203
+
204
+ <a href="https://zama.ai/community"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197977153-8c9c01a7-451a-4993-8e10-5a6ed5343d02.png">Community</a>
205
+
206
+ <a href="https://twitter.com/zama_fhe"> <img style="vertical-align: middle; display:inline-block; margin-right: 3px;" width=15 src="https://user-images.githubusercontent.com/5758427/197975044-bab9d199-e120-433b-b3be-abd73b211a54.png">@zama_fhe</a>
207
+ </p>
208
+ <p align="center">
209
+ <img src="https://user-images.githubusercontent.com/56846628/219329304-6868be9e-5ce8-4279-9123-4cb1bc0c2fb5.png" width="60%" height="60%">
210
+ </p>
211
+ """
212
+ )
213
+
214
+
215
+ gr.Markdown(
216
+ """
217
+ <p align="center">
218
+ </p>
219
+ <p align="center">
220
+ </p>
221
+ """
222
+ )
223
+
224
+ gr.Markdown("## Notes")
225
+ gr.Markdown(
226
+ """
227
+ - The private key is used to encrypt and decrypt the data and shall never be shared.
228
+ - The evaluation key is a public key that the server needs to process encrypted data.
229
+ """
230
+ )
231
+ gr.Markdown("# Step 0: Select Tasks")
232
+ task_checkbox = gr.CheckboxGroup(
233
+ choices=["legal_rating", "financial_rating"],
234
+ label="select_tasks"
235
+ )
236
+
237
  gr.Markdown("# Step 1: Generate the keys")
 
 
 
238
 
239
+ b_gen_key_and_install = gr.Button("Generate all the keys and send public part to server")
240
+
241
+ evaluation_key = gr.Textbox(
242
+ label="Evaluation key (truncated):",
243
+ max_lines=4,
244
+ interactive=False,
245
+ )
246
+
247
+ user_id = gr.Textbox(
248
+ label="",
249
+ max_lines=4,
250
+ interactive=False,
251
+ visible=False
252
+ )
253
+
254
+ gr.Markdown("# Step 2: Provide a message")
255
+ gr.Markdown("## Client side")
256
+ gr.Markdown(
257
+ "Enter a sensitive text message you received and would like to do sentiment analysis on (ideas: the last text message of your boss.... or lover)."
258
+ )
259
+ text = gr.Textbox(label="Enter a message:", value="I really like your work recently")
260
+
261
+ gr.Markdown("# Step 3: Encode the message with the private key")
262
+ b_encode_quantize_text = gr.Button(
263
+ "Encode, quantize and encrypt the text with transformer vectorizer, and send to server"
264
+ )
265
+
266
+ with gr.Row():
267
+ encoding = gr.Textbox(
268
+ label="Transformer representation:",
269
+ max_lines=4,
270
+ interactive=False,
271
+ )
272
+ quantized_encoding = gr.Textbox(
273
+ label="Quantized transformer representation:", max_lines=4, interactive=False
274
+ )
275
+ encrypted_quantized_encoding = gr.Textbox(
276
+ label="Encrypted quantized transformer representation (truncated):",
277
+ max_lines=4,
278
+ interactive=False,
279
+ )
280
+
281
+ gr.Markdown("# Step 4: Run the FHE evaluation")
282
+ gr.Markdown("## Server side")
283
+ gr.Markdown(
284
+ "The encrypted value is received by the server. Thanks to the evaluation key and to FHE, the server can compute the (encrypted) prediction directly over encrypted values. Once the computation is finished, the server returns the encrypted prediction to the client."
285
+ )
286
+
287
+ b_run_fhe = gr.Button("Run FHE execution there")
288
+ encrypted_prediction = gr.Textbox(
289
+ label="Encrypted prediction (truncated):",
290
+ max_lines=4,
291
+ interactive=False,
292
+ )
293
+
294
+ gr.Markdown("# Step 5: Decrypt the sentiment")
295
+ gr.Markdown("## Client side")
296
+ gr.Markdown(
297
+ "The encrypted sentiment is sent back to client, who can finally decrypt it with its private key. Only the client is aware of the original tweet and the prediction."
298
+ )
299
+ b_decrypt_prediction = gr.Button("Decrypt prediction")
300
+
301
+ labels_sentiment = gr.Label(label="Sentiment:")
302
 
303
+ # Button for key generation
304
+ b_gen_key_and_install.click(keygen, inputs=[task_checkbox], outputs=[evaluation_key, user_id])
305
 
306
+ # Button to quantize and encrypt
307
+ b_encode_quantize_text.click(
308
+ encode_quantize_encrypt,
309
+ inputs=[text, user_id],
310
+ outputs=[
311
+ encoding,
312
+ quantized_encoding,
313
+ encrypted_quantized_encoding,
314
+ ],
315
+ )
316
 
317
+ # Button to send the encodings to the server using post at (localhost:8000/predict_sentiment)
318
+ b_run_fhe.click(run_fhe, inputs=[user_id], outputs=[encrypted_prediction])
319
 
320
+ # Button to decrypt the prediction on the client
321
+ b_decrypt_prediction.click(decrypt_prediction, inputs=[user_id], outputs=[labels_sentiment])
322
+ gr.Markdown(
323
+ "The app was built with [Concrete-ML](https://github.com/zama-ai/concrete-ml), a Privacy-Preserving Machine Learning (PPML) open-source set of tools by [Zama](https://zama.ai/). Try it yourself and don't forget to star on Github &#11088;."
324
+ )
325
+ demo.launch(share=False)
server.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from joblib import load
3
+ from concrete.ml.deployment import FHEModelServer
4
+ from pydantic import BaseModel
5
+ import base64
6
+ from pathlib import Path
7
+
8
+ current_dir = Path(__file__).parent
9
+
10
+ # Load the model
11
+ fhe_model = FHEModelServer("deployment/financial_rating")
12
+ fhe_legal_model=FHEModelServer("deployment/legal_rating")
13
+
14
+ class PredictRequest(BaseModel):
15
+ evaluation_key: str
16
+ encrypted_encoding: str
17
+
18
+ # Initialize an instance of FastAPI
19
+ app = FastAPI()
20
+
21
+ # Define the default route
22
+ @app.get("/")
23
+ def root():
24
+ return {"message": "Welcome to Your Sentiment Classification FHE Model Server!"}
25
+
26
+ @app.post("/predict_sentiment")
27
+ def predict_sentiment(query: PredictRequest):
28
+ encrypted_encoding = base64.b64decode(query.encrypted_encoding)
29
+ evaluation_key = base64.b64decode(query.evaluation_key)
30
+ prediction = fhe_model.run(encrypted_encoding, evaluation_key)
31
+
32
+ # Encode base64 the prediction
33
+ encoded_prediction = base64.b64encode(prediction).decode()
34
+ return {"encrypted_prediction": encoded_prediction}
35
+
36
+
37
+ @app.post("/legal_rating")
38
+ def predict_sentiment(query: PredictRequest):
39
+ encrypted_encoding = base64.b64decode(query.encrypted_encoding)
40
+ evaluation_key = base64.b64decode(query.evaluation_key)
41
+ prediction = fhe_legal_model.run(encrypted_encoding, evaluation_key)
42
+
43
+ # Encode base64 the prediction
44
+ encoded_prediction = base64.b64encode(prediction).decode()
45
+ return {"encrypted_prediction": encoded_prediction}