Horizontal Federated Learning Task

This is an example of a horizontal federated learning task written using the 2PM Node Framework.

The data consists of the MNIST dataset distributed across multiple nodes, with each node holding a portion of the samples. The task involves training a convolutional neural network model for handwritten digit recognition.

Import Necessary Packages:

The computation logic is written in PyTorch. Therefore, we first import numpy and torch, along with some auxiliary tools. We then include the contents of the 2PM Framework from the 2pm-task package, which includes 2PMNode used to call APIs for sending tasks for horizontal federated learning, and the FaultTolerantFedAvg for configuring secure aggregation strategies:

from typing import Any, Dict, Iterable, List, Tuple

import logging
import numpy as np
import torch
from PIL.Image import Image
from torch.utils.data import DataLoader, Dataset

from 2pm.2pm_node import 2PMNode
from 2pm.task.learning import HorizontalLearning, FaultTolerantFedAvg
import 2pm.dataset

Define the Neural Network Model

Next, let's define the neural network model, which follows the traditional structure of neural network definitions:

class LeNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 5, padding=2)
        self.pool1 = torch.nn.AvgPool2d(2, stride=2)
        self.conv2 = torch.nn.Conv2d(16, 16, 5)
        self.pool2 = torch.nn.AvgPool2d(2, stride=2)
        self.dense1 = torch.nn.Linear(400, 100)
        self.dense2 = torch.nn.Linear(100, 10)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool2(x)

        x = x.view(-1, 400)
        x = self.dense1(x)
        x = torch.relu(x)
        x = self.dense2(x)
        return x

Define Privacy Computing Tasks

Now, we can start defining our horizontal federated learning task, using the horizontal federated learning approach to train the neural network model defined above on multiple nodes.

When defining the horizontal federated learning task, there are several components that users need to specify themselves:

  • Task Configuration: We need to configure the task in the super().__init__() method.

  • Dataset: Define the required dataset for the task in the dataset method.

  • Training Set DataLoader: Define the DataLoader for the training set in the make_train_dataloader method.

  • Validation Set DataLoader: Define the DataLoader for the validation set in the make_validate_dataloader method.

  • Model Training: In this method, define the entire training process of the model, including both forward and backward propagation. The input for this method is the training set's DataLoader.

  • Model Validation: In this method, define the entire validation process of the model. The input for this method is the validation set's DataLoader, and the output is a dictionary where the keys are the names of computed metrics and the values are the corresponding metric values.

  • Model Parameters: We need to define all model parameters that require training and updating in the state_dict method.

To specifically detail the process of defining a horizontal federated learning task, we need to create a class that inherits from HorizontalLearning. The HorizontalLearning class is an abstract base class that outlines a series of virtual functions specifically tailored for horizontal federated learning tasks, which we are required to implement.

Initially, the constructor, __init__, is where the foundational setup for the task is configured. This configuration includes the task name (name), the total number of training rounds (max_rounds), the frequency of validations (validate_interval indicates validation occurs after every set number of rounds), the proportion of the validation dataset (validate_frac), and the strategy for secure aggregation (strategy). When implementing the constructor, it is mandatory to call the superclass constructor using super().__init__().

Following this, the dataset method defines the dataset required for the task. This method returns an instance of 2pm.dataset.Dataset, where the parameter dataset specifies the name of the dataset required. For details on the dataset format, please refer to this article. Currently, horizontal federated learning tasks support only one dataset, so the dataset method can only return a single instance of 2pm.dataset.Dataset.

Next are two methods for defining DataLoaders: make_train_dataloader for the training set and make_validate_dataloader for the validation set. The logic required in both methods is similar, hence they are discussed together. Each method takes an instance of torch.utils.data.Dataset, corresponding to either the training or validation set, and can transform and preprocess the dataset as needed. Ultimately, each method returns an instance of torch.utils.data.DataLoader, which serves as the input for the model training method.

The train method is where the entire training process of the model is implemented, including forward propagation, backward propagation, and parameter updates. The input for this method is the DataLoader defined in make_train_dataloader.

Additionally, the validate method calculates various performance metrics of the model on the validation set. The input for this method is the DataLoader defined in make_validate_dataloader. The output is a dictionary where the keys are the names of the computed metrics and the values are the corresponding metric values.

Finally, the state_dict method is used to define all the model parameters that require training and updating. The return value of this method is a list of these model parameters.

Specifying the API for the 2PM Node to Execute Tasks

Once the task is defined, we can begin preparations to execute it on the 2PM Node.

The 2PM Task Management allows for direct interaction with the 2PM Node API to dispatch tasks to the 2PM Node for execution. Simply specify the API address of the 2PM Node when initiating the task execution.

Last updated