数据集说明

MNIST数据集

MNIST手写数字数据集是基础的数字分类数据集,MNIST的训练集中包含了60000个数字,测试集中包含10000个数字。每张图都是灰度图,尺寸都为固定的28*28。
在这里插入图片描述

SVHN数据集

SVHN 是一个从谷歌街景图像中的门牌号获得的真实世界的图像数据集,用于开发机器学习和对象识别算法。

包含73257 个用于训练的数字,26032 个用于测试的数字,以及 531131 个额外的、难度稍低的样本,用作额外的训练数据。

有两种格式:

  1. 具有字符级边界框的原始图像。
  2. 以单个字符为中心的类似 MNIST 的 32×32 图像(许多图像的侧面包含一些干扰物)

格式一:

格式二:

通过下面的方式下载导入的是第二种格式的图片,若要自行下载数据集可到官网:http://ufldl.stanford.edu/housenumbers/

配对

利用torch.utils.data中的Dataset可以灵活的实现MNIST和SVHN的随机配对:

from torch.utils.data import Dataset
from torchvision import transforms
from torchvison.datasets import MNIST, SVHN
import random


class MNIST_SVHN(Dataset):
    data_transform = {
        'mnist': transforms.Compose([
            transforms.Resize([32, 32]),
            transforms.ToTensor(),
        ]),
        'svhn': transforms.ToTensor()
    }

    def __init__(self,
                 data_path: str,
                 split: str
                 ):
        if split == 'train':
            self.mnist = MNIST(data_path, train=True, download=False,
                               transform=self.data_transform['mnist'])
            self.svhn = SVHN(data_path, split='train', download=False,
                             transform=self.data_transform['svhn'])
        elif split == 'test':
            self.mnist = MNIST(data_path, train=False, download=False,
                               transform=self.data_transform['mnist'])
            self.svhn = SVHN(data_path, split='test', download=False,
                             transform=self.data_transform['svhn'])

    def __len__(self):
        return len(self.svhn)

    def __getitem__(self, idx):
        svhn_data, svhn_label = self.svhn[idx]
        mnist_l, mnist_l_idx = self.mnist.targets.sort()
        cor_label_list = mnist_l_idx[mnist_l == svhn_label]
        len_list = len(cor_label_list)
        random_idx = random.randrange(len_list)
        cor_minst_idx = cor_label_list[random_idx]
        mnist_data = self.mnist[cor_minst_idx][0]

        return [mnist_data, svhn_data], svhn_label
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐