Horizontal Federated Learning Task
This is an example of a horizontal federated learning task written using the 2PM Node Framework.
The data consists of the MNIST dataset distributed across multiple nodes, with each node holding a portion of the samples. The task involves training a convolutional neural network model for handwritten digit recognition.
Import Necessary Packages:
The computation logic is written in PyTorch. Therefore, we first import numpy and torch, along with some auxiliary tools. We then include the contents of the 2PM Framework from the 2pm-task
package, which includes 2PMNode used to call APIs for sending tasks for horizontal federated learning, and the FaultTolerantFedAvg for configuring secure aggregation strategies:
Define the Neural Network Model
Next, let's define the neural network model, which follows the traditional structure of neural network definitions:
Define Privacy Computing Tasks
Now, we can start defining our horizontal federated learning task, using the horizontal federated learning approach to train the neural network model defined above on multiple nodes.
When defining the horizontal federated learning task, there are several components that users need to specify themselves:
Task Configuration: We need to configure the task in the
super().__init__()
method.Dataset: Define the required dataset for the task in the
dataset
method.Training Set DataLoader: Define the DataLoader for the training set in the
make_train_dataloader
method.Validation Set DataLoader: Define the DataLoader for the validation set in the
make_validate_dataloader
method.Model Training: In this method, define the entire training process of the model, including both forward and backward propagation. The input for this method is the training set's DataLoader.
Model Validation: In this method, define the entire validation process of the model. The input for this method is the validation set's DataLoader, and the output is a dictionary where the keys are the names of computed metrics and the values are the corresponding metric values.
Model Parameters: We need to define all model parameters that require training and updating in the
state_dict
method.
To specifically detail the process of defining a horizontal federated learning task, we need to create a class that inherits from HorizontalLearning
. The HorizontalLearning
class is an abstract base class that outlines a series of virtual functions specifically tailored for horizontal federated learning tasks, which we are required to implement.
Initially, the constructor, __init__
, is where the foundational setup for the task is configured. This configuration includes the task name (name
), the total number of training rounds (max_rounds
), the frequency of validations (validate_interval
indicates validation occurs after every set number of rounds), the proportion of the validation dataset (validate_frac
), and the strategy for secure aggregation (strategy
). When implementing the constructor, it is mandatory to call the superclass constructor using super().__init__()
.
Following this, the dataset
method defines the dataset required for the task. This method returns an instance of 2pm.dataset.Dataset
, where the parameter dataset
specifies the name of the dataset required. For details on the dataset format, please refer to this article. Currently, horizontal federated learning tasks support only one dataset, so the dataset
method can only return a single instance of 2pm.dataset.Dataset
.
Next are two methods for defining DataLoader
s: make_train_dataloader
for the training set and make_validate_dataloader
for the validation set. The logic required in both methods is similar, hence they are discussed together. Each method takes an instance of torch.utils.data.Dataset
, corresponding to either the training or validation set, and can transform and preprocess the dataset as needed. Ultimately, each method returns an instance of torch.utils.data.DataLoader
, which serves as the input for the model training method.
The train
method is where the entire training process of the model is implemented, including forward propagation, backward propagation, and parameter updates. The input for this method is the DataLoader
defined in make_train_dataloader
.
Additionally, the validate
method calculates various performance metrics of the model on the validation set. The input for this method is the DataLoader
defined in make_validate_dataloader
. The output is a dictionary where the keys are the names of the computed metrics and the values are the corresponding metric values.
Finally, the state_dict
method is used to define all the model parameters that require training and updating. The return value of this method is a list of these model parameters.
Specifying the API for the 2PM Node to Execute Tasks
Once the task is defined, we can begin preparations to execute it on the 2PM Node.
The 2PM Task Management allows for direct interaction with the 2PM Node API to dispatch tasks to the 2PM Node for execution. Simply specify the API address of the 2PM Node when initiating the task execution.
Last updated