batuergun commited on
Commit
abc2497
1 Parent(s): c37150a

model output

Browse files
Files changed (3) hide show
  1. app.py +12 -4
  2. client_server_interface.py +2 -3
  3. server.py +4 -1
app.py CHANGED
@@ -317,8 +317,7 @@ def decrypt_output(user_id):
317
  user_id (int): The current user's ID.
318
 
319
  Returns:
320
- bool: The decrypted output (True if seizure detected, False otherwise)
321
-
322
  """
323
  if user_id == "":
324
  raise gr.Error("Please generate the private key first.")
@@ -345,14 +344,23 @@ def decrypt_output(user_id):
345
  # Deserialize, decrypt and post-process the encrypted output
346
  try:
347
  decrypted_output = client.deserialize_decrypt_post_process(encrypted_output)
 
 
 
 
 
 
 
 
 
 
 
348
  except RuntimeError as e:
349
  logger.error(f"Error during deserialization: {str(e)}")
350
  raise gr.Error("Failed to deserialize the encrypted output. The data might be corrupted or in an unexpected format.")
351
  except Exception as e:
352
  logger.error(f"Unexpected error during decryption: {str(e)}")
353
  raise gr.Error(f"An unexpected error occurred during decryption: {str(e)}")
354
-
355
- return "Seizure detected" if decrypted_output else "No seizure detected"
356
 
357
  def resize_img(img, width=256, height=256):
358
  """Resize the image."""
 
317
  user_id (int): The current user's ID.
318
 
319
  Returns:
320
+ str: The decrypted output message
 
321
  """
322
  if user_id == "":
323
  raise gr.Error("Please generate the private key first.")
 
344
  # Deserialize, decrypt and post-process the encrypted output
345
  try:
346
  decrypted_output = client.deserialize_decrypt_post_process(encrypted_output)
347
+
348
+ # The decrypted output should be a 1D array with 2 elements
349
+ if isinstance(decrypted_output, np.ndarray) and decrypted_output.shape == (2,):
350
+ predicted_class = np.argmax(decrypted_output)
351
+ confidence = decrypted_output[predicted_class]
352
+ result = "Seizure detected" if predicted_class == 1 else "No seizure detected"
353
+ return f"{result} (Confidence: {confidence:.2f})"
354
+ else:
355
+ logger.error(f"Unexpected decrypted output format: {decrypted_output}")
356
+ raise ValueError("Unexpected output format from the model")
357
+
358
  except RuntimeError as e:
359
  logger.error(f"Error during deserialization: {str(e)}")
360
  raise gr.Error("Failed to deserialize the encrypted output. The data might be corrupted or in an unexpected format.")
361
  except Exception as e:
362
  logger.error(f"Unexpected error during decryption: {str(e)}")
363
  raise gr.Error(f"An unexpected error occurred during decryption: {str(e)}")
 
 
364
 
365
  def resize_img(img, width=256, height=256):
366
  """Resize the image."""
client_server_interface.py CHANGED
@@ -146,6 +146,5 @@ class FHEClient:
146
  output = self.client.decrypt(encrypted_output)
147
 
148
  # Post-process the output (if needed)
149
- seizure_detected = self.seizure_detector.post_processing(output)
150
-
151
- return seizure_detected
 
146
  output = self.client.decrypt(encrypted_output)
147
 
148
  # Post-process the output (if needed)
149
+ # Assuming the output is already in the correct format (2-element array)
150
+ return output
 
server.py CHANGED
@@ -114,7 +114,10 @@ def run_fhe(user_id: str = Form()):
114
  # gc.collect()
115
 
116
  # Placeholder output
117
- placeholder_output = np.random.randint(0, 2**64-1, size=(1, 2, 32769), dtype=np.uint64)
 
 
 
118
  encrypted_output = placeholder_output.tobytes()
119
 
120
  fhe_execution_time = round(time.time() - start, 2)
 
114
  # gc.collect()
115
 
116
  # Placeholder output
117
+ # Generate a random 2-element array with values between 0 and 1
118
+ placeholder_output = np.random.rand(2)
119
+ # Ensure the sum of the two elements is 1 (to mimic softmax output)
120
+ placeholder_output = placeholder_output / np.sum(placeholder_output)
121
  encrypted_output = placeholder_output.tobytes()
122
 
123
  fhe_execution_time = round(time.time() - start, 2)