Federated Learning (FL) is a machine learning paradigm where an aggregate Machine Learning (ML) model is collaboratively trained across multiple decentralized devices or servers holding private datasets. This approach is particularly advantageous for addressing model bias through diverse data, while maintaining privacy and security for participating data owners. In FL, sensitive data remains in the control of the data owner, reducing the risk of breaches and ensuring compliance with data protection regulations. Federated Learning therefore has the potential to become the foundation for a robust and secure data economy in the future.
To demonstrate this vision in practice, we assume the role of a data scientist (“Model Owner”) who designs an FL experiment to train a model across a federation of participating entities (“Data Owners”).
We are going to use OpenFL to define the experiment, starting from a familiar centralized training/validation Python script. We set the additional challenge to configure the FL experiment without modifying the initial ML model definition.
You are encouraged to build the FL workspace by following each step of the tutorial, but in case of difficulties, you can always refer to the full sources on GitHub. Note also that some of the code listings presented in this article are trimmed down for conciseness.
So, let’s get started!
The Centralized ML Script
Before diving into federated learning, let’s consider a Python module that defines a DigitRecognizerCNN PyTorch model along with helper functions for training and validation. We will use this example as an illustration of how an arbitrary ML script that exposes a similar interface can evolve into a Federated Learning system.
The source can be downloaded from cnn_model.py:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class DigitRecognizerCNN(nn.Module):
def __init__(self, **kwargs):
super(DigitRecognizerCNN, self).__init__(**kwargs)
self.conv1 = nn.Conv2d(1, 20, 2, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(800, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 800)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def train(model, optimizer, loss_fn, dataloader, epochs=1, device="cpu"):
for epoch in range(epochs):
print(f"Starting epoch {epoch + 1}/{epochs}")
average_loss = train_epoch(model, optimizer, loss_fn, dataloader, device)
print(f"Completed epoch {epoch + 1}/{epochs} with average loss: {average_loss}")
return average_loss
def validate(model, test_dataloader, device="cpu"):
total_correct = 0
total_samples = 0
for data, target in test_dataloader:
data = torch.as_tensor(data).to(device)
target = torch.as_tensor(target).to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
total_correct += pred.eq(target.view_as(pred)).sum().item()
total_samples += len(target)
return total_correct / total_samples
def train_epoch(model, optimizer, loss_fn, dataloader, device="cpu"):
total_loss = 0
num_batches = 0
for data, target in dataloader:
data = torch.as_tensor(data).to(device)
target = torch.as_tensor(target).to(device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
total_loss += loss.detach().cpu().item()
num_batches += 1
average_loss = total_loss / num_batches
return average_loss
The DigitRecognizerCNN model can be trained at a central location (such as on a laptop) that has direct access to data - for example the full MNIST dataset. This can be done through a script such as train.py from the code listing below:
import torch
import torchvision
import torch.nn.functional as F
from torchvision.transforms import ToTensor
import torch.optim as optim
from cnn_model import DigitRecognizerCNN, train, validate
if __name__ == '__main__':
model = DigitRecognizerCNN()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = F.cross_entropy
data_train = torchvision.datasets.MNIST('./data', download=True, train=True, transform=ToTensor())
data_test = torchvision.datasets.MNIST('./data', download=True, train=False, transform=ToTensor())
train_data_loader = torch.utils.data.DataLoader(data_train, batch_size=64)
test_data_loader = torch.utils.data.DataLoader(data_test, batch_size=64)
train(model, optimizer, loss_fn, train_data_loader, epochs=1)
acc = validate(model, test_data_loader)
print(f"Digit recognizer accuracy after centralized training: {acc}")
For the rest of the tutorial, we “throw away” the train.py script, as it assumes access to the full data set.
In Federated Learning, data resides in separate silos, so we will need to plug the cnn_model.py into OpenFL, enabling the model to be securely trained against the private data of participating entities.
Building an OpenFL Workspace from Scratch
OpenFL offers several APIs for building FL experiments, depending on the use case. We recommend starting with the TaskRunner API, which enables setting up a central Aggregator node (hosted by the Model Owner) and multiple Collaborator nodes (hosted by Data Owners). For simplicity — and to be ML framework agnostic — the only data exchanged between participant nodes in the federation are serialized numpy arrays (np.ndarray). As depicted in Figure 1: Architecture of an OpenFL Federation, the Aggregator is responsible for coordinating the training and validation tasks on each collaborator in the federation. Crucially for privacy, the Aggregator does not have direct access to the training or validation data, instead only working with the outputs of the computations that run on each collaborator (in the form of metrics or gradients for instance).
Figure 1: Architecture of an OpenFL Federation
OpenFL abstracts away the complexities of setting up and securely running the FL experiment, allowing developers to focus on the model and the data. The developer interacts with OpenFL via a CLI interface under the fx group of commands.
Installing OpenFL
In Ubuntu 20.04 or later, the OpenFL package can be installed from PyPi (v1.6 as of the time of writing). It is recommended to do so in a dedicated Python virtual environment:
pip install virtualenv
mkdir ~/openfl-quickstart
virtualenv ~/openfl-quickstart/venv
source ~/openfl-quickstart/venv/bin/activate
pip install openfl==1.6
Creating the Workspace Folder
In OpenFL, a federation’s software artifacts reside in a workspace folder hierarchy, which can be packaged and distributed among the participating entities. Use the OpenFL CLI command fx workspace create to generate a new workspace for the federated learning project. To help streamline the process of developing PyTorch-based FL workspaces, OpenFL provides a dedicated torch_template:
cd ~/openfl-quickstart
fx workspace create --template torch_template --prefix fl_workspace
cd ~/openfl-quickstart/fl_workspace
The FL workspace has the following file structure:
fl_workspace
├── requirements.txt # defines the required software packages
├── cert # holds trusted certificates
├── data # placeholder for each collaborator’s data set
├── save # holds the serialized model
├── logs # FL experiment logs
└── plan
├── plan.yaml # the Federated Learning plan declaration
├── cols.yaml # holds the list of authorized collaborators
├── data.yaml # holds the collaborator data set path
├── defaults # path to the default values for the FL plan
├── src
├── __init__.py # treat src as a Python package
├── dataloader.py # data loader module
└── taskrunner.py # task runner module
It is recommended to set the following versions of torch and torchvision in the workspace’s requirements.txt, as they have been tested for compatibility with OpenFL 1.6:
torch==2.3.1
torchvision==0.18.1
Note: There is no need to explicitly install the updated requirements, as this will be done automatically during the workspace initialization step.
To customize the workspace for our specific model, we are going to modify the following files:
- src/dataloader.py: to define a data loader capable of iterating through the data at the collaborator nodes
- src/taskrunner.py: to define the training and validation tasks
- plan/plan.yaml: to define the components, pipeline, and the parameters of the FL experiment
Specifying the Data Layout
As the Model Owner, it is our responsibility to define in advance the expected data format, which has to be the same across all collaborator nodes. Complying with this specification and pre-processing the data accordingly is the Data Owner’s responsibility.
In our case, the local dataset is expected to be under the data/<N> folder of the Nth collaborator’s FL workspace, containing an MNIST-compliant directory hierarchy under mnist_images (with the label for each digit encoded in the respective sub-folder name):
data
├── N
└── mnist_images
└── 0
└── 1
└── 2
└── 3
└── 4
└── 5
└── 6
└── 7
└── 8
└── 9
Defining the Data Loader
The data loader in OpenFL is responsible for batching and iterating through the dataset that will be used for local training and validation on each collaborator node. The TemplateDataLoader class in src/dataloader.py is designed to be a starting template for creating a data loader that is tailored to the FL experiment’s data format requirements.
To customize the TemplateDataLoader, we just need to implement the load_dataset() function to process the dataset available at data_path on the local file system. The data_path parameter comes from the data.yaml configuration file, which is populated when the collaborator’s identity is created via fx collaborator create.
Because we assume each dataset to be pre-existing data formatted according to the MNIST standard, we must define a custom MNISTDataset class, for example by extending the ImageFolder base class from torchvision. In addition to loading samples in batches, the data loader further splits the collaborator’s dataset between a training and a validation subset (with a suggested default ratio of 0.8).
You can either try to implement the placeholders by yourself, or get the solution from dataloader.py:
import numpy as np
from typing import Iterator, Tuple
from openfl.federated import PyTorchTaskRunner
from openfl.utilities import Metric
import torch.optim as optim
import torch.nn.functional as F
from src.cnn_model import DigitRecognizerCNN, train_epoch, validate
class TemplateDataLoader(PyTorchDataLoader):
def __init__(self, data_path, batch_size, **kwargs):
super().__init__(batch_size, **kwargs)
# Load the dataset using the provided data_path and any additional kwargs.
X_train, y_train, X_valid, y_valid = load_dataset(data_path, **kwargs)
# Assign the loaded data to instance variables.
self.X_train = X_train
self.y_train = y_train
self.X_valid = X_valid
self.y_valid = y_valid
def load_dataset(data_path, train_split_ratio=0.8, **kwargs):
dataset = MNISTDataset(
root=data_path,
transform=Compose([Grayscale(num_output_channels=1), ToTensor()])
)
n_train = int(train_split_ratio * len(dataset))
n_valid = len(dataset) - n_train
ds_train, ds_val = random_split(
dataset, lengths=[n_train, n_valid], generator=manual_seed(0))
X_train, y_train = list(zip(*ds_train))
X_train, y_train = np.stack(X_train), np.array(y_train)
X_valid, y_valid = list(zip(*ds_val))
X_valid, y_valid = np.stack(X_valid), np.array(y_valid)
return X_train, y_train, X_valid, y_valid
class MNISTDataset(ImageFolder):
"""Encapsulates the MNIST dataset"""
FOLDER_NAME = "mnist_images"
DEFAULT_PATH = path.join(path.expanduser('~'), '.openfl', 'data')
def __init__(self, root: str = DEFAULT_PATH, **kwargs) -> None:
"""Initialize."""
makedirs(root, exist_ok=True)
super(MNISTDataset, self).__init__(
path.join(root, MNISTDataset.FOLDER_NAME), **kwargs)
def __getitem__(self, index):
"""Allow getting items by slice index."""
if isinstance(index, Iterable):
return [super(MNISTDataset, self).__getitem__(i) for i in index]
else:
return super(MNISTDataset, self).__getitem__(index)
Note that the train_split_ratio can also be parameterized via the data loader settings section in plan.yaml:
...
data_loader:
template: src.dataloader.TemplateDataLoader
settings:
batch_size: 64
train_split_ratio: 0.8 #... or some other value between 0 and 1, determined by the Model Owner
...
Defining the Task Runner
The Task Runner class defines the actual computational tasks of the FL experiment (such as local training and validation). We can implement the placeholders of the TemplateTaskRunner class (src/taskrunner.py) by importing the DigitRecognizerCNN model, as well as the train_epoch() and validate() helper functions from the centralized ML script. The template also provides placeholders for providing custom optimizer and loss function objects.
Note! Don’t forget to copy the starting cnn_model.py module to the src folder of your local FL workspace, so that taskrunner.py can import it.
from src.cnn_model import DigitRecognizerCNN, train_epoch, validate
# ... numpy and torch imports
class TemplateTaskRunner(PyTorchTaskRunner):
def __init__(self, device="cpu", **kwargs):
super().__init__(device=device, **kwargs)
# Define the model
self.model = DigitRecognizerCNN()
self.to(device)
# Define the optimizer
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
# Define the loss function
self.loss_fn = F.cross_entropy
def forward(self, x):
return self.model(x)
def train_(
self, train_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
) -> Metric:
loss = train_epoch(self.model, self.optimizer, self.loss_fn, train_dataloader, self.device)
return Metric(name="training loss", value=np.array(loss))
def validate_(
self, validation_dataloader: Iterator[Tuple[np.ndarray, np.ndarray]]
) -> Metric:
accuracy = validate(self.model, validation_dataloader, self.device)
return Metric(name="accuracy", value=np.array(accuracy))
The full source code of the task runner class can be downloaded from taskrunner.py.
Note that the forward() function remains unchanged as it dynamically references the model object that has been set in the constructor. Notice also how the loss and accuracy arrays respectively output by the original model’s train_epoch() and validate() functions are wrapped as numpy arrays in OpenFL’s Metric object for transferring over the network and further processing by the aggregator. In addition to metrics, upon completion of each training round, OpenFL automatically transmits the updated model weights (also in the form of numpy arrays) from the collaborators to the aggregator. Once the aggregator receives a quorum of updates from collaborators, it applies an aggregation algorithm (FedAvg by default) to produce the collectively trained model.
Defining the Federated Learning Plan
The FL plan is a YAML descriptor of the federation’s configuration, including the model architecture, hyperparameters, and the data pipeline. It specifies how the model should be trained across the federation and how the updates should be aggregated.
The PyTorch workspace template comes with a pre-populated FL plan descriptor under ./plan/plan.yaml. The only modification we need to make at this stage is to set the mandatory batch_size parameter in the data loader section of the FL plan, indicating how many data samples would be processed in a single step of the training process:
...
data_loader:
template: src.dataloader.TemplateDataLoader
settings:
batch_size: 64
...
The full FL plan definition is available at plan.yaml.
Local Simulation of the Federated Learning Experiment
At this point, the FL workspace is ready to be tested in a locally simulated FL environment, before being distributed to all participating entities.
The fx plan initialize command bootstraps the FL workspace by first setting the initial weights of the aggregate model. It then parses the plan, updates the aggregator address if necessary, and produces a hash of the initialized plan for integrity and auditing purposes.
To help OpenFL calculate the initial model weights, we need to provide the shape of the input tensor as an additional parameter. For the MNIST data set of grayscale (single-channel) 28x28 pixel images, the input tensor shape is [1,28,28]. We will also use a locally deployed aggregator (localhost). Thus, the workspace initialization command for our local federation becomes:
cd ~/openfl-quickstart/fl_workspace
fx plan initialize --input_shape [1,28,28] --aggregator_address localhost
For a faithful simulation of a federated learning setting where data resides in separate silos, we use a pre-sharded version of the standard MNIST dataset (divided among two data owners). The pre-sharded dataset can be downloaded from mnist_data_shards.tar.gz. Copy the dataset bundle to the root of the FL workspace and unpack it:
cp mnist_data_shards.tar.gz ~/openfl-quickstart/fl_workspace
cd ~/openfl-quickstart/fl_workspace
tar -xvf mnist_data_shards.tar.gz
rm mnist_data_shards.tar.gz
This will populate the data folder of the FL workspace with two shards (data/1 and data/2) of labeled MNIST images of digits (the 0–9 labels being encoded in the sub-folder names). Note that in a real-world federation each of the collaborator nodes would only hold one shard, given the decentralized nature of Federated Learning. To facilitate the local testing of the FL workspace, both shards are made available under the local data/ folder:
data
├── 1
└── mnist_images
└── 0
└── 1
└── 2
└── 3
└── 4
└── 5
└── 6
└── 7
└── 8
└── 9
├── 2
└── mnist_images
└── 0
└── 1
└── 2
└── 3
└── 4
└── 5
└── 6
└── 7
└── 8
└── 9
We can now perform a test run with the following commands for creating a local PKI setup and starting the aggregator and the collaborators on the same machine:
cd ~/openfl-quickstart/fl_workspace
# This will create a local certificate authority (CA), so the participants communicate over a secure TLS Channel
fx workspace certify
#################################################################
# Step 1: Setup the Aggregator #
#################################################################
# Generate a Certificate Signing Request (CSR) for the Aggregator
fx aggregator generate-cert-request --fqdn localhost
# The CA signs the aggregator's request, which is now available in the workspace
fx aggregator certify --fqdn localhost --silent
################################
# Step 2: Setup Collaborator 1 #
################################
# Create a collaborator named "collaborator1" that will use data path "data/1"
# This command adds the collaborator1,data/1 entry in data.yaml
fx collaborator create -n collaborator1 -d data/1
# Generate a CSR for collaborator1
fx collaborator generate-cert-request -n collaborator1
# The CA signs collaborator1's certificate, adding an entry to the authorized cols.yaml
fx collaborator certify -n collaborator1 --silent
################################
# Step 3: Setup Collaborator 2 #
################################
# Create a collaborator named "collaborator2" that will use data path "data/2"
# This command adds the collaborator2,data/2 entry in data.yaml
fx collaborator create -n collaborator2 -d data/2
# Generate a CSR for collaborator2
fx collaborator generate-cert-request -n collaborator2
# The CA signs collaborator2's certificate, adding an entry to the authorized cols.yaml
fx collaborator certify -n collaborator2 --silent
##############################
# Step 4. Run the Federation #
##############################
fx aggregator start & fx collaborator start -n collaborator1 & fx collaborator start -n collaborator2
A successful local simulation of the FL workspace involves the aggregator and collaborators completing a round of training, saving the best-performing model under save/best.pbuf, and exiting with a unanimous “End of Federation reached…”:
INFO Round: 1, Collaborators that have completed all tasks: ['collaborator2', 'collaborator1']
METRIC {'metric_origin': 'aggregator', 'task_name': 'aggregated_model_validation', 'metric_name': 'accuracy', 'metric_value':
0.8915090382660382, 'round': 1}
METRIC Round 1: saved the best model with score 0.891509
METRIC {'metric_origin': 'aggregator', 'task_name': 'train', 'metric_name': 'training loss', 'metric_value': 0.2952194180338876,
'round': 1}
METRIC {'metric_origin': 'aggregator', 'task_name': 'locally_tuned_model_validation', 'metric_name': 'accuracy', 'metric_value':
0.9181734901767464, 'round': 1}
INFO Saving round 1 model...
INFO Experiment Completed. Cleaning up...
INFO Waiting for tasks...
INFO Sending signal to collaborator collaborator1 to shutdown...
INFO End of Federation reached. Exiting...
INFO Waiting for tasks...
INFO Sending signal to collaborator collaborator2 to shutdown...
INFO End of Federation reached. Exiting...
Towards a Fully Distributed Deployment
The core logic of the FL experiment having been defined and locally tested, we are ready to begin preparations for a fully distributed deployment. This process involves setting up the PKI infrastructure between the aggregator and collaborator nodes, exporting a distributable workspace package, and importing it on each collaborator node. The instructions in Step 2: Configure the Federation of the OpenFL documentation outline the main steps of this process.
Conclusion and Next Steps
By following this tutorial, you have learned how to transform a typical ML script that relies on locally available data into a federated learning system using OpenFL. This approach can unlock the vast potential of private data sets, which would otherwise remain inaccessible due to confidentiality or security constraints.
With relative ease, the examples provided here can be adapted to build federated learning experiments from a wide range of ML models and frameworks. Thanks to OpenFL’s generic numpy-based interfaces, a number of ML frameworks are already supported, including PyTorch and TensorFlow, but also GaNDLF which specializes in the healthcare space. You can explore all currently supported FL workspace templates here.
Join the OpenFL Community
The OpenFL community is growing, and we invite you to be a part of it. Join the Slack channel to connect with fellow enthusiasts, share insights, and contribute to the future of federated learning.
- Tags:
- AI
- Artificial Intelligence
- Data loader
- data privacy
- DigitRecognizerCNN
- federated learning
- Federated Learning Plan
- FL Workspace
- Intel Tiber trust services
- Intel Trust and Security Products
- Intel Trust Services
- Large language models
- machine learning
- OpenFL
- Privacy
- Secure Federated AI
- Technology
- Trust and Security Solutions by Intel
You must be a registered user to add a comment. If you've already registered, sign in. Otherwise, register and sign in.