前言

以下是一个使用MATLAB进行深度学习网络训练的完整示例,包含数据准备、模型构建、训练和评估等关键步骤:


1. 数据准备与加载

% 下载并准备示例数据集(手写数字识别)
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
    'nndatasets','DigitDataset');
imageDatastore = imageDatastore(digitDatasetPath, ...
    'IncludeSubfolders',true, 'LabelSource','foldernames');

% 划分训练集和测试集
[imdsTrain,imdsTest] = splitEachLabel(imageDatastore,0.7,'randomized');

2. 定义深度学习网络架构

% 创建卷积神经网络 (CNN)
layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,16,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,32,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    maxPooling2dLayer(2,'Stride',2)
    
    convolution2dLayer(3,64,'Padding','same')
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer
    classificationLayer];

% 设置训练选项
options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'MaxEpochs',10, ...
    'MiniBatchSize',128, ...
    'Shuffle','every-epoch', ...
    'ValidationData',imdsTest, ...
    'ValidationFrequency',30, ...
    'Verbose',false, ...
    'Plots','training-progress');

3. 训练网络

% 训练网络
net = trainNetwork(imdsTrain,layers,options);

4. 评估模型性能

% 在测试集上进行预测
YPred = classify(net,imdsTest);
YTest = imdsTest.Labels;

% 计算分类准确率
accuracy = mean(YPred == YTest);
fprintf('分类准确率: %.2f%%\n', accuracy*100);

% 创建混淆矩阵
cm = confusionmat(YTest, YPred);
figure
imagesc(cm)
colorbar
title('混淆矩阵')
xlabel('预测标签')
ylabel('真实标签')

% 显示一些预测结果示例
figure
numImages = 4;
for i = 1:numImages
    subplot(2,2,i)
    I = readimage(imdsTest,i);
    imshow(I)
    title(sprintf('预测: %d', YPred(i)))
end

5. 模型保存与加载

% 保存训练好的模型
save('digitClassifier.mat', 'net');

% 加载模型(在其他脚本中使用)
loadedNet = load('digitClassifier.mat');
net = loadedNet.net;

代码说明:

  • 数据准备:使用MATLAB内置的手写数字数据集,自动划分训练集和测试集
  • 网络架构:设计了一个包含卷积层、池化层和全连接层的典型CNN结构
  • 训练配置:使用随机梯度下降法(SGDM)优化器,设置学习率、批次大小和训练轮数
  • 模型评估:通过准确率和混淆矩阵评估模型性能,并可视化预测结果
  • 模型部署:提供了保存和加载模型的方法,方便后续使用

这个示例展示了MATLAB深度学习工具箱的基本用法,你可以根据自己的需求修改网络结构、训练参数或替换为自定义数据集。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐