1. 导入依赖库

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import flwr as fl
import multiprocessing as mp
from torch.utils.data import DataLoader, Subset
import numpy as np
from torchvision import datasets
# 如需使用,联系Q 596520206
  • torch, torch.nn, torch.optim: PyTorch 的核心库,用于创建神经网络、定义优化器和计算损失。
  • torchvision: 用于加载常用的计算机视觉数据集(如 MNIST、CIFAR-10)以及预处理图像。
  • flwr (Flower): 一个用于联邦学习的框架,提供了客户端和服务器的通信接口。
  • multiprocessing: 用于在多个进程中并行启动 Flower 客户端和服务器。
  • DataLoader, Subset: 用于加载数据集和切割训练集为不同客户端使用。
  • numpy: 用于数值计算。

2. 设备设置和超参数定义

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 3
LEARNING_RATE = 0.001
  • DEVICE: 判断当前是否有可用的 GPU 设备,如果有则使用 GPU,否则使用 CPU。
  • BATCH_SIZE: 每次训练的批大小。
  • EPOCHS: 训练的轮数。
  • LEARNING_RATE: 优化器的学习率。

3. CNN模型定义

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 输出为10个类别

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # 展平
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
  • SimpleCNN 是一个简单的卷积神经网络(CNN),由两层卷积层、池化层和两层全连接层构成。
    • conv1conv2:卷积层,conv1 的输入为单通道(灰度图像),输出为32个特征图;conv2 的输入为32个特征图,输出为64个特征图。
    • pool:最大池化层,用于降采样。
    • fc1fc2:全连接层,将特征图展平后通过全连接层输出10个类别。

4. 联邦学习客户端定义

class CNNClient(fl.client.NumPyClient):
    def __init__(self, model, train_data, val_data):
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.criterion = nn.CrossEntropyLoss()
  • CNNClient 继承自 fl.client.NumPyClient,这是 Flower 框架提供的接口,代表了一个联邦学习的客户端。
  • train_dataval_data:训练数据和验证数据。
  • optimizer:使用 Adam 优化器进行训练。
  • criterion:损失函数使用交叉熵损失函数(适用于分类问题)。

4.1 get_parametersset_parameters 方法

    def get_parameters(self, config=None):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        state_dict = dict(zip(self.model.state_dict().keys(), parameters))
        self.model.load_state_dict({k: torch.tensor(v) for k, v in state_dict.items()})
  • get_parameters: 获取模型的参数,将其转换为 NumPy 数组返回。
  • set_parameters: 设置模型的参数,从传入的 NumPy 数组恢复模型权重。

4.2 fitevaluate 方法

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        for batch_idx, (data, target) in enumerate(self.train_data):
            data, target = data.to(DEVICE), target.to(DEVICE)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
        avg_loss = total_loss / len(self.train_data)
        accuracy = correct / total
        print(f"Training loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
        return self.get_parameters(), len(self.train_data.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in self.val_data:
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = self.model(data)
                _, predicted = torch.max(output.data, 1)
                correct += (predicted == target).sum().item()
                total += target.size(0)
        accuracy = correct / total
        print(f"Evaluation Accuracy: {accuracy:.4f}")
        return accuracy, len(self.val_data.dataset), {}
  • fit: 在客户端上进行本地训练,更新模型权重,计算并返回损失和准确率。
  • evaluate: 在客户端上进行模型评估,计算并返回准确率。

5. 数据加载和划分

def get_mnist_loaders():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    file_path = r"E:\t5\mnist"
    trainset = datasets.MNIST(root=file_path, train=True, download=True, transform=transform)
    testset = datasets.MNIST(root=file_path, train=False, download=True, transform=transform)

    # 将训练集划分为两个部分,模拟两个客户端
    trainloader_1 = DataLoader(Subset(trainset, range(0, len(trainset)//2)), batch_size=BATCH_SIZE, shuffle=True)
    trainloader_2 = DataLoader(Subset(trainset, range(len(trainset)//2, len(trainset))), batch_size=BATCH_SIZE, shuffle=True)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)

    return trainloader_1, trainloader_2, testloader
  • get_mnist_loaders: 加载 MNIST 数据集,并将训练集划分为两个子集,模拟两个客户端。

6. 启动 Flower 服务器

def start_flower_server():
    strategy = fl.server.strategy.FedAvg()
    config = fl.server.ServerConfig(num_rounds=3)
    print("Starting Flower server...")
    fl.server.start_server(server_address="127.0.0.1:8080", strategy=strategy, config=config)
    print("Flower server started successfully.")
  • 启动 Flower 服务器并使用 FedAvg 策略进行参数聚合。
  • 服务器配置为 3 轮训练。

7. 启动 Flower 客户端

def start_flower_client(client_id, train_data, val_data):
    model = SimpleCNN().to(DEVICE)
    client = CNNClient(model, train_data, val_data)
    fl.client.start_client(server_address="127.0.0.1:8080", client=client.to_client())
  • 启动 Flower 客户端,客户端会连接到 Flower 服务器并开始训练和评估。

8. 主程序

if __name__ == "__main__":
    # 加载数据
    trainloader_1, trainloader_2, testloader = get_mnist_loaders()

    # 启动服务器
    server_process = mp.Process(target=start_flower_server)
    server_process.start()

    # 启动客户端
    client_process_1 = mp.Process(target=start_flower_client, args=(1, trainloader_1, testloader))
    client_process_2 = mp.Process(target=start_flower_client, args=(2, trainloader_2, testloader))
    client_process_1.start()
    client_process_2.start()

    # 等待所有进程完成
    client_process_1.join()
    client_process_2.join()

    # 终止服务器
    server_process
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐