pytorch数据增强
pytorch数据增强,以及transforms操作
数据增强
数据增强是在数据量比较少的情况下,通过对原有的数据进行灰度、裁切、旋转、镜像、明度、色调、饱和度变化的一系列过程,用来增加数据量。
import os
from PIL import Image
import torchvision.transforms as transforms
# 数据增强
def DataEnhance(sourth_path,aim_dir,size):
name=0
#得到源文件的文件夹
file_list=os.listdir(sourth_path)
#创建目标文件的文件夹
if not os.path.exists(aim_dir):
os.mkdir(aim_dir)
for i in file_list:
img=Image.open('%s\%s.png'%(sourth_path,i))
print(img.size)
name+=1
transform1=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Resize(size),
])
img1=transform1(img)
img1.save('%s/%s.png'%(aim_dir,name))
name+=1
transform2=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)
])
img2 = transform1(img)
img2.save('%s/%s.png' % (aim_dir, name))
name+=1
transform3=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop(227,pad_if_needed=True),
transforms.Resize(size)
])
img3 = transform1(img)
img3.save('%s/%s.png' % (aim_dir, name))
name+=1
transform4=transforms.Compose([
transforms.Compose(),
transforms.ToPILImage(),
transforms.RandomRotation(60),
transforms.Resize(size),
])
img4 = transform1(img)
img4.save('%s/%s.png' % (aim_dir, name))
补充
直接利用PIL进行数据增强
import os
from PIL import Image
import argparse
import shutil
parser = argparse.ArgumentParser(description="数据增强")
parser.add_argument("--input-path", type=str, default='./hhh', help="地址")
parser.add_argument("--out-path", type=str, default='./hh', help="地址")
opt=parser.parse_args()
pathlist=os.listdir(opt.input_path)
if not os.path.exists(opt.out_path):
os.mkdir(opt.out_path)
##尺度缩放
scale = [1]
##旋转角度
angle = [0,90]
##翻转 0代表不翻转 1代表水平翻转 2代表垂直翻转
flip=[0,1,2]
count=0
for image_name in pathlist:
image=Image.open(os.path.join(opt.input_path,image_name)).convert('RGB')
for i in range(len(scale)):
for j in range(len(angle)):
for c in range(len(flip)):
image=image.rotate(angle[j])
image=image.resize((int(image.size[0]*scale[i]), int(image.size[1]*scale[i])), Image.BICUBIC)
if flip[c]==1: ##水平翻转
image=image.transpose(Image.FLIP_LEFT_RIGHT)
elif flip[c]==2: ###垂直翻转
image.transpose(Image.FLIP_TOP_BOTTOM)
else:
image=image ##不翻转
count+=1
aa=image_name.split('.')[0]+'{}_{}_{}.png'.format(i,j,c)
image.save(os.path.join(opt.out_path,aa))
# image.show()
数据增强并不是真的增多了样本的数量,因为在训练的过程中,数据增强函数会随机对图片进行处理,比如对图片进行裁剪,将图片翻转,这样在多训练几轮之后就相当于增加了原本不属于数据集的图片,也就是实现了数据增强。
常见的图像变换
裁剪
中心裁剪(transforms.CenterCrop)
作用:Crops the given PIL Image at the center
from PIL import Image
import torchvision.transforms as transforms
img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.CenterCrop((224,224)),
])
img=transform1(img)
img.show()
原始图片

裁剪后图片

随机裁剪(transforms.RandomCrop)
from PIL import Image
import torchvision.transforms as transforms
img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((224,224)),
])
img=transform1(img)
img.show()

翻转和旋转
依概率p水平翻转transforms.RandomHorizontalFlip
from PIL import Image
import torchvision.transforms as transforms
img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
# transforms.RandomCrop((224,224)),
transforms.RandomHorizontalFlip(p=0.999999)
])
img=transform1(img)
img.show()

依概率p垂直翻转transforms.RandomVerticalFlip
以给定的概率随机垂直翻转给定的PIL图像
torchvision.transforms.RandomVerticalFlip(p=0.5)
随机旋转:transforms.RandomRotation
torchvision.transforms.RandomRotation(degrees)
图像变换
resize:transforms.Resize()
Resize the input PIL Image to the given size.注意这些函数的输入都是PIL格式
from PIL import Image
import torchvision.transforms as transforms
img=Image.open('./Set14/baboon.png')
transform1=transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.Resize((224,224),interpolation=2)
])
img=transform1(img)
print(img.size)
img.show()
标准化:transforms.Normalize()
torchvision.transforms.Normalize(mean, std)
转为tensor:transforms.ToTensor()
将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1]
将数据转换为PILImage:transforms.ToPILImage(mode=None)
将tensor 或者 ndarray的数据转换为 PIL Image 类型数据 参数: mode- 为None时,为1通道, mode=3通道默认转换为RGB,4通道默认转换为RGBA。
对transforms操作,使数据增强更灵活
transforms.RandomChoice(transforms)
从给定的一系列transforms中选一个进行操作
参考博客:
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)