利用transforms Dataset DataLoader对图像数据进行处理并构建自己的数据集
1. torchvision.transforms在CV任务中,可以用此对图像进行预处理,数据增强等操作1.1 Transforms on Imageimport torchvision.transforms as transformsfrom PIL import Imageimg = Image.open('lena.png')img = img.convert("RGB")imgwidth,
·
1. torchvision.transforms
在CV任务中,可以用此对图像进行预处理,数据增强等操作
1.1 Transforms on Image
import torchvision.transforms as transforms
from PIL import Image
img = Image.open('lena.png')
img = img.convert("RGB")
img

width, height = img.size
print(width, height)
132 193
1.1.1 transforms.Resize
把给定的图片resize到给定的size
size = (100, 100)
transform = transforms.Resize(size=size)
resize_img = transform(img)
resize_img

1.1.2 transforms.CenterCrop
在图片的中心区域进行裁剪
size = (100, 100)
transform = transforms.CenterCrop(size=size)
centercrop_img = transform(img)
centercrop_img

1.1.3 transforms.RandomCrop
在图片上随机一个位置进行裁剪
size = (100, 100)
transform = transforms.RandomCrop(size=size)
randomcrop_img = transform(img)
randomcrop_img

1.1.4 transforms.RandomHorizontalFlip§
以概率为p水平翻转给定的图像
transform = transforms.RandomHorizontalFlip(p=0.5)
rpf_img = transform(img)
rpf_img

1.1.5 transforms.RandomVerticalFlip§
以概率为p垂直翻转给定的图像
transform = transforms.RandomVerticalFlip(p=0.5)
rvf_img = transform(img)
rvf_img

1.1.6 transforms.ColorJitter
随机修改图片的亮度、对比度和饱和度,常用来进行数据增强
brightness = (1, 10)
contrast = (1, 10)
saturation = (1, 10)
hue = (0.2, 0.4)
transform = transforms.ColorJitter(brightness, contrast, saturation, hue)
colorjitter_img = transform(img)
colorjitter_img

1.1.7 transforms.Grayscale
将图像转换为灰度图像
transform = transforms.Grayscale()
gary_img = transform(img)
gary_img

1.1.8 transforms.RandomGrayscale
以概率p将图像转换为灰度图像
transform = transforms.RandomGrayscale(p=0.5)
rg_img = transform(img)
rg_img

1.2 transforms on Tensor
1.2.1 transforms.ToTensor()
将Image转换为Tensor
transform = transforms.ToTensor()
tensor_img = transform(img)
tensor_img
tensor([[[0.7176, 0.7294, 0.7255, ..., 0.6627, 0.6549, 0.6627],
[0.7137, 0.7176, 0.7176, ..., 0.6510, 0.6510, 0.6549],
[0.7137, 0.7176, 0.7137, ..., 0.6392, 0.6431, 0.6353],
...,
[0.9922, 1.0000, 0.9725, ..., 0.6863, 0.6902, 0.7059],
[1.0000, 1.0000, 0.9961, ..., 0.6745, 0.6824, 0.6902],
[1.0000, 0.9961, 0.9882, ..., 0.6745, 0.6745, 0.6863]],
[[0.3843, 0.3922, 0.3922, ..., 0.3529, 0.3451, 0.3529],
[0.3765, 0.3804, 0.3804, ..., 0.3412, 0.3412, 0.3412],
[0.3765, 0.3804, 0.3804, ..., 0.3294, 0.3412, 0.3333],
...,
[0.8745, 0.8941, 0.8863, ..., 0.3294, 0.3490, 0.3647],
[0.9098, 0.9176, 0.9176, ..., 0.3216, 0.3373, 0.3490],
[0.9294, 0.9255, 0.9255, ..., 0.3216, 0.3294, 0.3412]],
[[0.2745, 0.2863, 0.2784, ..., 0.2353, 0.2235, 0.2353],
[0.2784, 0.2745, 0.2745, ..., 0.2353, 0.2353, 0.2314],
[0.2784, 0.2745, 0.2706, ..., 0.2275, 0.2392, 0.2353],
...,
[0.8706, 0.8824, 0.8627, ..., 0.2510, 0.2706, 0.2863],
[0.9216, 0.9176, 0.9059, ..., 0.2392, 0.2588, 0.2706],
[0.9451, 0.9333, 0.9255, ..., 0.2392, 0.2510, 0.2588]]])
1.2.2 transforms.Normalize
input[channel] = (input[channel] - mean[channel]) / std[channel]
transform = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
img_normal = transform(tensor_img)
img_normal
tensor([[[ 0.4353, 0.4588, 0.4510, ..., 0.3255, 0.3098, 0.3255],
[ 0.4275, 0.4353, 0.4353, ..., 0.3020, 0.3020, 0.3098],
[ 0.4275, 0.4353, 0.4275, ..., 0.2784, 0.2863, 0.2706],
...,
[ 0.9843, 1.0000, 0.9451, ..., 0.3725, 0.3804, 0.4118],
[ 1.0000, 1.0000, 0.9922, ..., 0.3490, 0.3647, 0.3804],
[ 1.0000, 0.9922, 0.9765, ..., 0.3490, 0.3490, 0.3725]],
[[-0.2314, -0.2157, -0.2157, ..., -0.2941, -0.3098, -0.2941],
[-0.2471, -0.2392, -0.2392, ..., -0.3176, -0.3176, -0.3176],
[-0.2471, -0.2392, -0.2392, ..., -0.3412, -0.3176, -0.3333],
...,
[ 0.7490, 0.7882, 0.7725, ..., -0.3412, -0.3020, -0.2706],
[ 0.8196, 0.8353, 0.8353, ..., -0.3569, -0.3255, -0.3020],
[ 0.8588, 0.8510, 0.8510, ..., -0.3569, -0.3412, -0.3176]],
[[-0.4510, -0.4275, -0.4431, ..., -0.5294, -0.5529, -0.5294],
[-0.4431, -0.4510, -0.4510, ..., -0.5294, -0.5294, -0.5373],
[-0.4431, -0.4510, -0.4588, ..., -0.5451, -0.5216, -0.5294],
...,
[ 0.7412, 0.7647, 0.7255, ..., -0.4980, -0.4588, -0.4275],
[ 0.8431, 0.8353, 0.8118, ..., -0.5216, -0.4824, -0.4588],
[ 0.8902, 0.8667, 0.8510, ..., -0.5216, -0.4980, -0.4824]]])
1.2.3 transforms.Compose
将多个变换组合在一起
img = Image.open('lena.png')
img = img.convert('RGB')
transform = transforms.Compose([
transforms.Resize(100),
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])
img_compose = transform(img)
img_compose.size()
torch.Size([3, 64, 64])
2. torchvision.datasets
用来进行数据加载的,下面以CIFAR-10数据集为例,其中transform表示对数据进行预处理,对应着上面所讲
import torchvision
trainset = torchvision.datasets.CIFAR10(
root='./dataset', # 数据集下载的地方
train=True, # True表示创建训练集;False表示创建测试集
download=True, # 如果为true,则从Internet下载数据集。如果已下载数据集,则不会再次下载
transform=None # 表示是否对数据进行预处理,None表示不做任何处理
)
3. torch.utils.data.DataLoader
import torch
from torch.utils.data.sampler import SubsetRandomSampler
trainloader = torch.utils.data.DataLoader(
dataset=trainset, # 加载torch.utils.data.Dataset对象数据或者是torchvision.datasets中的数据
batch_size=1, # 每个batch所含样本的大小
shuffle=False, # 是否对数据进行打乱
sampler=SubsetRandomSampler(indices=), # 按指定下标进行取样,如果此参数被指定,shuffle参数必须为False
drop_last=False, # 当整个数据集不能整除batch_size,False表示最后一个batch的大小会变小,True表示直接丢弃最后一个batch
num_workers=0 # 表示加载的时候子进程数
)
4. torch.utils.data.Dataset
from torch.utils.data.dataset import Dataset
# 基本框架
class CustomDataset(Dataset):
def __init__(self):
"""
一些初始化过程写在这里
"""
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
"""
返回数据和标签,可以这样显示调用:
img, label = MyCustomDataset.__getitem__(index)
"""
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
"""
返回所有数据的数量
"""
# You should change 9 to the total size of your dataset.
return 9 # e.g. 9 is size of dataset
目前我们有一个关于图像分类的问题,数据结构如下:

其中一个是训练文件夹,一个测试文件夹,分类的类别数为6个,其中每个文件夹包含很多图片
如何构建Custom Dataset
- 分别为训练集和测试集建立两个DataFrame文件,其中DataFrame文件有两列,一列是图片的名字,令一列为标签
| Images | Labels |
|---|---|
| 0.jpg | 0 |
| 99.jpg | 5 |
- 构建Custom Dataset
class INTELDataset(Dataset):
def __init__(self, img_data,img_path,transform=None):
self.img_path = img_path # 数据路径
self.transform = transform
self.img_data = img_data # DaraFrame
def __getitem__(self, index):
img_name = os.path.join(self.img_path,self.img_data.loc[index, 'labels'],
self.img_data.loc[index, 'Images']) # 图片路径
image = Image.open(img_name) # 获得图片
image = image.convert('RGB')
label = torch.tensor(self.img_data.loc[index, 'labels']) # 获得标签
if self.transform is not None:
image = self.transform(image)
return image, label
def __len__(self):
return len(self.img_data) # 数据大小
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)