使用flower进行联邦学习 CNN模型 MNIST数据集
联邦学习 cnn模型
·
目录
1. 导入依赖库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import flwr as fl
import multiprocessing as mp
from torch.utils.data import DataLoader, Subset
import numpy as np
from torchvision import datasets
# 如需使用,联系Q 596520206
torch
,torch.nn
,torch.optim
: PyTorch 的核心库,用于创建神经网络、定义优化器和计算损失。torchvision
: 用于加载常用的计算机视觉数据集(如 MNIST、CIFAR-10)以及预处理图像。flwr
(Flower): 一个用于联邦学习的框架,提供了客户端和服务器的通信接口。multiprocessing
: 用于在多个进程中并行启动 Flower 客户端和服务器。DataLoader
,Subset
: 用于加载数据集和切割训练集为不同客户端使用。numpy
: 用于数值计算。
2. 设备设置和超参数定义
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
EPOCHS = 3
LEARNING_RATE = 0.001
DEVICE
: 判断当前是否有可用的 GPU 设备,如果有则使用 GPU,否则使用 CPU。BATCH_SIZE
: 每次训练的批大小。EPOCHS
: 训练的轮数。LEARNING_RATE
: 优化器的学习率。
3. CNN模型定义
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 输出为10个类别
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7) # 展平
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
SimpleCNN
是一个简单的卷积神经网络(CNN),由两层卷积层、池化层和两层全连接层构成。conv1
和conv2
:卷积层,conv1
的输入为单通道(灰度图像),输出为32个特征图;conv2
的输入为32个特征图,输出为64个特征图。pool
:最大池化层,用于降采样。fc1
和fc2
:全连接层,将特征图展平后通过全连接层输出10个类别。
4. 联邦学习客户端定义
class CNNClient(fl.client.NumPyClient):
def __init__(self, model, train_data, val_data):
self.model = model
self.train_data = train_data
self.val_data = val_data
self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
self.criterion = nn.CrossEntropyLoss()
CNNClient
继承自fl.client.NumPyClient
,这是 Flower 框架提供的接口,代表了一个联邦学习的客户端。train_data
和val_data
:训练数据和验证数据。optimizer
:使用 Adam 优化器进行训练。criterion
:损失函数使用交叉熵损失函数(适用于分类问题)。
4.1 get_parameters
和 set_parameters
方法
def get_parameters(self, config=None):
return [val.cpu().numpy() for val in self.model.state_dict().values()]
def set_parameters(self, parameters):
state_dict = dict(zip(self.model.state_dict().keys(), parameters))
self.model.load_state_dict({k: torch.tensor(v) for k, v in state_dict.items()})
get_parameters
: 获取模型的参数,将其转换为 NumPy 数组返回。set_parameters
: 设置模型的参数,从传入的 NumPy 数组恢复模型权重。
4.2 fit
和 evaluate
方法
def fit(self, parameters, config):
self.set_parameters(parameters)
self.model.train()
total_loss = 0.0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(self.train_data):
data, target = data.to(DEVICE), target.to(DEVICE)
self.optimizer.zero_grad()
output = self.model(data)
loss = self.criterion(output, target)
loss.backward()
self.optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
avg_loss = total_loss / len(self.train_data)
accuracy = correct / total
print(f"Training loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
return self.get_parameters(), len(self.train_data.dataset), {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
self.model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in self.val_data:
data, target = data.to(DEVICE), target.to(DEVICE)
output = self.model(data)
_, predicted = torch.max(output.data, 1)
correct += (predicted == target).sum().item()
total += target.size(0)
accuracy = correct / total
print(f"Evaluation Accuracy: {accuracy:.4f}")
return accuracy, len(self.val_data.dataset), {}
fit
: 在客户端上进行本地训练,更新模型权重,计算并返回损失和准确率。evaluate
: 在客户端上进行模型评估,计算并返回准确率。
5. 数据加载和划分
def get_mnist_loaders():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
file_path = r"E:\t5\mnist"
trainset = datasets.MNIST(root=file_path, train=True, download=True, transform=transform)
testset = datasets.MNIST(root=file_path, train=False, download=True, transform=transform)
# 将训练集划分为两个部分,模拟两个客户端
trainloader_1 = DataLoader(Subset(trainset, range(0, len(trainset)//2)), batch_size=BATCH_SIZE, shuffle=True)
trainloader_2 = DataLoader(Subset(trainset, range(len(trainset)//2, len(trainset))), batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
return trainloader_1, trainloader_2, testloader
get_mnist_loaders
: 加载 MNIST 数据集,并将训练集划分为两个子集,模拟两个客户端。
6. 启动 Flower 服务器
def start_flower_server():
strategy = fl.server.strategy.FedAvg()
config = fl.server.ServerConfig(num_rounds=3)
print("Starting Flower server...")
fl.server.start_server(server_address="127.0.0.1:8080", strategy=strategy, config=config)
print("Flower server started successfully.")
- 启动 Flower 服务器并使用
FedAvg
策略进行参数聚合。 - 服务器配置为 3 轮训练。
7. 启动 Flower 客户端
def start_flower_client(client_id, train_data, val_data):
model = SimpleCNN().to(DEVICE)
client = CNNClient(model, train_data, val_data)
fl.client.start_client(server_address="127.0.0.1:8080", client=client.to_client())
- 启动 Flower 客户端,客户端会连接到 Flower 服务器并开始训练和评估。
8. 主程序
if __name__ == "__main__":
# 加载数据
trainloader_1, trainloader_2, testloader = get_mnist_loaders()
# 启动服务器
server_process = mp.Process(target=start_flower_server)
server_process.start()
# 启动客户端
client_process_1 = mp.Process(target=start_flower_client, args=(1, trainloader_1, testloader))
client_process_2 = mp.Process(target=start_flower_client, args=(2, trainloader_2, testloader))
client_process_1.start()
client_process_2.start()
# 等待所有进程完成
client_process_1.join()
client_process_2.join()
# 终止服务器
server_process

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)