Horizontal Federated Learning Task
This is an example of a horizontal federated learning task written using the 2PM Node Framework.
The data consists of the MNIST dataset distributed across multiple nodes, with each node holding a portion of the samples. The task involves training a convolutional neural network model for handwritten digit recognition.
Import Necessary Packages:
The computation logic is written in PyTorch. Therefore, we first import numpy and torch, along with some auxiliary tools. We then include the contents of the 2PM Framework from the 2pm-task
package, which includes 2PMNode used to call APIs for sending tasks for horizontal federated learning, and the FaultTolerantFedAvg for configuring secure aggregation strategies:
from typing import Any, Dict, Iterable, List, Tuple
import logging
import numpy as np
import torch
from PIL.Image import Image
from torch.utils.data import DataLoader, Dataset
from 2pm.2pm_node import 2PMNode
from 2pm.task.learning import HorizontalLearning, FaultTolerantFedAvg
import 2pm.dataset
Define the Neural Network Model
Next, let's define the neural network model, which follows the traditional structure of neural network definitions:
class LeNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 5, padding=2)
self.pool1 = torch.nn.AvgPool2d(2, stride=2)
self.conv2 = torch.nn.Conv2d(16, 16, 5)
self.pool2 = torch.nn.AvgPool2d(2, stride=2)
self.dense1 = torch.nn.Linear(400, 100)
self.dense2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor):
x = self.conv1(x)
x = torch.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = torch.relu(x)
x = self.pool2(x)
x = x.view(-1, 400)
x = self.dense1(x)
x = torch.relu(x)
x = self.dense2(x)
return x
Define Privacy Computing Tasks
Now, we can start defining our horizontal federated learning task, using the horizontal federated learning approach to train the neural network model defined above on multiple nodes.
When defining the horizontal federated learning task, there are several components that users need to specify themselves:
Task Configuration: We need to configure the task in the
super().__init__()
method.Dataset: Define the required dataset for the task in the
dataset
method.Training Set DataLoader: Define the DataLoader for the training set in the
make_train_dataloader
method.Validation Set DataLoader: Define the DataLoader for the validation set in the
make_validate_dataloader
method.Model Training: In this method, define the entire training process of the model, including both forward and backward propagation. The input for this method is the training set's DataLoader.
Model Validation: In this method, define the entire validation process of the model. The input for this method is the validation set's DataLoader, and the output is a dictionary where the keys are the names of computed metrics and the values are the corresponding metric values.
Model Parameters: We need to define all model parameters that require training and updating in the
state_dict
method.
def transform_data(data: List[Tuple[Image, str]]):
"""
As a dataloader's collate_fn, this function preprocesses the input.
It resizes and normalizes MNIST images before converting them into torch.Tensor format for processing.
"""
xs, ys = [], []
for x, y in data:
xs.append(np.array(x).reshape((1, 28, 28)))
ys.append(int(y))
imgs = torch.tensor(xs)
labels = torch.tensor(ys)
imgs = imgs / 255 - 0.5
return imgs, labels
class Example(HorizontalLearning):
def __init__(self) -> None:
super().__init__(
name="example", # Task name
max_rounds=2, # Total number of training rounds, each round represents an update aggregation
validate_interval=1, # Interval of validation rounds, 1 indicates validation after each round
validate_frac=0.1, # Proportion of the validation set, within the range (0,1)
strategy=FaultTolerantFedAvg( # Secure aggregation strategy, currently includes FedAvg and FaultTolerantFedAvg located under 2pm.task.learning
min_clients=2, # Minimum required number of clients, at least 2
max_clients=3, # Maximum supported number of clients, must be greater than or equal to min_clients
merge_epoch=1, # Interval for aggregation updates, merge_interval_epoch indicates the number of epochs between each weight update
wait_timeout=30, # Timeout for waiting, used to control the timeout of a computational round
connection_timeout=10 # Connection timeout, used to control the timeout of each phase in the process
)
)
self.model = LeNet()
self.loss_func = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=0.1,
momentum=0.9,
weight_decay=1e-3,
nesterov=True,
)
def dataset(self) -> 2pm.dataset.Dataset:
"""
Define the dataset required for the task.
return: A 2pm.dataset.Dataset
"""
return 2pm.dataset.Dataset(dataset="mnist")
def make_train_dataloader(self, dataset: Dataset) -> DataLoader:
"""
Define the DataLoader for the training set, which can perform various transformations and preprocessing on the dataset.
dataset: The training set Dataset
return: DataLoader for the training set
"""
return DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, collate_fn=transform_data) # type: ignore
def make_validate_dataloader(self, dataset: Dataset) -> DataLoader:
"""
Define the DataLoader for the validation set, which can perform various transformations and preprocessing on the dataset.
dataset: The validation set Dataset
return: DataLoader for the validation set
"""
return DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False, collate_fn=transform_data) # type: ignore
def train(self, dataloader: Iterable):
"""
Training step
dataloader: The dataloader corresponding to the training dataset
return: None
"""
for batch in dataloader:
x, y = batch
y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def validate(self, dataloader: Iterable) -> Dict[str, Any]:
"""
Validation step, outputs validation metrics
dataloader: The dataloader corresponding to the validation dataset
return: Dictionary with metric names as keys (str) and their corresponding values (float)
"""
total_loss = 0
count = 0
ys = []
y_s = []
for batch in dataloader:
x, y = batch
y_pred = self.model(x)
loss = self.loss_func(y_pred, y)
total_loss += loss.item()
count += 1
y_ = torch.argmax(y_pred, dim=1)
y_s.extend(y_.tolist())
ys.extend(y.tolist())
avg_loss = total_loss / count
tp = len([1 for i in range(len(ys)) if ys[i] == y_s[i]])
precision = tp / len(ys)
return {"loss": avg_loss, "precision": precision}
def state_dict(self) -> Dict[str, torch.Tensor]:
"""
Model parameters that need training and updating
Only the parameters returned by get_params will be updated and saved during aggregation update and result saving.
return: List of torch.Tensor representing the model parameters
"""
return self.model.state_dict()
To specifically detail the process of defining a horizontal federated learning task, we need to create a class that inherits from HorizontalLearning
. The HorizontalLearning
class is an abstract base class that outlines a series of virtual functions specifically tailored for horizontal federated learning tasks, which we are required to implement.
Initially, the constructor, __init__
, is where the foundational setup for the task is configured. This configuration includes the task name (name
), the total number of training rounds (max_rounds
), the frequency of validations (validate_interval
indicates validation occurs after every set number of rounds), the proportion of the validation dataset (validate_frac
), and the strategy for secure aggregation (strategy
). When implementing the constructor, it is mandatory to call the superclass constructor using super().__init__()
.
Following this, the dataset
method defines the dataset required for the task. This method returns an instance of 2pm.dataset.Dataset
, where the parameter dataset
specifies the name of the dataset required. For details on the dataset format, please refer to this article. Currently, horizontal federated learning tasks support only one dataset, so the dataset
method can only return a single instance of 2pm.dataset.Dataset
.
Next are two methods for defining DataLoader
s: make_train_dataloader
for the training set and make_validate_dataloader
for the validation set. The logic required in both methods is similar, hence they are discussed together. Each method takes an instance of torch.utils.data.Dataset
, corresponding to either the training or validation set, and can transform and preprocess the dataset as needed. Ultimately, each method returns an instance of torch.utils.data.DataLoader
, which serves as the input for the model training method.
The train
method is where the entire training process of the model is implemented, including forward propagation, backward propagation, and parameter updates. The input for this method is the DataLoader
defined in make_train_dataloader
.
Additionally, the validate
method calculates various performance metrics of the model on the validation set. The input for this method is the DataLoader
defined in make_validate_dataloader
. The output is a dictionary where the keys are the names of the computed metrics and the values are the corresponding metric values.
Finally, the state_dict
method is used to define all the model parameters that require training and updating. The return value of this method is a list of these model parameters.
Specifying the API for the 2PM Node to Execute Tasks
Once the task is defined, we can begin preparations to execute it on the 2PM Node.
The 2PM Task Management allows for direct interaction with the 2PM Node API to dispatch tasks to the 2PM Node for execution. Simply specify the API address of the 2PM Node when initiating the task execution.
Last updated