Spaces:
Runtime error
Runtime error
NimaBoscarino
commited on
Commit
•
04a30fc
1
Parent(s):
061b29d
WIP: Substra orchestrator
Browse files- app.py +12 -0
- requirements.txt +3 -0
- substra_launcher.py +17 -0
- substra_template/Dockerfile +27 -0
- substra_template/__init__.py +0 -0
- substra_template/app.py +13 -0
- substra_template/mlflow_live_performances.py +45 -0
- substra_template/requirements.txt +12 -0
- substra_template/run.sh +3 -0
- substra_template/run_compute_plan.py +34 -0
- substra_template/substra_helpers/__init__.py +0 -0
- substra_template/substra_helpers/dataset.py +29 -0
- substra_template/substra_helpers/dataset_assets/description.md +18 -0
- substra_template/substra_helpers/dataset_assets/opener.py +20 -0
- substra_template/substra_helpers/model.py +25 -0
- substra_template/substra_helpers/substra_runner.py +194 -0
- tests/test_substra_launcher.py +25 -0
- tests/test_substra_runner.py +55 -0
app.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import HfApi
|
3 |
+
from substra_launcher import launch_substra_space
|
4 |
+
|
5 |
+
api = HfApi()
|
6 |
+
|
7 |
+
gr.Interface(
|
8 |
+
fn=lambda *args, **kwargs: launch_substra_space(api, *args, **kwargs),
|
9 |
+
inputs="text",
|
10 |
+
outputs="text",
|
11 |
+
examples=[["NimaBoscarino/substra-test"]]
|
12 |
+
).launch()
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pytest
|
3 |
+
huggingface_hub
|
substra_launcher.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi, RepoUrl
|
2 |
+
|
3 |
+
|
4 |
+
def launch_substra_space(hf_api: HfApi, repo_id: str) -> RepoUrl:
|
5 |
+
repo_url = hf_api.create_repo(
|
6 |
+
repo_id=repo_id,
|
7 |
+
repo_type="space",
|
8 |
+
space_sdk="docker"
|
9 |
+
)
|
10 |
+
|
11 |
+
hf_api.upload_folder(
|
12 |
+
repo_id=repo_id,
|
13 |
+
repo_type="space",
|
14 |
+
folder_path="substra_template/"
|
15 |
+
)
|
16 |
+
|
17 |
+
return repo_url
|
substra_template/Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10-slim-bullseye
|
2 |
+
|
3 |
+
# Set the working directory to /code
|
4 |
+
WORKDIR /code
|
5 |
+
|
6 |
+
# Copy the current directory contents into the container at /code
|
7 |
+
COPY ./requirements.txt /code/requirements.txt
|
8 |
+
|
9 |
+
# Install requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
11 |
+
|
12 |
+
# Set up a new user named "user" with user ID 1000
|
13 |
+
RUN useradd -m -u 1000 user
|
14 |
+
# Switch to the "user" user
|
15 |
+
USER user
|
16 |
+
# Set home to the user's home directory
|
17 |
+
ENV HOME=/home/user \
|
18 |
+
PATH=/home/user/.local/bin:$PATH
|
19 |
+
|
20 |
+
# Set the working directory to the user's home directory
|
21 |
+
WORKDIR $HOME/app
|
22 |
+
|
23 |
+
# Copy the current directory contents into the container at $HOME/app setting the owner to the user
|
24 |
+
COPY --chown=user . $HOME/app
|
25 |
+
|
26 |
+
EXPOSE 7860
|
27 |
+
CMD ["bash", "-c", "/code/run.sh"]
|
substra_template/__init__.py
ADDED
File without changes
|
substra_template/app.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
def read_logs():
|
5 |
+
with open("output.log", "r") as f:
|
6 |
+
return f.read()
|
7 |
+
|
8 |
+
|
9 |
+
with gr.Blocks() as demo:
|
10 |
+
logs = gr.Plot()
|
11 |
+
demo.load(read_logs, None, logs, every=1)
|
12 |
+
|
13 |
+
demo.queue().launch()
|
substra_template/mlflow_live_performances.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
from mlflow import log_metric
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
from glob import glob
|
8 |
+
|
9 |
+
TIMEOUT = 60 # Number of seconds to stop the script after the last update of the json file
|
10 |
+
POLLING_FREQUENCY = 10 # Try to read the updates in the file every 10 seconds
|
11 |
+
|
12 |
+
# Wait for the file to be found
|
13 |
+
start = time.time()
|
14 |
+
while not len(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))) > 0:
|
15 |
+
time.sleep(POLLING_FREQUENCY)
|
16 |
+
if time.time() - start >= TIMEOUT:
|
17 |
+
raise TimeoutError("The performance file does not exist, maybe no test task has been executed yet.")
|
18 |
+
|
19 |
+
path_to_json = Path(glob(str(Path("local-worker") / "live_performances" / "*" / "performances.json"))[0])
|
20 |
+
|
21 |
+
logged_rows = []
|
22 |
+
last_update = time.time()
|
23 |
+
|
24 |
+
while (time.time() - last_update) <= TIMEOUT:
|
25 |
+
|
26 |
+
if last_update == os.path.getmtime(str(path_to_json)):
|
27 |
+
time.sleep(POLLING_FREQUENCY)
|
28 |
+
continue
|
29 |
+
|
30 |
+
last_update = os.path.getmtime(str(path_to_json))
|
31 |
+
|
32 |
+
time.sleep(1) # Waiting for the json to be fully written
|
33 |
+
dict_perf = json.load(path_to_json.open())
|
34 |
+
|
35 |
+
df = pd.DataFrame(dict_perf)
|
36 |
+
|
37 |
+
for _, row in df.iterrows():
|
38 |
+
if row["testtask_key"] in logged_rows:
|
39 |
+
continue
|
40 |
+
|
41 |
+
logged_rows.append(row["testtask_key"])
|
42 |
+
|
43 |
+
step = int(row["round_idx"]) if row["round_idx"] is not None else int(row["testtask_rank"])
|
44 |
+
|
45 |
+
log_metric(f"{row['metric_name']}_{row['worker']}", row["performance"], step)
|
substra_template/requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
substrafl
|
3 |
+
datasets
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
scikit-learn
|
7 |
+
numpy==1.23.0
|
8 |
+
Pillow
|
9 |
+
transformers
|
10 |
+
matplotlib
|
11 |
+
pandas
|
12 |
+
mlflow
|
substra_template/run.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
PYTHONPATH=/Users/nima/Work/society-ethics/substra/substra_template python run_compute_plan.py &
|
2 |
+
PYTHONPATH=/Users/nima/Work/society-ethics/substra/substra_template python mlflow_live_performances.py &
|
3 |
+
mlflow ui
|
substra_template/run_compute_plan.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from substra_helpers.substra_runner import SubstraRunner, algo_generator
|
2 |
+
from substra_helpers.model import CNN
|
3 |
+
from substra_helpers.dataset import TorchDataset
|
4 |
+
from substrafl.strategies import FedAvg
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
seed = 42
|
9 |
+
torch.manual_seed(seed)
|
10 |
+
model = CNN()
|
11 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
12 |
+
criterion = torch.nn.CrossEntropyLoss()
|
13 |
+
|
14 |
+
runner = SubstraRunner()
|
15 |
+
runner.set_up_clients()
|
16 |
+
runner.prepare_data()
|
17 |
+
runner.register_data()
|
18 |
+
runner.register_metric()
|
19 |
+
|
20 |
+
runner.algorithm = algo_generator(
|
21 |
+
model=model,
|
22 |
+
criterion=criterion,
|
23 |
+
optimizer=optimizer,
|
24 |
+
index_generator=runner.index_generator,
|
25 |
+
dataset=TorchDataset,
|
26 |
+
seed=seed
|
27 |
+
)()
|
28 |
+
|
29 |
+
runner.strategy = FedAvg()
|
30 |
+
|
31 |
+
runner.set_aggregation()
|
32 |
+
runner.set_testing()
|
33 |
+
|
34 |
+
runner.run_compute_plan()
|
substra_template/substra_helpers/__init__.py
ADDED
File without changes
|
substra_template/substra_helpers/dataset.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils import data
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class TorchDataset(data.Dataset):
|
8 |
+
def __init__(self, datasamples, is_inference: bool):
|
9 |
+
self.x = datasamples["image"]
|
10 |
+
self.y = datasamples["label"]
|
11 |
+
self.is_inference = is_inference
|
12 |
+
|
13 |
+
def __getitem__(self, idx):
|
14 |
+
|
15 |
+
if self.is_inference:
|
16 |
+
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
|
17 |
+
return x
|
18 |
+
|
19 |
+
else:
|
20 |
+
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
|
21 |
+
|
22 |
+
y = torch.tensor(self.y[idx]).type(torch.int64)
|
23 |
+
y = F.one_hot(y, 10)
|
24 |
+
y = y.type(torch.float32)
|
25 |
+
|
26 |
+
return x, y
|
27 |
+
|
28 |
+
def __len__(self):
|
29 |
+
return len(self.x)
|
substra_template/substra_helpers/dataset_assets/description.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mnist
|
2 |
+
|
3 |
+
This dataset is [THE MNIST DATABASE of handwritten digits](http://yann.lecun.com/exdb/mnist/).
|
4 |
+
|
5 |
+
The target is the number (0 -> 9) represented by the pixels.
|
6 |
+
|
7 |
+
## Data repartition
|
8 |
+
|
9 |
+
### Train and test
|
10 |
+
|
11 |
+
### Split data between organizations
|
12 |
+
|
13 |
+
## Opener usage
|
14 |
+
|
15 |
+
The opener exposes 2 methods:
|
16 |
+
|
17 |
+
- `get_data` returns a dictionary containing the images and the labels as numpy arrays
|
18 |
+
- `fake_data` returns a fake data sample of images and labels in a dict
|
substra_template/substra_helpers/dataset_assets/opener.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import substratools as tools
|
3 |
+
from datasets import load_from_disk
|
4 |
+
from transformers import ImageFeatureExtractionMixin
|
5 |
+
|
6 |
+
|
7 |
+
class MnistOpener(tools.Opener):
|
8 |
+
def fake_data(self, n_samples=None):
|
9 |
+
N_SAMPLES = n_samples if n_samples and n_samples <= 100 else 100
|
10 |
+
|
11 |
+
fake_images = np.random.randint(256, size=(N_SAMPLES, 28, 28))
|
12 |
+
|
13 |
+
fake_labels = np.random.randint(10, size=N_SAMPLES)
|
14 |
+
|
15 |
+
data = {"image": fake_images, "label": fake_labels}
|
16 |
+
|
17 |
+
return data
|
18 |
+
|
19 |
+
def get_data(self, folders):
|
20 |
+
return load_from_disk(folders[0])
|
substra_template/substra_helpers/model.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
# TODO: Would be cool to use a simple Transformer model... then I could use the Trainer API 👀
|
6 |
+
class CNN(nn.Module):
|
7 |
+
def __init__(self):
|
8 |
+
super(CNN, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
|
10 |
+
self.conv2 = nn.Conv2d(32, 32, kernel_size=5)
|
11 |
+
self.conv3 = nn.Conv2d(32, 64, kernel_size=5)
|
12 |
+
self.fc1 = nn.Linear(3 * 3 * 64, 256)
|
13 |
+
self.fc2 = nn.Linear(256, 10)
|
14 |
+
|
15 |
+
def forward(self, x, eval=False):
|
16 |
+
x = F.relu(self.conv1(x))
|
17 |
+
x = F.relu(F.max_pool2d(self.conv2(x), 2))
|
18 |
+
x = F.dropout(x, p=0.5, training=not eval)
|
19 |
+
x = F.relu(F.max_pool2d(self.conv3(x), 2))
|
20 |
+
x = F.dropout(x, p=0.5, training=not eval)
|
21 |
+
x = x.view(-1, 3 * 3 * 64)
|
22 |
+
x = F.relu(self.fc1(x))
|
23 |
+
x = F.dropout(x, p=0.5, training=not eval)
|
24 |
+
x = self.fc2(x)
|
25 |
+
return F.log_softmax(x, dim=1)
|
substra_template/substra_helpers/substra_runner.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pathlib
|
2 |
+
import shutil
|
3 |
+
from typing import Optional, List
|
4 |
+
|
5 |
+
from substra import Client, BackendType
|
6 |
+
|
7 |
+
from substra.sdk.schemas import (
|
8 |
+
DatasetSpec,
|
9 |
+
Permissions,
|
10 |
+
DataSampleSpec
|
11 |
+
)
|
12 |
+
|
13 |
+
from substrafl.strategies import Strategy
|
14 |
+
from substrafl.dependency import Dependency
|
15 |
+
from substrafl.remote.register import add_metric
|
16 |
+
from substrafl.index_generator import NpIndexGenerator
|
17 |
+
from substrafl.algorithms.pytorch import TorchFedAvgAlgo
|
18 |
+
|
19 |
+
from substrafl.nodes import TrainDataNode, AggregationNode, TestDataNode
|
20 |
+
from substrafl.evaluation_strategy import EvaluationStrategy
|
21 |
+
|
22 |
+
from substrafl.experiment import execute_experiment
|
23 |
+
from substra.sdk.models import ComputePlan
|
24 |
+
|
25 |
+
from datasets import load_dataset, Dataset
|
26 |
+
from sklearn.metrics import accuracy_score
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
import torch
|
30 |
+
|
31 |
+
|
32 |
+
class SubstraRunner:
|
33 |
+
def __init__(self):
|
34 |
+
self.num_clients = 3
|
35 |
+
self.clients = {}
|
36 |
+
self.algo_provider: Optional[Client] = None
|
37 |
+
|
38 |
+
self.datasets: List[Dataset] = []
|
39 |
+
self.test_dataset: Optional[Dataset] = None
|
40 |
+
self.path = pathlib.Path(__file__).parent.resolve()
|
41 |
+
|
42 |
+
self.dataset_keys = {}
|
43 |
+
self.train_data_sample_keys = {}
|
44 |
+
self.test_data_sample_keys = {}
|
45 |
+
|
46 |
+
self.metric_key: Optional[str] = None
|
47 |
+
|
48 |
+
NUM_UPDATES = 100
|
49 |
+
BATCH_SIZE = 32
|
50 |
+
|
51 |
+
self.index_generator = NpIndexGenerator(
|
52 |
+
batch_size=BATCH_SIZE,
|
53 |
+
num_updates=NUM_UPDATES,
|
54 |
+
)
|
55 |
+
|
56 |
+
self.algorithm: Optional[TorchFedAvgAlgo] = None
|
57 |
+
self.strategy: Optional[Strategy] = None
|
58 |
+
|
59 |
+
self.aggregation_node: Optional[AggregationNode] = None
|
60 |
+
self.train_data_nodes = list()
|
61 |
+
self.test_data_nodes = list()
|
62 |
+
self.eval_strategy: Optional[EvaluationStrategy] = None
|
63 |
+
|
64 |
+
self.NUM_ROUNDS = 3
|
65 |
+
self.compute_plan: Optional[ComputePlan] = None
|
66 |
+
|
67 |
+
self.experiment_folder = self.path / "experiment_summaries"
|
68 |
+
|
69 |
+
def set_up_clients(self):
|
70 |
+
self.algo_provider = Client(backend_type=BackendType.LOCAL_SUBPROCESS)
|
71 |
+
|
72 |
+
self.clients = {
|
73 |
+
c.organization_info().organization_id: c
|
74 |
+
for c in [Client(backend_type=BackendType.LOCAL_SUBPROCESS) for _ in range(self.num_clients - 1)]
|
75 |
+
}
|
76 |
+
|
77 |
+
def prepare_data(self):
|
78 |
+
dataset = load_dataset("mnist", split="train").shuffle()
|
79 |
+
self.datasets = [dataset.shard(num_shards=self.num_clients - 1, index=i) for i in range(self.num_clients - 1)]
|
80 |
+
|
81 |
+
self.test_dataset = load_dataset("mnist", split="test")
|
82 |
+
|
83 |
+
data_path = self.path / "data"
|
84 |
+
if data_path.exists() and data_path.is_dir():
|
85 |
+
shutil.rmtree(data_path)
|
86 |
+
|
87 |
+
for i, client_id in enumerate(self.clients):
|
88 |
+
ds = self.datasets[i]
|
89 |
+
ds.save_to_disk(data_path / client_id / "train")
|
90 |
+
self.test_dataset.save_to_disk(data_path / client_id / "test")
|
91 |
+
|
92 |
+
def register_data(self):
|
93 |
+
for client_id, client in self.clients.items():
|
94 |
+
permissions_dataset = Permissions(public=False, authorized_ids=[
|
95 |
+
self.algo_provider.organization_info().organization_id
|
96 |
+
])
|
97 |
+
|
98 |
+
dataset = DatasetSpec(
|
99 |
+
name="MNIST",
|
100 |
+
type="npy",
|
101 |
+
data_opener=self.path / pathlib.Path("dataset_assets/opener.py"),
|
102 |
+
description=self.path / pathlib.Path("dataset_assets/description.md"),
|
103 |
+
permissions=permissions_dataset,
|
104 |
+
logs_permission=permissions_dataset,
|
105 |
+
)
|
106 |
+
self.dataset_keys[client_id] = client.add_dataset(dataset)
|
107 |
+
assert self.dataset_keys[client_id], "Missing dataset key"
|
108 |
+
|
109 |
+
self.train_data_sample_keys[client_id] = client.add_data_sample(DataSampleSpec(
|
110 |
+
data_manager_keys=[self.dataset_keys[client_id]],
|
111 |
+
path=self.path / "data" / client_id / "train",
|
112 |
+
))
|
113 |
+
|
114 |
+
data_sample = DataSampleSpec(
|
115 |
+
data_manager_keys=[self.dataset_keys[client_id]],
|
116 |
+
path=self.path / "data" / client_id / "test",
|
117 |
+
)
|
118 |
+
self.test_data_sample_keys[client_id] = client.add_data_sample(data_sample)
|
119 |
+
|
120 |
+
def register_metric(self):
|
121 |
+
permissions_metric = Permissions(
|
122 |
+
public=False,
|
123 |
+
authorized_ids=[
|
124 |
+
self.algo_provider.organization_info().organization_id
|
125 |
+
] + list(self.clients.keys())
|
126 |
+
)
|
127 |
+
|
128 |
+
metric_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "scikit-learn==1.1.1"])
|
129 |
+
|
130 |
+
def accuracy(datasamples, predictions_path):
|
131 |
+
y_true = datasamples["label"]
|
132 |
+
y_pred = np.load(predictions_path)
|
133 |
+
|
134 |
+
return accuracy_score(y_true, np.argmax(y_pred, axis=1))
|
135 |
+
|
136 |
+
self.metric_key = add_metric(
|
137 |
+
client=self.algo_provider,
|
138 |
+
metric_function=accuracy,
|
139 |
+
permissions=permissions_metric,
|
140 |
+
dependencies=metric_deps,
|
141 |
+
)
|
142 |
+
|
143 |
+
def set_aggregation(self):
|
144 |
+
self.aggregation_node = AggregationNode(self.algo_provider.organization_info().organization_id)
|
145 |
+
|
146 |
+
for org_id in self.clients:
|
147 |
+
train_data_node = TrainDataNode(
|
148 |
+
organization_id=org_id,
|
149 |
+
data_manager_key=self.dataset_keys[org_id],
|
150 |
+
data_sample_keys=[self.train_data_sample_keys[org_id]],
|
151 |
+
)
|
152 |
+
self.train_data_nodes.append(train_data_node)
|
153 |
+
|
154 |
+
def set_testing(self):
|
155 |
+
for org_id in self.clients:
|
156 |
+
test_data_node = TestDataNode(
|
157 |
+
organization_id=org_id,
|
158 |
+
data_manager_key=self.dataset_keys[org_id],
|
159 |
+
test_data_sample_keys=[self.test_data_sample_keys[org_id]],
|
160 |
+
metric_keys=[self.metric_key],
|
161 |
+
)
|
162 |
+
self.test_data_nodes.append(test_data_node)
|
163 |
+
|
164 |
+
self.eval_strategy = EvaluationStrategy(test_data_nodes=self.test_data_nodes, rounds=1)
|
165 |
+
|
166 |
+
def run_compute_plan(self):
|
167 |
+
algo_deps = Dependency(pypi_dependencies=["numpy==1.23.1", "torch==1.11.0"])
|
168 |
+
|
169 |
+
self.compute_plan = execute_experiment(
|
170 |
+
client=self.algo_provider,
|
171 |
+
algo=self.algorithm,
|
172 |
+
strategy=self.strategy,
|
173 |
+
train_data_nodes=self.train_data_nodes,
|
174 |
+
evaluation_strategy=self.eval_strategy,
|
175 |
+
aggregation_node=self.aggregation_node,
|
176 |
+
num_rounds=self.NUM_ROUNDS,
|
177 |
+
experiment_folder=self.experiment_folder,
|
178 |
+
dependencies=algo_deps,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
def algo_generator(model, criterion, optimizer, index_generator, dataset, seed):
|
183 |
+
class MyAlgo(TorchFedAvgAlgo):
|
184 |
+
def __init__(self):
|
185 |
+
super().__init__(
|
186 |
+
model=model,
|
187 |
+
criterion=criterion,
|
188 |
+
optimizer=optimizer,
|
189 |
+
index_generator=index_generator,
|
190 |
+
dataset=dataset,
|
191 |
+
seed=seed,
|
192 |
+
)
|
193 |
+
|
194 |
+
return MyAlgo
|
tests/test_substra_launcher.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from unittest.mock import Mock
|
3 |
+
|
4 |
+
from substra_launcher import launch_substra_space
|
5 |
+
|
6 |
+
|
7 |
+
class TestSubstraLauncher:
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def mock_hf_api(self):
|
11 |
+
mock_hf_api = Mock()
|
12 |
+
mock_hf_api.create_repo = Mock(side_effect=lambda repo_id, *args, **kwargs: f"https://hf.space/{repo_id}")
|
13 |
+
return mock_hf_api
|
14 |
+
|
15 |
+
def test_launch_substra_space(self, mock_hf_api):
|
16 |
+
repo_id = "user/space"
|
17 |
+
repo_link = launch_substra_space(mock_hf_api, repo_id=repo_id)
|
18 |
+
mock_hf_api.create_repo.assert_called_once_with(
|
19 |
+
repo_id=repo_id, repo_type="space", space_sdk="docker"
|
20 |
+
)
|
21 |
+
mock_hf_api.upload_folder.assert_called_once_with(
|
22 |
+
repo_id=repo_id, repo_type="space", folder_path="substra_template/"
|
23 |
+
)
|
24 |
+
|
25 |
+
assert repo_link == f"https://hf.space/{repo_id}"
|
tests/test_substra_runner.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from unittest.mock import Mock, call
|
3 |
+
from datasets import Dataset
|
4 |
+
|
5 |
+
from substra_template.substra_runner import SubstraRunner
|
6 |
+
|
7 |
+
|
8 |
+
class TestSubstraRunner:
|
9 |
+
@pytest.fixture
|
10 |
+
def mock_substra_client_class(self, monkeypatch):
|
11 |
+
mock_substra_client_class = Mock()
|
12 |
+
monkeypatch.setattr("substra_template.substra_runner.Client", mock_substra_client_class)
|
13 |
+
|
14 |
+
return mock_substra_client_class
|
15 |
+
|
16 |
+
@pytest.fixture
|
17 |
+
def mock_load_dataset(self, monkeypatch):
|
18 |
+
mock_load_dataset = Mock()
|
19 |
+
monkeypatch.setattr("substra_template.substra_runner.load_dataset", mock_load_dataset)
|
20 |
+
|
21 |
+
return mock_load_dataset
|
22 |
+
|
23 |
+
def test_set_up_clients(self, mock_substra_client_class):
|
24 |
+
runner = SubstraRunner()
|
25 |
+
runner.set_up_clients()
|
26 |
+
|
27 |
+
mock_substra_client_class.assert_called()
|
28 |
+
|
29 |
+
def test_prepare_data(self, mock_load_dataset):
|
30 |
+
runner = SubstraRunner()
|
31 |
+
runner.prepare_data()
|
32 |
+
|
33 |
+
mock_load_dataset.assert_has_calls(calls=[
|
34 |
+
call("mnist", split="train"),
|
35 |
+
call("mnist", split="test"),
|
36 |
+
], any_order=True)
|
37 |
+
|
38 |
+
assert len(runner.datasets) == runner.num_clients - 1
|
39 |
+
|
40 |
+
def test_register_data(self, mock_load_dataset):
|
41 |
+
runner = SubstraRunner()
|
42 |
+
runner.datasets = [Dataset.from_dict({}) for _ in range(runner.num_clients - 1)]
|
43 |
+
|
44 |
+
runner.register_data()
|
45 |
+
|
46 |
+
def test_register_metric(self):
|
47 |
+
runner = SubstraRunner()
|
48 |
+
runner.set_up_clients()
|
49 |
+
runner.register_metric()
|
50 |
+
|
51 |
+
def test_set_aggregation(self):
|
52 |
+
pass
|
53 |
+
|
54 |
+
def test_set_testing(self):
|
55 |
+
pass
|