半监督学习之数据加载
1 图像分类数据集半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)参考资料https://blog.csdn.net/Z609834342/article/details/106863690...
·
图像分类数据集
半监督数据加载:把需要设置为无标签样本的标签设置为-1,这样可以在交叉熵的时候设置忽略-1的标签
class_criterion = nn.CrossEntropyLoss(reduction='sum', ignore_index=NO_LABEL)
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from functools import reduce
from operator import __or__
from torch.utils.data.sampler import Sampler
import itertools
import numpy as np
def load_data(path, args, NO_LABEL=-1):
if args.dataset == 'cifar10':
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]
elif args.dataset == 'cifar100':
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
std = [x / 255 for x in [68.2, 65.4, 70.4]]
elif args.dataset == 'svhn':
mean = [x / 255 for x in [127.5, 127.5, 127.5]]
std = [x / 255 for x in [127.5, 127.5, 127.5]]
elif args.dataset == 'mnist':
mean = (0.5, )
std = (0.5, )
elif args.dataset == 'stl10':
assert False, 'Do not finish stl10 code'
elif args.dataset == 'imagenet':
assert False, 'Do not finish imagenet code'
else:
assert False, "Unknow dataset : {}".format(args.dataset)
if args.dataset == 'svhn':
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
elif args.dataset == 'mnist':
train_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
else:
train_transform = TransformTwice(transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=2),
transforms.ToTensor(),
transforms.Normalize(mean, std)
]))
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
if args.dataset == 'cifar10':
train_data = datasets.CIFAR10(path, train=True, transform=train_transform, download=True)
test_data = datasets.CIFAR10(path, train=False, transform=test_transform, download=True)
num_classes = 10
elif args.dataset == 'cifar100':
train_data = datasets.CIFAR100(path, train=True, transform=train_transform, download=True)
test_data = datasets.CIFAR100(path, train=False, transform=test_transform, download=True)
num_classes = 100
elif args.dataset == 'svhn':
train_data = datasets.SVHN(path, split='train', transform=train_transform, download=True)
test_data = datasets.SVHN(path, split='test', transform=test_transform, download=True)
num_classes = 10
elif args.dataset == 'mnist':
train_data = datasets.MNIST(path, train=True, transform=train_transform, download=True)
test_data = datasets.MNIST(path, train=False, transform=test_transform, download=True)
num_classes = 10
elif args.dataset == 'stl10':
train_data = datasets.STL10(path, split='train', transform=train_transform, download=True)
test_data = datasets.STL10(path, split='test', transform=test_transform, download=True)
num_classes = 10
elif args.dataset == 'imagenet':
assert False, 'Do not finish imagenet code'
else:
assert False, 'Do not support dataset : {}'.format(args.dataset)
labeled_idxs, unlabeled_idxs = spilt_l_u(args.dataset, train_data, args.num_labels)
# if args.labeled_batch_size:
# batch_sampler = TwoStreamBatchSampler(
# unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
# else:
# assert False, "labeled batch size {}".format(args.labeled_batch_size)
if args.dataset == 'svhn':
train_data.labels = np.array(train_data.labels)
train_data.labels[unlabeled_idxs] = NO_LABEL
else:
train_data.targets = np.array(train_data.targets)
train_data.targets[unlabeled_idxs] = NO_LABEL
train_loader = DataLoader(train_data,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
drop_last=True)
eval_loader = DataLoader(
test_data,
batch_size=args.eval_batch_size,
shuffle=False,
num_workers=args.workers, # Needs images twice as fast
pin_memory=True,
drop_last=False)
return train_loader, eval_loader
def spilt_l_u(dataset, train_data, num_labels, num_val=400, classes=10):
if dataset == 'mnist':
labels = train_data.targets.numpy()
elif dataset == 'svhn':
labels = train_data.labels
else:
labels = train_data.targets
v = num_val
n = int(num_labels / classes)
(indices,) = np.where(reduce(__or__, [labels == i for i in np.arange(classes)]))
# Ensure uniform distribution of labels
np.random.shuffle(indices)
indices_train = np.hstack(
[list(filter(lambda idx: labels[idx] == i, indices))[:n] for i in range(classes)])
indices_unlabelled = np.hstack(
[list(filter(lambda idx: labels[idx] == i, indices))[n:] for i in range(classes)])
indices_train = torch.from_numpy(indices_train)
indices_unlabelled = torch.from_numpy(indices_unlabelled)
return indices_train, indices_unlabelled
class TransformTwice:
def __init__(self, transform):
self.transform = transform
def __call__(self, inp):
out1 = self.transform(inp)
out2 = self.transform(inp)
return out1, out2
class TwoStreamBatchSampler(Sampler):
"""
Labeled + unlabeled data in a batch
Iterate two sets of indices
An 'epoch' is one iteration through the primary indices.
During the epoch, the secondary indices are iterated through
as many times as needed.
"""
def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
self.primary_indices = primary_indices
self.secondary_indices = secondary_indices
self.secondary_batch_size = secondary_batch_size
self.primary_batch_size = batch_size - secondary_batch_size
assert len(self.primary_indices) >= self.primary_batch_size > 0
assert len(self.secondary_indices) >= self.secondary_batch_size > 0
def __iter__(self):
primary_iter = iterate_once(self.primary_indices)
secondary_iter = iterate_eternally(self.secondary_indices)
return (
primary_batch + secondary_batch
for (primary_batch, secondary_batch)
in zip(grouper(primary_iter, self.primary_batch_size),
grouper(secondary_iter, self.secondary_batch_size))
)
def __len__(self):
return len(self.primary_indices) // self.primary_batch_size
def iterate_once(iterable):
return np.random.permutation(iterable)
def iterate_eternally(indices):
def infinite_shuffles():
while True:
yield np.random.permutation(indices)
return itertools.chain.from_iterable(infinite_shuffles())
def grouper(iterable, n):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3) --> ABC DEF"
args = [iter(iterable)] * n
return zip(*args)
参考资料
- https://blog.csdn.net/Z609834342/article/details/106863690

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)