话不多说,直接干货,计算一个自定义数据集的三个通道的标准差和方差,这里使用了两个文件,getStat.py和ImageDataset.py,还有两个数据集的文件(夹),一个文件夹下面放所有的数据(图片),另一个记录图片和类train.csv,这里放了整个数据集的标签信息,并不仅仅是训练集的(其实这个我直接拿我的来用的–把所有数据的名称标签都放进去了,然后通过ImageDataset.py取出,感兴趣的小伙伴可以自己再研究,其实可以不用标签)
下面是目录结构:
在这里插入图片描述
下面是.csv文件数据 前面是文件名后面是标签(可以是任意的标签形式,不必要是数字)
在这里插入图片描述

getStat.py代码如下:

# -*- coding: utf-8 -*-
# @Time : 2020/11/8 7:40 下午
# @Author : ligang
# @FileName: getStat.py
# @Email   : ucasligang@163.com
# @Software: PyCharm
import torch

from ImageDataset import ImageDataset


def getStat(train_data):
    '''
    Compute mean and variance for training data
    :param train_data: 自定义类Dataset(或ImageFolder即可)
    :return: (mean, std)
    '''
    print('Compute mean and variance for training data.')
    print(len(train_data))
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())


if __name__ == '__main__':

    train_dataset = ImageDataset('', 'data')
    print(getStat(train_dataset))


ImageDataSet.py文件如下:

import os.path as osp
from PIL import Image

from torch.utils.data import Dataset
from torchvision import transforms

class ImageDataset(Dataset):

    def __init__(self, ROOT_PATH, setname):
        csv_path = osp.join(ROOT_PATH, setname + '.csv')
        lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
        data = []
        label = []
        lb = -1

        self.wnids = []

        for l in lines:
            name, wnid = l.split(',')
            path = osp.join(ROOT_PATH, 'images', name)
            if wnid not in self.wnids:
                self.wnids.append(wnid)
                lb += 1
            data.append(path)
            label.append(lb)

        self.data = data
        self.label = label

        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        path, label = self.data[i], self.label[i]
        image = self.transform(Image.open(path).convert('RGB'))
        return image, label


Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐