ClairVault / fhe_utils.py
VaultChem's picture
Upload 3 files
6c570a1 verified
raw
history blame contribute delete
No virus
6.87 kB
import sys
import os
import pdb
import numpy as np
import random
import json
import shutil
import time
from scipy.stats import pearsonr
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import xgboost as xgb
from tqdm import tqdm
random.seed(42)
import gzip
import numpy as np
import pandas as pd
import requests
from io import BytesIO
from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.sklearn import DecisionTreeClassifier as DecisionTreeClassifierZAMA
from concrete.ml.sklearn import LinearSVC as LinearSVCZAMA
from sklearn.svm import LinearSVR as LinearSVR
import time
from shutil import copyfile
from tempfile import TemporaryDirectory
import pickle
import os
import time
import numpy as np
def convert_numpy(obj):
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj
class OnDiskNetwork:
"""Simulate a network on disk."""
def __init__(self):
# Create 3 temporary folder for server, client and dev with tempfile
self.server_dir = TemporaryDirectory()
self.client_dir = TemporaryDirectory()
self.dev_dir = TemporaryDirectory()
def client_send_evaluation_key_to_server(self, serialized_evaluation_keys):
"""Send the public key to the server."""
with open(self.server_dir.name + "/serialized_evaluation_keys.ekl", "wb") as f:
f.write(serialized_evaluation_keys)
def client_send_input_to_server_for_prediction(self, encrypted_input):
"""Send the input to the server and execute on the server in FHE."""
with open(self.server_dir.name + "/serialized_evaluation_keys.ekl", "rb") as f:
serialized_evaluation_keys = f.read()
time_begin = time.time()
encrypted_prediction = FHEModelServer(self.server_dir.name).run(
encrypted_input, serialized_evaluation_keys
)
time_end = time.time()
with open(self.server_dir.name + "/encrypted_prediction.enc", "wb") as f:
f.write(encrypted_prediction)
return time_end - time_begin
def dev_send_model_to_server(self):
"""Send the model to the server."""
copyfile(
self.dev_dir.name + "/server.zip", self.server_dir.name + "/server.zip"
)
def server_send_encrypted_prediction_to_client(self):
"""Send the encrypted prediction to the client."""
with open(self.server_dir.name + "/encrypted_prediction.enc", "rb") as f:
encrypted_prediction = f.read()
return encrypted_prediction
def dev_send_clientspecs_and_modelspecs_to_client(self):
"""Send the clientspecs and evaluation key to the client."""
copyfile(
self.dev_dir.name + "/client.zip", self.client_dir.name + "/client.zip"
)
def cleanup(self):
"""Clean up the temporary folders."""
self.server_dir.cleanup()
self.client_dir.cleanup()
self.dev_dir.cleanup()
def generate_fingerprint(smiles, radius=2, bits=512):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return np.nan
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=bits)
return np.array(fp)
def train_xgb_regressor(X_train, y_train, param_grid=None, verbose=10):
if param_grid is None:
param_grid = {
"max_depth": [3, 6],
"learning_rate": [0.01, 0.1, 0.2],
"n_estimators": [20],
"colsample_bytree": [0.3, 0.7],
}
xgb_regressor = xgb.XGBRegressor(objective="reg:squarederror")
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
grid_search = GridSearchCV(
estimator=xgb_regressor,
param_grid=param_grid,
cv=kfold,
verbose=verbose,
n_jobs=-1,
)
grid_search.fit(X_train, y_train)
return (
grid_search.best_params_,
grid_search.best_score_,
grid_search.best_estimator_,
)
def evaluate_model(model, X_test, y_test):
y_pred = model.predict(X_test)
pearsonr_score = pearsonr(y_test, y_pred).statistic
return pearsonr_score
def setup_network(model_dev):
network = OnDiskNetwork()
fhemodel_dev = FHEModelDev(network.dev_dir.name, model_dev)
fhemodel_dev.save(via_mlir=True)
return network, fhemodel_dev
def copy_directory(source, destination="deployment"):
try:
# Check if the source directory exists
if not os.path.exists(source):
return False, "Source directory does not exist."
# Check if the destination directory exists
if not os.path.exists(destination):
os.makedirs(destination)
# Copy each item in the source directory
for item in os.listdir(source):
s = os.path.join(source, item)
d = os.path.join(destination, item)
if os.path.isdir(s):
shutil.copytree(
s, d, dirs_exist_ok=True
) # dirs_exist_ok is available from Python 3.8
else:
shutil.copy2(s, d)
return True, None
except Exception as e:
return False, str(e)
def client_server_interaction(network, fhemodel_client, X_client):
decrypted_predictions = []
execution_time = []
for i in tqdm(range(X_client.shape[0])):
clear_input = X_client[[i], :]
encrypted_input = fhemodel_client.quantize_encrypt_serialize(clear_input)
execution_time.append(
network.client_send_input_to_server_for_prediction(encrypted_input)
)
encrypted_prediction = network.server_send_encrypted_prediction_to_client()
decrypted_prediction = fhemodel_client.deserialize_decrypt_dequantize(
encrypted_prediction
)[0]
decrypted_predictions.append(decrypted_prediction)
#pdb.set_trace()
return decrypted_predictions, execution_time
def train_zama(X_train, y_train):
model_dev = LinearSVCZAMA()
# LinearSVCZAMA()
# DecisionTreeClassifierZAMA()
print("Training Zama model...")
model_dev.fit(X_train, y_train)
print("compiling model...")
model_dev.compile(X_train)
print("done")
return model_dev
def time_prediction(model, X_sample):
time_begin = time.time()
y_pred_fhe = model.predict(X_sample, fhe="execute")
time_end = time.time()
return time_end - time_begin
def setup_client(network, key_dir):
fhemodel_client = FHEModelClient(network.client_dir.name, key_dir=key_dir)
fhemodel_client.generate_private_and_evaluation_keys()
serialized_evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
return fhemodel_client, serialized_evaluation_keys