联邦学习开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》代码
【代码】联邦学习开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》代码。
·
1 设置初值
# record time for running time
start_time = time.time()
# define paths
path_project = os.path.abspath('..')
logger = SummaryWriter('../logs')
# Parse parameters
args = args_parser()
exp_details(args)
# gpu
if args.gpu_id:
torch.cuda.set_device(args.gpu_id)
device = 'cuda' if args.gpu else 'cpu'
# load dataset and user groups
train_dataset, test_dataset, user_groups = get_dataset(args)
2 建立模型——CNN/MLP
# BUILD MODEL
if args.model == 'cnn':
# Convolutional neural netork
if args.dataset == 'mnist':
global_model = CNNMnist(args=args)
elif args.dataset == 'fmnist':
global_model = CNNFashion_Mnist(args=args)
elif args.dataset == 'cifar':
global_model = CNNCifar(args=args)
elif args.model == 'mlp':
# Multi-layer preceptron
img_size = train_dataset[0][0].shape
len_in = 1
for x in img_size:
len_in *= x
global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
else:
exit('Error: unrecognized model')
3 设置要训练的模型,并传递给device
# Set the model to train and send it to device.
global_model.to(device)
global_model.train()
print(global_model)
# weights
global_weights = global_model.state_dict()
4 训练参数
# Training
train_loss, train_accuracy = [], []
val_acc_list, net_list = [], []
cv_loss, cv_acc = [], []
print_every = 2
val_loss_pre, counter = 0, 0
5 训练过程——函数:for epoch in tqdm(range(args.epochs)):
5.1 局部训练
# 初始化weights和loss,记录选中的client
local_weights, local_losses = [], []
# 当前epoch
print(f'\n | Global Training Round : {epoch+1} |\n')
global_model.train()
# m:选中的clients数量
m = max(int(args.frac * args.num_users), 1)
# 选中的clients序号
idxs_users = np.random.choice(range(args.num_users), m, replace=False)
# 每个client训练
for idx in idxs_users:
# 局部模型更新
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
# 记录weights和loss
w, loss = local_model.update_weights( model=copy.deepcopy(global_model), global_round=epoch)
local_weights.append(copy.deepcopy(w))
local_losses.append(copy.deepcopy(loss))
5.2 全局更新
# update global weights, average
global_weights = average_weights(local_weights)
# update global weights
global_model.load_state_dict(global_weights)
loss_avg = sum(local_losses) / len(local_losses)
train_loss.append(loss_avg)
5.3 计算每个epoch结果
# Calculate avg training accuracy over all users at every epoch
list_acc, list_loss = [], []
global_model.eval()
for c in range(args.num_users):
local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
acc, loss = local_model.inference(model=global_model)
list_acc.append(acc)
list_loss.append(loss)
train_accuracy.append(sum(list_acc)/len(list_acc))
# print global training loss after every 'i' rounds
if (epoch+1) % print_every == 0:
print(f' \nAvg Training Stats after {epoch+1} global rounds:')
print(f'Training Loss : {np.mean(np.array(train_loss))}')
print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))
6 测试模型
# Test inference after completion of training
test_acc, test_loss = test_inference(args, global_model, test_dataset)
print(f' \n Results after {args.epochs} global rounds of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
# Saving the objects train_loss and train_accuracy:
file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\format(args.dataset, args.model, args.epochs, args.frac, args.iid,args.local_ep, args.local_bs)
with open(file_name, 'wb') as f:
pickle.dump([train_loss, train_accuracy], f)
print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)