Horizontal Federated Learning Task

This is an example of a horizontal federated learning task written using the 2PM Node Framework.

The data consists of the MNIST dataset distributed across multiple nodes, with each node holding a portion of the samples. The task involves training a convolutional neural network model for handwritten digit recognition.

Import Necessary Packages:

The computation logic is written in PyTorch. Therefore, we first import numpy and torch, along with some auxiliary tools. We then include the contents of the 2PM Framework from the 2pm-task package, which includes 2PMNode used to call APIs for sending tasks for horizontal federated learning, and the FaultTolerantFedAvg for configuring secure aggregation strategies:

from typing import Any, Dict, Iterable, List, Tuple

import logging
import numpy as np
import torch
from PIL.Image import Image
from torch.utils.data import DataLoader, Dataset

from 2pm.2pm_node import 2PMNode
from 2pm.task.learning import HorizontalLearning, FaultTolerantFedAvg
import 2pm.dataset

Define the Neural Network Model

Next, let's define the neural network model, which follows the traditional structure of neural network definitions:

class LeNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 16, 5, padding=2)
        self.pool1 = torch.nn.AvgPool2d(2, stride=2)
        self.conv2 = torch.nn.Conv2d(16, 16, 5)
        self.pool2 = torch.nn.AvgPool2d(2, stride=2)
        self.dense1 = torch.nn.Linear(400, 100)
        self.dense2 = torch.nn.Linear(100, 10)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool2(x)

        x = x.view(-1, 400)
        x = self.dense1(x)
        x = torch.relu(x)
        x = self.dense2(x)
        return x

Define Privacy Computing Tasks

Now, we can start defining our horizontal federated learning task, using the horizontal federated learning approach to train the neural network model defined above on multiple nodes.

When defining the horizontal federated learning task, there are several components that users need to specify themselves:

  • Task Configuration: We need to configure the task in the super().__init__() method.

  • Dataset: Define the required dataset for the task in the dataset method.

  • Training Set DataLoader: Define the DataLoader for the training set in the make_train_dataloader method.

  • Validation Set DataLoader: Define the DataLoader for the validation set in the make_validate_dataloader method.

  • Model Training: In this method, define the entire training process of the model, including both forward and backward propagation. The input for this method is the training set's DataLoader.

  • Model Validation: In this method, define the entire validation process of the model. The input for this method is the validation set's DataLoader, and the output is a dictionary where the keys are the names of computed metrics and the values are the corresponding metric values.

  • Model Parameters: We need to define all model parameters that require training and updating in the state_dict method.

def transform_data(data: List[Tuple[Image, str]]):
    """
    As a dataloader's collate_fn, this function preprocesses the input.
    It resizes and normalizes MNIST images before converting them into torch.Tensor format for processing.
    """
    xs, ys = [], []
    for x, y in data:
        xs.append(np.array(x).reshape((1, 28, 28)))
        ys.append(int(y))

    imgs = torch.tensor(xs)
    labels = torch.tensor(ys)
    imgs = imgs / 255 - 0.5
    return imgs, labels


class Example(HorizontalLearning):
    def __init__(self) -> None:
        super().__init__(
            name="example",  # Task name
            max_rounds=2,  # Total number of training rounds, each round represents an update aggregation
            validate_interval=1,  # Interval of validation rounds, 1 indicates validation after each round
            validate_frac=0.1,  # Proportion of the validation set, within the range (0,1)
            strategy=FaultTolerantFedAvg(  # Secure aggregation strategy, currently includes FedAvg and FaultTolerantFedAvg located under 2pm.task.learning
                min_clients=2,  # Minimum required number of clients, at least 2
                max_clients=3,  # Maximum supported number of clients, must be greater than or equal to min_clients
                merge_epoch=1,  # Interval for aggregation updates, merge_interval_epoch indicates the number of epochs between each weight update
                wait_timeout=30,  # Timeout for waiting, used to control the timeout of a computational round
                connection_timeout=10  # Connection timeout, used to control the timeout of each phase in the process
            )
        )
        self.model = LeNet()
        self.loss_func = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=0.1,
            momentum=0.9,
            weight_decay=1e-3,
            nesterov=True,
        )

    def dataset(self) -> 2pm.dataset.Dataset:
        """
        Define the dataset required for the task.
        return: A 2pm.dataset.Dataset
        """
        return 2pm.dataset.Dataset(dataset="mnist")

    def make_train_dataloader(self, dataset: Dataset) -> DataLoader:
        """
        Define the DataLoader for the training set, which can perform various transformations and preprocessing on the dataset.
        dataset: The training set Dataset
        return: DataLoader for the training set
        """
        return DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True, collate_fn=transform_data)  # type: ignore

    def make_validate_dataloader(self, dataset: Dataset) -> DataLoader:
        """
        Define the DataLoader for the validation set, which can perform various transformations and preprocessing on the dataset.
        dataset: The validation set Dataset
        return: DataLoader for the validation set
        """
        return DataLoader(dataset, batch_size=64, shuffle=False, drop_last=False, collate_fn=transform_data)  # type: ignore

    def train(self, dataloader: Iterable):
        """
        Training step
        dataloader: The dataloader corresponding to the training dataset
        return: None
        """
        for batch in dataloader:
            x, y = batch
            y_pred = self.model(x)
            loss = self.loss_func(y_pred, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

    def validate(self, dataloader: Iterable) -> Dict[str, Any]:
        """
        Validation step, outputs validation metrics
        dataloader: The dataloader corresponding to the validation dataset
        return: Dictionary with metric names as keys (str) and their corresponding values (float)
        """
        total_loss = 0
        count = 0
        ys = []
        y_s = []
        for batch in dataloader:
            x, y = batch
            y_pred = self.model(x)
            loss = self.loss_func(y_pred, y)
            total_loss += loss.item()
            count += 1

            y_ = torch.argmax(y_pred, dim=1)
            y_s.extend(y_.tolist())
            ys.extend(y.tolist())
        avg_loss = total_loss / count
        tp = len([1 for i in range(len(ys)) if ys[i] == y_s[i]])
        precision = tp / len(ys)

        return {"loss": avg_loss, "precision": precision}

    def state_dict(self) -> Dict[str, torch.Tensor]:
        """
        Model parameters that need training and updating
        Only the parameters returned by get_params will be updated and saved during aggregation update and result saving.
        return: List of torch.Tensor representing the model parameters
        """
        return self.model.state_dict()

To specifically detail the process of defining a horizontal federated learning task, we need to create a class that inherits from HorizontalLearning. The HorizontalLearning class is an abstract base class that outlines a series of virtual functions specifically tailored for horizontal federated learning tasks, which we are required to implement.

Initially, the constructor, __init__, is where the foundational setup for the task is configured. This configuration includes the task name (name), the total number of training rounds (max_rounds), the frequency of validations (validate_interval indicates validation occurs after every set number of rounds), the proportion of the validation dataset (validate_frac), and the strategy for secure aggregation (strategy). When implementing the constructor, it is mandatory to call the superclass constructor using super().__init__().

Following this, the dataset method defines the dataset required for the task. This method returns an instance of 2pm.dataset.Dataset, where the parameter dataset specifies the name of the dataset required. For details on the dataset format, please refer to this article. Currently, horizontal federated learning tasks support only one dataset, so the dataset method can only return a single instance of 2pm.dataset.Dataset.

Next are two methods for defining DataLoaders: make_train_dataloader for the training set and make_validate_dataloader for the validation set. The logic required in both methods is similar, hence they are discussed together. Each method takes an instance of torch.utils.data.Dataset, corresponding to either the training or validation set, and can transform and preprocess the dataset as needed. Ultimately, each method returns an instance of torch.utils.data.DataLoader, which serves as the input for the model training method.

The train method is where the entire training process of the model is implemented, including forward propagation, backward propagation, and parameter updates. The input for this method is the DataLoader defined in make_train_dataloader.

Additionally, the validate method calculates various performance metrics of the model on the validation set. The input for this method is the DataLoader defined in make_validate_dataloader. The output is a dictionary where the keys are the names of the computed metrics and the values are the corresponding metric values.

Finally, the state_dict method is used to define all the model parameters that require training and updating. The return value of this method is a list of these model parameters.

Specifying the API for the 2PM Node to Execute Tasks

Once the task is defined, we can begin preparations to execute it on the 2PM Node.

The 2PM Task Management allows for direct interaction with the 2PM Node API to dispatch tasks to the 2PM Node for execution. Simply specify the API address of the 2PM Node when initiating the task execution.

Last updated