图像分类数据集

半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签

class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as np
 
def load_data(path, args, NO_LABEL=-1):
    if args.dataset == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif args.dataset == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif args.dataset == 'svhn':
        mean = [x / 255 for x in [127.5, 127.5, 127.5]]
        std = [x / 255 for x in [127.5, 127.5, 127.5]]
    elif args.dataset == 'mnist':
        mean = (0.5, )
        std = (0.5, )
    elif args.dataset == 'stl10':
        assert False, 'Do not finish stl10 code'
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, "Unknow dataset : {}".format(args.dataset)
 
    if args.dataset == 'svhn':
        train_transform = transforms.Compose([
             transforms.RandomCrop(32, padding=2),
             transforms.ToTensor(),
             transforms.Normalize(mean, std)
        ])
        test_transform = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize(mean, std)
        ])
    elif args.dataset == 'mnist':
 
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])
    else:
        train_transform = TransformTwice(transforms.Compose([
             transforms.RandomHorizontalFlip(),
             transforms.RandomCrop(32, padding=2),
             transforms.ToTensor(),
             transforms.Normalize(mean, std)
        ]))
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])
 
    if args.dataset == 'cifar10':
        train_data = datasets.CIFAR10(path, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR10(path, train=False, transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'cifar100':
        train_data = datasets.CIFAR100(path, train=True, transform=train_transform, download=True)
        test_data = datasets.CIFAR100(path, train=False, transform=test_transform, download=True)
        num_classes = 100
    elif args.dataset == 'svhn':
        train_data = datasets.SVHN(path, split='train', transform=train_transform, download=True)
        test_data = datasets.SVHN(path, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'mnist':
        train_data = datasets.MNIST(path, train=True, transform=train_transform, download=True)
        test_data = datasets.MNIST(path, train=False, transform=test_transform, download=True)
        num_classes = 10
 
    elif args.dataset == 'stl10':
        train_data = datasets.STL10(path, split='train', transform=train_transform, download=True)
        test_data = datasets.STL10(path, split='test', transform=test_transform, download=True)
        num_classes = 10
    elif args.dataset == 'imagenet':
        assert False, 'Do not finish imagenet code'
    else:
        assert False, 'Do not support dataset : {}'.format(args.dataset)
 
 
    labeled_idxs, unlabeled_idxs = spilt_l_u(args.dataset, train_data, args.num_labels)
 
    # if args.labeled_batch_size:
    # batch_sampler = TwoStreamBatchSampler(
    #     unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    # else:
    #     assert False, "labeled batch size {}".format(args.labeled_batch_size)
 
    if args.dataset == 'svhn':
        train_data.labels = np.array(train_data.labels)
        train_data.labels[unlabeled_idxs] = NO_LABEL
    else:
        train_data.targets = np.array(train_data.targets)
        train_data.targets[unlabeled_idxs] = NO_LABEL
 
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=True)
 
    eval_loader = DataLoader(
        test_data,
        batch_size=args.eval_batch_size,
        shuffle=False,
        num_workers=args.workers,  # Needs images twice as fast
        pin_memory=True,
        drop_last=False)
 
    return train_loader, eval_loader
 
def spilt_l_u(dataset, train_data, num_labels, num_val=400, classes=10):
 
  
    if dataset == 'mnist':
        labels = train_data.targets.numpy()
    elif dataset == 'svhn':
        labels = train_data.labels
    else:
        labels = train_data.targets
 
    v = num_val
    n = int(num_labels / classes)
    (indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(classes)]))
    # Ensure uniform distribution of labels
    np.random.shuffle(indices)
 
    indices_train = np.hstack(
        [list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(classes)])
    indices_unlabelled = np.hstack(
        [list(filter(lambda idx: labels[idx] == i, indices))[n:] for i in range(classes)])
 
    indices_train = torch.from_numpy(indices_train)
    indices_unlabelled = torch.from_numpy(indices_unlabelled)
 
 
    return indices_train, indices_unlabelled
 
class TransformTwice:
    def __init__(self, transform):
        self.transform = transform
 
    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2
 
 
class TwoStreamBatchSampler(Sampler):
    """
    Labeled + unlabeled data in a batch
    Iterate two sets of indices
    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size
 
        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0
 
    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in  zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )
 
    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size
 
def iterate_once(iterable):
    return np.random.permutation(iterable)
 
 
def iterate_eternally(indices):
    def infinite_shuffles():
        while True:
            yield np.random.permutation(indices)
    return itertools.chain.from_iterable(infinite_shuffles())
 
 
def grouper(iterable, n):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3) --> ABC DEF"
    args = [iter(iterable)] * n
    return zip(*args)

参考资料

  1. https://blog.csdn.net/Z609834342/article/details/106863690
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐