原文链接:Pytorch 图像增强教程 - 知乎

使用数据增强技术可以增加数据集中图像的多样性,从而提高模型的性能和泛化能力。主要的图像增强技术包括:

  • 调整大小

  • 灰度变换

  • 标准化

  • 随机旋转

  • 中心裁剪

  • 随机裁剪

  • 高斯模糊

  • 亮度、对比度调节

  • 水平翻转

  • 垂直翻转

调整大小

在开始图像大小的调整之前我们需要导入数据(图像以眼底图像为例)。

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import numpy as np
import torchvision.transforms as T

plt.rcParams["savefig.bbox"] = 'tight'
orig_img = Image.open(Path('image/000001.tif'))
torch.manual_seed(0) # 设置 CPU 生成随机数的 种子 ,方便下次复现实验结果
print(np.asarray(orig_img).shape) #(800, 800, 3)

#图像大小的调整
resized_imgs = [T.Resize(size=size)(orig_img) for size in [128,256]]
# plt.figure('resize:128*128')
ax1 = plt.subplot(131)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(132)
ax2.set_title('resize:128*128')
ax2.imshow(resized_imgs[0])

ax3 = plt.subplot(133)
ax3.set_title('resize:256*256')
ax3.imshow(resized_imgs[1])

plt.show()

 

灰度变换

此操作将RGB图像转化为灰度图像。

gray_img = T.Grayscale()(orig_img)
# plt.figure('resize:128*128')
ax1 = plt.subplot(121)
ax1.set_title('original')
ax1.imshow(orig_img)

ax2 = plt.subplot(122)
ax2.set_title('gray')
ax2.imshow(gray_img,cmap='gray')

标准化

标准化可以加快基于神经网络结构的模型的计算速度,加快学习速度。

  • 从每个输入通道中减去通道平均值

  • 将其除以通道标准差。

    normalized_img = T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(T.ToTensor()(orig_img))
    normalized_img = [T.ToPILImage()(normalized_img)]
    # plt.figure('resize:128*128')
    ax1 = plt.subplot(121)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(122)
    ax2.set_title('normalize')
    ax2.imshow(normalized_img[0])
    
    plt.show()

  • 随机旋转

    设计角度旋转图像

  • from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    
    rotated_imgs = [T.RandomRotation(degrees=90)(orig_img)]
    print(rotated_imgs)
    plt.figure('resize:128*128')
    ax1 = plt.subplot(121)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(122)
    ax2.set_title('90°')
    ax2.imshow(np.array(rotated_imgs[0]))

  • 随机裁剪

    随机剪切图像的某一部分

    from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    
    random_crops = [T.RandomCrop(size=size)(orig_img) for size in (400,300)]
    
    plt.figure('resize:128*128')
    ax1 = plt.subplot(131)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(132)
    ax2.set_title('400*400')
    ax2.imshow(np.array(random_crops[0]))
    
    ax3 = plt.subplot(133)
    ax3.set_title('300*300')
    ax3.imshow(np.array(random_crops[1]))
    
    plt.show()

    高斯模糊

    使用高斯核对图像进行模糊变换

  • from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    
    blurred_imgs = [T.GaussianBlur(kernel_size=(3, 3), sigma=sigma)(orig_img) for sigma in (3,7)]
    
    plt.figure('resize:128*128')
    ax1 = plt.subplot(131)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(132)
    ax2.set_title('sigma=3')
    ax2.imshow(np.array(blurred_imgs[0]))
    
    ax3 = plt.subplot(133)
    ax3.set_title('sigma=7')
    ax3.imshow(np.array(blurred_imgs[1]))
    
    plt.show()

  • 亮度、对比度和饱和度调节

  • from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    # random_crops = [T.RandomCrop(size=size)(orig_img) for size in (832,704, 256)]
    colorjitter_img = [T.ColorJitter(brightness=(2,2), contrast=(0.5,0.5), saturation=(0.5,0.5))(orig_img)]
    
    plt.figure('resize:128*128')
    ax1 = plt.subplot(121)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    ax2 = plt.subplot(122)
    ax2.set_title('colorjitter_img')
    ax2.imshow(np.array(colorjitter_img[0]))
    plt.show()

  • 水平翻转

  • from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    
    HorizontalFlip_img = [T.RandomHorizontalFlip(p=1)(orig_img)]
    
    plt.figure('resize:128*128')
    ax1 = plt.subplot(121)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(122)
    ax2.set_title('colorjitter_img')
    ax2.imshow(np.array(HorizontalFlip_img[0]))
    
    
    plt.show()

  • 垂直翻转

  • from PIL import Image
    from pathlib import Path
    import matplotlib.pyplot as plt
    import numpy as np
    import sys
    import torch
    import numpy as np
    import torchvision.transforms as T
    
    
    plt.rcParams["savefig.bbox"] = 'tight'
    orig_img = Image.open(Path('image/2.png'))
    
    VerticalFlip_img = [T.RandomVerticalFlip(p=1)(orig_img)]
    
    plt.figure('resize:128*128')
    ax1 = plt.subplot(121)
    ax1.set_title('original')
    ax1.imshow(orig_img)
    
    ax2 = plt.subplot(122)
    ax2.set_title('VerticalFlip')
    ax2.imshow(np.array(VerticalFlip_img[0]))
    
    # ax3 = plt.subplot(133)
    # ax3.set_title('sigma=7')
    # ax3.imshow(np.array(blurred_imgs[1]))
    
    plt.show()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐