What is federated averaging (FedAvg)?

Federated averaging is a technique used to train machine learning models where data is spread across many different servers or devices. It ensures data privacy and security and maintains data locality by enabling model training without sharing the raw data.

In the conventional approach of centralized machine learning, data is collected and stored in a central server, and a single model is trained on this consolidated data. Conversely, in federated averaging, data remains decentralized, residing on numerous devices or edge devices, and the model's training is conducted locally on each specific device.

Workflow mechanism

The process of federated averaging involves the following steps:

Federated averaging workflow mechanism
Federated averaging workflow mechanism
  1. Initialization: The central server The central server is a centralized entity in federated learning. initializes a global modelThe global model is a shared machine learning model used in federated learning capturing collective knowledge while preserving privacy..

  2. Client selection: To take part in the training round, a subset of clients is chosen. This choice may be arbitrary or determined by a set of standards.

  3. Model distribution: The chosen clients receive the global model. A duplicate of the model is given to each client.

  4. Local training: The model is trained using the local data on each client device. To increase the model's performance, this training procedure can contain numerous iterations or epochs.

  5. Model aggregation: The updated models from each client are sent back to the central server after local training.

  6. Model averaging: The central server aggregates the models received from the clients by averaging the model parameters. This averaging process ensures that the global model benefits from the knowledge learned on different clients while preserving privacy.

  7. Repeat: Steps 2 (Client selection) to 6 (Model averaging) are repeated for multiple training rounds until convergenceConvergence refers to the point at which an iterative algorithm has reached a stable and optimal solution. or a desired level of performance is achieved.

Note: The flowchart above provides a general workflow. The actual implementation of federated averaging may vary based on the specific federated learning framework or algorithm used.

Code example

import numpy as np
# Define a sample model for demonstration
class Model:
def __init__(self):
self.fc = np.random.randn(10, 1)
def forward(self, x):
return np.dot(x, self.fc)
# Client-side training function
def train_local_model(client_data, model):
num_epochs = 10
learning_rate = 0.1
for epoch in range(num_epochs):
inputs, labels = client_data
inputs = inputs.reshape(1, -1)
outputs = model.forward(inputs)
loss = np.mean((outputs - labels) ** 2)
grad = 2 * np.dot(inputs.T, outputs - labels) / inputs.shape[1]
model.fc -= learning_rate * grad
return model
# Server-side federated averaging
def federated_averaging(global_model, client_data, num_rounds):
for round in range(num_rounds):
# Client selection
selected_clients = np.random.choice(range(len(client_data)), size=3, replace=False) # Select 3 client indices
# Model distribution
client_models = [global_model] * len(selected_clients) # Provide duplicate global model to each client
# Local training
for i, client_index in enumerate(selected_clients):
client_model = train_local_model(client_data[client_index], client_models[i])
client_models[i] = client_model
# Model aggregation
aggregated_model = Model()
for client_model in client_models:
aggregated_model.fc += client_model.fc
# Model averaging
global_model.fc = aggregated_model.fc / len(client_models)
print(f"Round {round+1} - Global Model Parameters:")
print(global_model.fc)
print()
# Example usage
global_model = Model()
client_data = [
(np.random.randn(10), np.random.randn(1)),
(np.random.randn(10), np.random.randn(1)),
(np.random.randn(10), np.random.randn(1))
] # Dummy client data
num_rounds = 2
federated_averaging(global_model, client_data, num_rounds)

Code explanation

  • Lines 4–9: Defines a sample model called Model. The model has a single fully connected layer represented by the fc attribute, initialized with random values. The forward method performs a forward pass by taking input x and multiplying it with the fc weights using the dot product.

  • Lines 12–24: The function train_local_model performs local training on the client side.

    • Takes client_data which consists of input samples and corresponding labels, and model as inputs.

    • Trains the model for a fixed number of epochs using gradient descent with a fixed learning rate.

    • Calculates the loss by comparing the model's outputs with the labels and updates the model's weights using the gradient of the loss with respect to the weights.

  • Lines 27–50: This function federated_averaging performs server-side federated averaging. It takes global_model as the initial model, client_data as a list of client data, and num_rounds as the number of federated learning rounds. In each round, the following steps are performed:

    • Randomly selects a subset of clients from the available client data.

    • Provides each selected client with a copy of the global model.

    • Each selected client trains its local model using the train_local_model function.

    • Creates an empty aggregated_model and aggregates the weights of each client's model by summing them.

    • Updates the global model's weights by averaging the aggregated weights across all clients.

    • Displays the global model's weights for each round.

  • Lines 53–60: Initializes a global_model, creates a list of dummy client_data consisting of input samples and labels, and defines the number of federated learning num_rounds. It then calls the federated_averaging function with these inputs to perform federated learning.

Benefits

The federated averaging technique has the following benefits:

  • Allows collaboration without providing raw data while maintaining privacy.

  • Reduces communication costs because only model updates are sent between clients and the server.

  • Scalable for large-scale machine learning applications.

  • Minimizes the need for data transfer, reducing network bandwidth requirements and latency.

  • Optimizes resource utilization, as it distributes the computational load across multiple devices or servers.

Drawbacks

The federated averaging has the following drawbacks:

  • Data distribution heterogeneity between devices.

  • Centralized control of the training process is lacking.

  • Limited access to the data kept on specific servers or devices.

  • Potential threats to privacy and security if necessary precautions are not taken.

Use cases

The federated averaging technique is used in many different fields. It applies to:

  • Healthcare: To train models using distributed patient data while maintaining privacy.

  • Finance sector: It allows financial firms to collaborate without disclosing confidential customer data.

  • Edge computing relevance: It is also helpful in edge computing applications (IoT devices, smartphones, or edge servers), where devices with a limited connection can contribute to model training.

Try it yourself

The picture below has eight different cards, each showing the workflow mechanism of federated averaging. They are not in the correct order. Try fixing the sequence of steps.

Rearrange the different steps and create the correct sequence for the workflow mechanism of federated averaging.

Conclusion

Federated averaging is a potential way to address the privacy and data decentralization concerns in machine learning. Federated averaging achieves a careful balance between data safety and model performance by allowing models to be trained locally on specific devices or servers while maintaining the privacy of user data.

Copyright ©2024 Educative, Inc. All rights reserved