CNN实现MNIST数据集分类
参照肖智清老师的“神经网络与PyTorch实战”import torchimport torch.utils.dataimport torch.nnimport torch.optimimport torchvision.datasetsimport torchvision.transformstrain_dataset = torchvision.datasets.MNIST(root='./d
·
参照肖智清老师的“神经网络与PyTorch实战”
import torch
import torch.utils.data
import torch.nn
import torch.optim
import torchvision.datasets
import torchvision.transforms
train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, transform=torchvision.transforms.ToTensor(), download=True)
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 64, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 128, kernel_size=3, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(stride=2, kernel_size=2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(128 * 14 * 14, 1024),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(1024, 10)
)
def forward(self, x):
x = self.conv1(x)
x = x.view(-1, 128 * 14 * 14)
x = self.dense(x)
return x
net = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())
num_epochs = 5
for epoch in range(num_epochs):
for idx, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
preds = net(images)
loss = criterion(preds, labels)
loss.backward()
optimizer.step()
if idx % 100 == 0:
print(f'epoch {epoch}, batch{idx}, loss = {loss.item():g}')
correct = 0
total = 0
for images, labels in test_loader:
preds = net(images)
predicted = torch.argmax(preds, 1)
total += labels.size(0)
correct += (predicted==labels).sum().item()
accuracy = correct / total
print(f'The accuracy in test dataset is {accuracy:.1%}')
正确率98.8%

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)