# # ################# | |
# import streamlit as st | |
# import matplotlib.pyplot as plt | |
# import torch | |
# from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, AdamW | |
# from datasets import load_dataset, Dataset | |
# from evaluate import load as load_metric | |
# from torch.utils.data import DataLoader | |
# import pandas as pd | |
# import random | |
# from collections import OrderedDict | |
# import flwr as fl | |
# from logging import INFO, DEBUG | |
# from flwr.common.logger import log | |
# import logging | |
# import re | |
# import plotly.graph_objects as go | |
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
# def load_data(dataset_name, train_size=20, test_size=20, num_clients=2): | |
# raw_datasets = load_dataset(dataset_name) | |
# raw_datasets = raw_datasets.shuffle(seed=42) | |
# del raw_datasets["unsupervised"] | |
# tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
# def tokenize_function(examples): | |
# return tokenizer(examples["text"], truncation=True) | |
# tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
# tokenized_datasets = tokenized_datasets.remove_columns("text") | |
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
# train_datasets = [] | |
# test_datasets = [] | |
# for _ in range(num_clients): | |
# train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
# test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
# train_datasets.append(train_dataset) | |
# test_datasets.append(test_dataset) | |
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
# return train_datasets, test_datasets, data_collator, raw_datasets | |
# def train(net, trainloader, epochs): | |
# optimizer = AdamW(net.parameters(), lr=5e-5) | |
# net.train() | |
# for _ in range(epochs): | |
# for batch in trainloader: | |
# batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
# outputs = net(**batch) | |
# loss = outputs.loss | |
# loss.backward() | |
# optimizer.step() | |
# optimizer.zero_grad() | |
# def test(net, testloader): | |
# metric = load_metric("accuracy") | |
# net.eval() | |
# loss = 0 | |
# for batch in testloader: | |
# batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
# with torch.no_grad(): | |
# outputs = net(**batch) | |
# logits = outputs.logits | |
# loss += outputs.loss.item() | |
# predictions = torch.argmax(logits, dim=-1) | |
# metric.add_batch(predictions=predictions, references=batch["labels"]) | |
# loss /= len(testloader) | |
# accuracy = metric.compute()["accuracy"] | |
# return loss, accuracy | |
# class CustomClient(fl.client.NumPyClient): | |
# def __init__(self, net, trainloader, testloader, client_id): | |
# self.net = net | |
# self.trainloader = trainloader | |
# self.testloader = testloader | |
# self.client_id = client_id | |
# self.losses = [] | |
# self.accuracies = [] | |
# def get_parameters(self, config): | |
# return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
# def set_parameters(self, parameters): | |
# params_dict = zip(self.net.state_dict().keys(), parameters) | |
# state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
# self.net.load_state_dict(state_dict, strict=True) | |
# def fit(self, parameters, config): | |
# log(INFO, f"Client {self.client_id} is starting fit()") | |
# self.set_parameters(parameters) | |
# train(self.net, self.trainloader, epochs=1) | |
# loss, accuracy = test(self.net, self.testloader) | |
# self.losses.append(loss) | |
# self.accuracies.append(accuracy) | |
# log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
# return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} | |
# def evaluate(self, parameters, config): | |
# log(INFO, f"Client {self.client_id} is starting evaluate()") | |
# self.set_parameters(parameters) | |
# loss, accuracy = test(self.net, self.testloader) | |
# log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
# return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} | |
# def plot_metrics(self, round_num, plot_placeholder): | |
# if self.losses and self.accuracies: | |
# plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
# plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
# plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
# fig, ax1 = plt.subplots() | |
# color = 'tab:red' | |
# ax1.set_xlabel('Round') | |
# ax1.set_ylabel('Loss', color=color) | |
# ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
# ax1.tick_params(axis='y', labelcolor=color) | |
# ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
# color = 'tab:blue' | |
# ax2.set_ylabel('Accuracy', color=color) | |
# ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
# ax2.tick_params(axis='y', labelcolor=color) | |
# fig.tight_layout() | |
# plot_placeholder.pyplot(fig) | |
# def read_log_file(log_path='./log.txt'): | |
# with open(log_path, 'r') as file: | |
# log_lines = file.readlines() | |
# return log_lines | |
# def parse_log(log_lines): | |
# rounds = [] | |
# clients = {} | |
# memory_usage = [] | |
# round_pattern = re.compile(r'ROUND (\d+)') | |
# client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') | |
# memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') | |
# current_round = None | |
# for line in log_lines: | |
# round_match = round_pattern.search(line) | |
# client_match = client_pattern.search(line) | |
# memory_match = memory_pattern.search(line) | |
# if round_match: | |
# current_round = int(round_match.group(1)) | |
# rounds.append(current_round) | |
# elif client_match: | |
# client_id = int(client_match.group(1)) | |
# log_level = client_match.group(2) | |
# message = client_match.group(3) | |
# if client_id not in clients: | |
# clients[client_id] = {'rounds': [], 'messages': []} | |
# clients[client_id]['rounds'].append(current_round) | |
# clients[client_id]['messages'].append((log_level, message)) | |
# elif memory_match: | |
# memory_usage.append(float(memory_match.group(1))) | |
# return rounds, clients, memory_usage | |
# def plot_metrics(rounds, clients, memory_usage): | |
# st.write("## Metrics Overview") | |
# st.write("### Memory Usage") | |
# plt.figure() | |
# plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') | |
# plt.xlabel('Step') | |
# plt.ylabel('Memory Usage (GB)') | |
# plt.legend() | |
# st.pyplot(plt) | |
# for client_id, data in clients.items(): | |
# st.write(f"### Client {client_id} Metrics") | |
# info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] | |
# debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] | |
# st.write("#### INFO Messages") | |
# for msg in info_messages: | |
# st.write(msg) | |
# st.write("#### DEBUG Messages") | |
# for msg in debug_messages: | |
# st.write(msg) | |
# # Placeholder for actual loss and accuracy values, assuming they're included in the messages | |
# losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] | |
# accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] | |
# if losses: | |
# plt.figure() | |
# plt.plot(data['rounds'], losses, label='Loss') | |
# plt.xlabel('Round') | |
# plt.ylabel('Loss') | |
# plt.legend() | |
# st.pyplot(plt) | |
# if accuracies: | |
# plt.figure() | |
# plt.plot(data['rounds'], accuracies, label='Accuracy') | |
# plt.xlabel('Round') | |
# plt.ylabel('Accuracy') | |
# plt.legend() | |
# st.pyplot(plt) | |
# def read_log_file2(): | |
# with open("./log.txt", "r") as file: | |
# return file.read() | |
# def main(): | |
# st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
# logs = read_log_file2() | |
# # cleanLogs = # Define a pattern to match relevant log entries | |
# pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE) | |
# # Filter the log data | |
# filtered_logs = [line for line in logs.splitlines() if pattern.search(line)] | |
# st.markdown(filtered_logs) | |
# # Provide a download button for the logs | |
# st.download_button( | |
# label="Download Logs", | |
# data="\n".join(filtered_logs), | |
# file_name="./log.txt", | |
# mime="text/plain" | |
# ) | |
# dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
# model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased"]) | |
# NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
# NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
# train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS) | |
# trainloaders = [] | |
# testloaders = [] | |
# clients = [] | |
# for i in range(NUM_CLIENTS): | |
# st.write(f"### Client {i+1} Datasets") | |
# train_df = pd.DataFrame(train_datasets[i]) | |
# test_df = pd.DataFrame(test_datasets[i]) | |
# st.write("#### Train Dataset (Words)") | |
# st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) | |
# st.write("#### Train Dataset (Tokens)") | |
# edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
# st.write("#### Test Dataset (Words)") | |
# st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) | |
# st.write("#### Test Dataset (Tokens)") | |
# edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
# edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
# edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
# trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
# testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
# trainloaders.append(trainloader) | |
# testloaders.append(testloader) | |
# net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
# client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
# clients.append(client) | |
# if st.button("Start Training"): | |
# def client_fn(cid): | |
# return clients[int(cid)].to_client() | |
# def weighted_average(metrics): | |
# accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
# losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
# examples = [num_examples for num_examples, _ in metrics] | |
# return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
# strategy = fl.server.strategy.FedAvg( | |
# fraction_fit=1.0, | |
# fraction_evaluate=1.0, | |
# evaluate_metrics_aggregation_fn=weighted_average, | |
# ) | |
# for round_num in range(NUM_ROUNDS): | |
# st.write(f"### Round {round_num + 1} ✅") | |
# logs = read_log_file2() | |
# filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)] | |
# filtered_logs = "\n".join(filtered_log_list) | |
# st.markdown(filtered_logs) | |
# # Provide a download button for the logs | |
# # st.download_button( | |
# # label="Download Logs", | |
# # data=logs, | |
# # file_name="./log.txt", | |
# # mime="text/plain" | |
# # ) | |
# # # Extract relevant data | |
# accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}") | |
# loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}") | |
# accuracy_matches = accuracy_pattern.findall(filtered_logs) | |
# loss_matches = loss_pattern.findall(filtered_logs) | |
# rounds = [int(match[0]) for match in accuracy_matches] | |
# accuracies = [float(match[1]) for match in accuracy_matches] | |
# losses = [float(match[1]) for match in loss_matches] | |
# # Create accuracy plot | |
# accuracy_fig = go.Figure() | |
# accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy')) | |
# accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy') | |
# # Create loss plot | |
# loss_fig = go.Figure() | |
# loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss')) | |
# loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss') | |
# # Display plots in Streamlit | |
# st.plotly_chart(accuracy_fig) | |
# st.plotly_chart(loss_fig) | |
# # Display data table | |
# data = { | |
# 'Round': rounds, | |
# 'Accuracy': accuracies, | |
# 'Loss': losses | |
# } | |
# df = pd.DataFrame(data) | |
# st.write("## Training Metrics") | |
# st.table(df) | |
# plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
# fl.simulation.start_simulation( | |
# client_fn=client_fn, | |
# num_clients=NUM_CLIENTS, | |
# config=fl.server.ServerConfig(num_rounds=1), | |
# strategy=strategy, | |
# client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, | |
# ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} | |
# ) | |
# for i, client in enumerate(clients): | |
# client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
# st.write(" ") | |
# st.success("Training completed successfully!") | |
# # Display final metrics | |
# st.write("## Final Client Metrics") | |
# for client in clients: | |
# st.write(f"### Client {client.client_id}") | |
# if client.losses and client.accuracies: | |
# st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
# st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
# client.plot_metrics(NUM_ROUNDS, st.empty()) | |
# else: | |
# st.write("No metrics available.") | |
# st.write(" ") | |
# # Display log.txt content | |
# st.write("## Training Log") | |
# st.write(read_log_file2()) | |
# st.write("## Training Log Analysis") | |
# log_lines = read_log_file() | |
# rounds, clients, memory_usage = parse_log(log_lines) | |
# plot_metrics(rounds, clients, memory_usage) | |
# else: | |
# st.write("Click the 'Start Training' button to start the training process.") | |
# if __name__ == "__main__": | |
# main() | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
from datasets import load_dataset, Dataset | |
from evaluate import load as load_metric | |
from torch.utils.data import DataLoader | |
import pandas as pd | |
import random | |
from collections import OrderedDict | |
import flwr as fl | |
from logging import INFO, DEBUG | |
from flwr.common.logger import log | |
import logging | |
import re | |
import plotly.graph_objects as go | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
fl.common.logger.configure(identifier="myFlowerExperiment", filename="./log.txt") | |
class CustomDataCollator: | |
def __init__(self, pad_token_id=0): | |
self.pad_token_id = pad_token_id | |
def __call__(self, features): | |
max_length = max(len(f["input_ids"]) for f in features) | |
for f in features: | |
f['input_ids'] += [self.pad_token_id] * (max_length - len(f['input_ids'])) | |
batch = {k: torch.tensor([f[k] for f in features]) for k in features[0].keys()} | |
return batch | |
def load_data(dataset_name, train_size=20, test_size=20, num_clients=2, use_utf8=False, model_name="bert-base-uncased"): | |
raw_datasets = load_dataset(dataset_name) | |
raw_datasets = raw_datasets.shuffle(seed=42) | |
del raw_datasets["unsupervised"] | |
if model_name == "google/byt5-small": | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
def utf8_encode_function(examples): | |
examples["input_ids"] = [tokenizer(text.encode('utf-8'), return_tensors="pt")["input_ids"].squeeze().tolist() for text in examples["text"]] | |
return examples | |
tokenized_datasets = raw_datasets.map(utf8_encode_function, batched=True) | |
tokenized_datasets = tokenized_datasets.remove_columns("text") | |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], truncation=True) | |
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
tokenized_datasets = tokenized_datasets.remove_columns("text") | |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
train_datasets = [] | |
test_datasets = [] | |
for _ in range(num_clients): | |
train_dataset = tokenized_datasets["train"].select(random.sample(range(len(tokenized_datasets["train"])), train_size)) | |
test_dataset = tokenized_datasets["test"].select(random.sample(range(len(tokenized_datasets["test"])), test_size)) | |
train_datasets.append(train_dataset) | |
test_datasets.append(test_dataset) | |
data_collator = CustomDataCollator(pad_token_id=tokenizer.pad_token_id) | |
return train_datasets, test_datasets, data_collator, raw_datasets | |
def train(net, trainloader, epochs): | |
optimizer = AdamW(net.parameters(), lr=5e-5) | |
net.train() | |
for _ in range(epochs): | |
for batch in trainloader: | |
batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
outputs = net(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
def test(net, testloader): | |
metric = load_metric("accuracy") | |
net.eval() | |
loss = 0 | |
for batch in testloader: | |
batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
with torch.no_grad(): | |
outputs = net(**batch) | |
logits = outputs.logits | |
loss += outputs.loss.item() | |
predictions = torch.argmax(logits, dim=-1) | |
metric.add_batch(predictions=predictions, references=batch["labels"]) | |
loss /= len(testloader) | |
accuracy = metric.compute()["accuracy"] | |
return loss, accuracy | |
class CustomClient(fl.client.NumPyClient): | |
def __init__(self, net, trainloader, testloader, client_id): | |
self.net = net | |
self.trainloader = trainloader | |
self.testloader = testloader | |
self.client_id = client_id | |
self.losses = [] | |
self.accuracies = [] | |
def get_parameters(self, config): | |
return [val.cpu().numpy() for _, val in self.net.state_dict().items()] | |
def set_parameters(self, parameters): | |
params_dict = zip(self.net.state_dict().keys(), parameters) | |
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict}) | |
self.net.load_state_dict(state_dict, strict=True) | |
def fit(self, parameters, config): | |
log(INFO, f"Client {self.client_id} is starting fit()") | |
self.set_parameters(parameters) | |
train(self.net, self.trainloader, epochs=1) | |
loss, accuracy = test(self.net, self.testloader) | |
self.losses.append(loss) | |
self.accuracies.append(accuracy) | |
log(INFO, f"Client {self.client_id} finished fit() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
return self.get_parameters(config={}), len(self.trainloader.dataset), {"loss": loss, "accuracy": accuracy} | |
def evaluate(self, parameters, config): | |
log(INFO, f"Client {self.client_id} is starting evaluate()") | |
self.set_parameters(parameters) | |
loss, accuracy = test(self.net, self.testloader) | |
log(INFO, f"Client {self.client_id} finished evaluate() with loss: {loss:.4f} and accuracy: {accuracy:.4f}") | |
return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy), "loss": float(loss)} | |
def plot_metrics(self, round_num, plot_placeholder): | |
if self.losses and self.accuracies: | |
plot_placeholder.write(f"#### Client {self.client_id} Metrics for Round {round_num}") | |
plot_placeholder.write(f"Loss: {self.losses[-1]:.4f}") | |
plot_placeholder.write(f"Accuracy: {self.accuracies[-1]:.4f}") | |
fig, ax1 = plt.subplots() | |
color = 'tab:red' | |
ax1.set_xlabel('Round') | |
ax1.set_ylabel('Loss', color=color) | |
ax1.plot(range(1, len(self.losses) + 1), self.losses, color=color) | |
ax1.tick_params(axis='y', labelcolor=color) | |
ax2 = ax1.twinx() # instantiate a second axes that shares the same x-axis | |
color = 'tab:blue' | |
ax2.set_ylabel('Accuracy', color=color) | |
ax2.plot(range(1, len(self.accuracies) + 1), self.accuracies, color=color) | |
ax2.tick_params(axis='y', labelcolor=color) | |
fig.tight_layout() | |
plot_placeholder.pyplot(fig) | |
def read_log_file(log_path='./log.txt'): | |
with open(log_path, 'r') as file: | |
log_lines = file.readlines() | |
return log_lines | |
def parse_log(log_lines): | |
rounds = [] | |
clients = {} | |
memory_usage = [] | |
round_pattern = re.compile(r'ROUND (\d+)') | |
client_pattern = re.compile(r'Client (\d+) \| (INFO|DEBUG) \| (.*)') | |
memory_pattern = re.compile(r'memory used=(\d+\.\d+)GB') | |
current_round = None | |
for line in log_lines: | |
round_match = round_pattern.search(line) | |
client_match = client_pattern.search(line) | |
memory_match = memory_pattern.search(line) | |
if round_match: | |
current_round = int(round_match.group(1)) | |
rounds.append(current_round) | |
elif client_match: | |
client_id = int(client_match.group(1)) | |
log_level = client_match.group(2) | |
message = client_match.group(3) | |
if client_id not in clients: | |
clients[client_id] = {'rounds': [], 'messages': []} | |
clients[client_id]['rounds'].append(current_round) | |
clients[client_id]['messages'].append((log_level, message)) | |
elif memory_match: | |
memory_usage.append(float(memory_match.group(1))) | |
return rounds, clients, memory_usage | |
def plot_metrics(rounds, clients, memory_usage): | |
st.write("## Metrics Overview") | |
st.write("### Memory Usage") | |
plt.figure() | |
plt.plot(range(len(memory_usage)), memory_usage, label='Memory Usage (GB)') | |
plt.xlabel('Step') | |
plt.ylabel('Memory Usage (GB)') | |
plt.legend() | |
st.pyplot(plt) | |
for client_id, data in clients.items(): | |
st.write(f"### Client {client_id} Metrics") | |
info_messages = [msg for level, msg in data['messages'] if level == 'INFO'] | |
debug_messages = [msg for level, msg in data['messages'] if level == 'DEBUG'] | |
st.write("#### INFO Messages") | |
for msg in info_messages: | |
st.write(msg) | |
st.write("#### DEBUG Messages") | |
for msg in debug_messages: | |
st.write(msg) | |
# Placeholder for actual loss and accuracy values, assuming they're included in the messages | |
losses = [float(re.search(r'loss=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'loss=' in msg] | |
accuracies = [float(re.search(r'accuracy=([\d\.]+)', msg).group(1)) for msg in debug_messages if 'accuracy=' in msg] | |
if losses: | |
plt.figure() | |
plt.plot(data['rounds'], losses, label='Loss') | |
plt.xlabel('Round') | |
plt.ylabel('Loss') | |
plt.legend() | |
st.pyplot(plt) | |
if accuracies: | |
plt.figure() | |
plt.plot(data['rounds'], accuracies, label='Accuracy') | |
plt.xlabel('Round') | |
plt.ylabel('Accuracy') | |
plt.legend() | |
st.pyplot(plt) | |
def read_log_file2(): | |
with open("./log.txt", "r") as file: | |
return file.read() | |
def main(): | |
st.write("## Federated Learning with Dynamic Models and Datasets for Mobile Devices") | |
logs = read_log_file2() | |
pattern = re.compile(r"memory|loss|accuracy|round|client", re.IGNORECASE) | |
filtered_logs = [line for line in logs.splitlines() if pattern.search(line)] | |
st.markdown(filtered_logs) | |
st.download_button( | |
label="Download Logs", | |
data="\n".join(filtered_logs), | |
file_name="./log.txt", | |
mime="text/plain" | |
) | |
dataset_name = st.selectbox("Dataset", ["imdb", "amazon_polarity", "ag_news"]) | |
model_name = st.selectbox("Model", ["bert-base-uncased", "facebook/hubert-base-ls960", "distilbert-base-uncased", "google/byt5-small"]) | |
NUM_CLIENTS = st.slider("Number of Clients", min_value=1, max_value=10, value=2) | |
NUM_ROUNDS = st.slider("Number of Rounds", min_value=1, max_value=10, value=3) | |
use_utf8 = st.checkbox("Train on Byte UTF-8 Dataset", value=False) | |
train_datasets, test_datasets, data_collator, raw_datasets = load_data(dataset_name, num_clients=NUM_CLIENTS, use_utf8=use_utf8, model_name=model_name) | |
trainloaders = [] | |
testloaders = [] | |
clients = [] | |
for i in range(NUM_CLIENTS): | |
st.write(f"### Client {i+1} Datasets") | |
train_df = pd.DataFrame(train_datasets[i]) | |
test_df = pd.DataFrame(test_datasets[i]) | |
st.write("#### Train Dataset (Words)") | |
st.dataframe(raw_datasets["train"].select(random.sample(range(len(raw_datasets["train"])), 20))) | |
st.write("#### Train Dataset (Tokens)") | |
edited_train_df = st.data_editor(train_df, key=f"train_{i}") | |
st.write("#### Test Dataset (Words)") | |
st.dataframe(raw_datasets["test"].select(random.sample(range(len(raw_datasets["test"])), 20))) | |
st.write("#### Test Dataset (Tokens)") | |
edited_test_df = st.data_editor(test_df, key=f"test_{i}") | |
edited_train_dataset = Dataset.from_pandas(edited_train_df) | |
edited_test_dataset = Dataset.from_pandas(edited_test_df) | |
trainloader = DataLoader(edited_train_dataset, shuffle=True, batch_size=32, collate_fn=data_collator) | |
testloader = DataLoader(edited_test_dataset, batch_size=32, collate_fn=data_collator) | |
net = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2).to(DEVICE) | |
client = CustomClient(net, trainloader, testloader, client_id=i+1) | |
clients.append(client) | |
if st.button("Start Training"): | |
def client_fn(cid): | |
return clients[int(cid)].to_client() | |
def weighted_average(metrics): | |
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics] | |
losses = [num_examples * m["loss"] for num_examples, m in metrics] | |
examples = [num_examples for num_examples, _ in metrics] | |
return {"accuracy": sum(accuracies) / sum(examples), "loss": sum(losses) / sum(examples)} | |
strategy = fl.server.strategy.FedAvg( | |
fraction_fit=1.0, | |
fraction_evaluate=1.0, | |
evaluate_metrics_aggregation_fn=weighted_average, | |
) | |
for round_num in range(NUM_ROUNDS): | |
st.write(f"### Round {round_num + 1} ✅") | |
logs = read_log_file2() | |
filtered_log_list = [line for line in logs.splitlines() if pattern.search(line)] | |
filtered_logs = "\n".join(filtered_log_list) | |
st.markdown(filtered_logs) | |
accuracy_pattern = re.compile(r"'accuracy': \{(\d+), ([\d.]+)\}") | |
loss_pattern = re.compile(r"'loss': \{(\d+), ([\d.]+)\}") | |
accuracy_matches = accuracy_pattern.findall(filtered_logs) | |
loss_matches = loss_pattern.findall(filtered_logs) | |
rounds = [int(match[0]) for match in accuracy_matches] | |
accuracies = [float(match[1]) for match in accuracy_matches] | |
losses = [float(match[1]) for match in loss_matches] | |
accuracy_fig = go.Figure() | |
accuracy_fig.add_trace(go.Scatter(x=rounds, y=accuracies, mode='lines+markers', name='Accuracy')) | |
accuracy_fig.update_layout(title='Accuracy over Rounds', xaxis_title='Round', yaxis_title='Accuracy') | |
loss_fig = go.Figure() | |
loss_fig.add_trace(go.Scatter(x=rounds, y=losses, mode='lines+markers', name='Loss')) | |
loss_fig.update_layout(title='Loss over Rounds', xaxis_title='Round', yaxis_title='Loss') | |
st.plotly_chart(accuracy_fig) | |
st.plotly_chart(loss_fig) | |
data = { | |
'Round': rounds, | |
'Accuracy': accuracies, | |
'Loss': losses | |
} | |
df = pd.DataFrame(data) | |
st.write("## Training Metrics") | |
st.table(df) | |
plot_placeholders = [st.empty() for _ in range(NUM_CLIENTS)] | |
fl.simulation.start_simulation( | |
client_fn=client_fn, | |
num_clients=NUM_CLIENTS, | |
config=fl.server.ServerConfig(num_rounds=1), | |
strategy=strategy, | |
client_resources={"num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)}, | |
ray_init_args={"log_to_driver": True, "num_cpus": 1, "num_gpus": (1 if torch.cuda.is_available() else 0)} | |
) | |
for i, client in enumerate(clients): | |
client.plot_metrics(round_num + 1, plot_placeholders[i]) | |
st.write(" ") | |
st.success("Training completed successfully!") | |
st.write("## Final Client Metrics") | |
for client in clients: | |
st.write(f"### Client {client.client_id}") | |
if client.losses and client.accuracies: | |
st.write(f"Final Loss: {client.losses[-1]:.4f}") | |
st.write(f"Final Accuracy: {client.accuracies[-1]:.4f}") | |
client.plot_metrics(NUM_ROUNDS, st.empty()) | |
else: | |
st.write("No metrics available.") | |
st.write(" ") | |
st.write("## Training Log") | |
st.write(read_log_file2()) | |
st.write("## Training Log Analysis") | |
log_lines = read_log_file() | |
rounds, clients, memory_usage = parse_log(log_lines) | |
plot_metrics(rounds, clients, memory_usage) | |
else: | |
st.write("Click the 'Start Training' button to start the training process.") | |
if __name__ == "__main__": | |
main() | |