数据增强

数据增强是在数据量比较少的情况下,通过对原有的数据进行灰度、裁切、旋转、镜像、明度、色调、饱和度变化的一系列过程,用来增加数据量。

import os
from PIL import Image
import torchvision.transforms as transforms

# 数据增强
def DataEnhance(sourth_path,aim_dir,size):
    name=0
    #得到源文件的文件夹
    file_list=os.listdir(sourth_path)
    #创建目标文件的文件夹
    if not os.path.exists(aim_dir):
        os.mkdir(aim_dir)

    for i in file_list:
        img=Image.open('%s\%s.png'%(sourth_path,i))
        print(img.size)

        name+=1
        transform1=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.Resize(size),
        ])
        img1=transform1(img)
        img1.save('%s/%s.png'%(aim_dir,name))

        name+=1
        transform2=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)
        ])
        img2 = transform1(img)
        img2.save('%s/%s.png' % (aim_dir, name))

        name+=1
        transform3=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop(227,pad_if_needed=True),
            transforms.Resize(size)
        ])
        img3 = transform1(img)
        img3.save('%s/%s.png' % (aim_dir, name))

        name+=1
        transform4=transforms.Compose([
            transforms.Compose(),
            transforms.ToPILImage(),
            transforms.RandomRotation(60),
            transforms.Resize(size),
        ])
        img4 = transform1(img)
        img4.save('%s/%s.png' % (aim_dir, name))


补充 

直接利用PIL进行数据增强

import os
from PIL import Image
import argparse
import shutil

parser = argparse.ArgumentParser(description="数据增强")
parser.add_argument("--input-path", type=str, default='./hhh', help="地址")
parser.add_argument("--out-path", type=str, default='./hh', help="地址")

opt=parser.parse_args()

pathlist=os.listdir(opt.input_path)

if not os.path.exists(opt.out_path):
    os.mkdir(opt.out_path)

##尺度缩放
scale = [1]
##旋转角度
angle = [0,90]
##翻转  0代表不翻转 1代表水平翻转 2代表垂直翻转
flip=[0,1,2]
count=0

for image_name in pathlist:
    image=Image.open(os.path.join(opt.input_path,image_name)).convert('RGB')

    for i in range(len(scale)):
        for j in range(len(angle)):
            for c in range(len(flip)):

                image=image.rotate(angle[j])
                image=image.resize((int(image.size[0]*scale[i]), int(image.size[1]*scale[i])), Image.BICUBIC)

                if flip[c]==1:  ##水平翻转
                    image=image.transpose(Image.FLIP_LEFT_RIGHT)
                elif flip[c]==2:  ###垂直翻转
                    image.transpose(Image.FLIP_TOP_BOTTOM)
                else:
                    image=image   ##不翻转

                count+=1
                aa=image_name.split('.')[0]+'{}_{}_{}.png'.format(i,j,c)
                image.save(os.path.join(opt.out_path,aa))
                # image.show()




数据增强并不是真的增多了样本的数量,因为在训练的过程中,数据增强函数会随机对图片进行处理,比如对图片进行裁剪,将图片翻转,这样在多训练几轮之后就相当于增加了原本不属于数据集的图片,也就是实现了数据增强。 

常见的图像变换

裁剪

中心裁剪(transforms.CenterCrop)

作用:Crops the given PIL Image at the center

from PIL import Image
import torchvision.transforms as transforms

img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.CenterCrop((224,224)),
        ])

img=transform1(img)
img.show()

原始图片

 裁剪后图片

随机裁剪(transforms.RandomCrop)

from PIL import Image
import torchvision.transforms as transforms

img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.RandomCrop((224,224)),
        ])

img=transform1(img)
img.show()

翻转和旋转

依概率p水平翻转transforms.RandomHorizontalFlip

from PIL import Image
import torchvision.transforms as transforms

img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            # transforms.RandomCrop((224,224)),
            transforms.RandomHorizontalFlip(p=0.999999)
        ])

img=transform1(img)
img.show()

依概率p垂直翻转transforms.RandomVerticalFlip

以给定的概率随机垂直翻转给定的PIL图像

torchvision.transforms.RandomVerticalFlip(p=0.5)

随机旋转:transforms.RandomRotation

torchvision.transforms.RandomRotation(degrees)

图像变换

resize:transforms.Resize()

Resize the input PIL Image to the given size.注意这些函数的输入都是PIL格式
from PIL import Image
import torchvision.transforms as transforms

img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
            transforms.ToTensor(),
            transforms.ToPILImage(),
            transforms.Resize((224,224),interpolation=2)
        ])
img=transform1(img)
print(img.size)
img.show()

标准化:transforms.Normalize()

torchvision.transforms.Normalize(mean, std)

转为tensor:transforms.ToTensor()

将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1] 

将数据转换为PILImage:transforms.ToPILImage(mode=None)

将tensor 或者 ndarray的数据转换为 PIL Image 类型数据 参数: mode- 为None时,为1通道, mode=3通道默认转换为RGB,4通道默认转换为RGBA。

对transforms操作,使数据增强更灵活

transforms.RandomChoice(transforms)

从给定的一系列transforms中选一个进行操作

参考博客:

pytorch实现AlexNet(含完整代码)_不会水的鱼o的博客-CSDN博客_alexnet pytorch

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐