统计数据集的标准差和方差
话不多说,直接干货,计算一个自定义数据集的三个通道的标准差和方差,这里使用了两个文件,getStat.py和ImageDataset.py,还有两个数据集的文件(夹),一个文件夹下面放所有的数据(图片),另一个记录图片和类train.csv,这里放了整个数据集的标签信息,并不仅仅是训练集的(其实这个我直接拿我的来用的–把所有数据的名称标签都放进去了,然后通过ImageDataset.py取出,感兴
·
话不多说,直接干货,计算一个自定义数据集的三个通道的标准差和方差,这里使用了两个文件,getStat.py和ImageDataset.py,还有两个数据集的文件(夹),一个文件夹下面放所有的数据(图片),另一个记录图片和类train.csv,这里放了整个数据集的标签信息,并不仅仅是训练集的(其实这个我直接拿我的来用的–把所有数据的名称标签都放进去了,然后通过ImageDataset.py取出,感兴趣的小伙伴可以自己再研究,其实可以不用标签)
下面是目录结构:
下面是.csv文件数据 前面是文件名后面是标签(可以是任意的标签形式,不必要是数字)
getStat.py代码如下:
# -*- coding: utf-8 -*-
# @Time : 2020/11/8 7:40 下午
# @Author : ligang
# @FileName: getStat.py
# @Email : ucasligang@163.com
# @Software: PyCharm
import torch
from ImageDataset import ImageDataset
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)
'''
print('Compute mean and variance for training data.')
print(len(train_data))
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageDataset('', 'data')
print(getStat(train_dataset))
ImageDataSet.py文件如下:
import os.path as osp
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class ImageDataset(Dataset):
def __init__(self, ROOT_PATH, setname):
csv_path = osp.join(ROOT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
data = []
label = []
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(ROOT_PATH, 'images', name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
data.append(path)
label.append(lb)
self.data = data
self.label = label
self.transform = transforms.Compose([
transforms.ToTensor(),
])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, label = self.data[i], self.label[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, label

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)