参照肖智清老师的“神经网络与PyTorch实战”

import torch
import torch.utils.data
import torch.nn
import torch.optim
import torchvision.datasets
import torchvision.transforms

train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, transform=torchvision.transforms.ToTensor(), download=True)

batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size)

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(1, 64, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(stride=2, kernel_size=2)
        )
        self.dense = torch.nn.Sequential(
            torch.nn.Linear(128 * 14 * 14, 1024),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=0.5),
            torch.nn.Linear(1024, 10)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = x.view(-1, 128 * 14 * 14)
        x = self.dense(x)
        return x


net = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
num_epochs = 5
for epoch in range(num_epochs):
    for idx, (images, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        preds = net(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if idx % 100 == 0:
            print(f'epoch {epoch}, batch{idx}, loss = {loss.item():g}')

correct = 0
total = 0
for images, labels in test_loader:
    preds = net(images)
    predicted = torch.argmax(preds, 1)
    total += labels.size(0)
    correct += (predicted==labels).sum().item()

accuracy = correct / total
print(f'The accuracy in test dataset is {accuracy:.1%}')

正确率98.8%
在这里插入图片描述

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐