二、PyTorch 数据处理板块

1.1 Dataset 类:定义数据集

概述:

Dataset 类是 PyTorch 用于封装数据的基础类,通常通过继承 torch.utils.data.Dataset 来创建自定义数据集类,并实现两个关键方法:

  • __len__(self):返回数据集的大小(即样本的数量)。
  • __getitem__(self, idx):根据索引 idx 返回数据集中的某一项数据,通常返回 (数据, 标签)
示例代码:
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        """
        初始化数据集。
        :param data: 数据,通常是一个张量或数组,大小为 (N, C, H, W) 或 (N, feature_size)
        :param labels: 标签,通常是一个张量,大小为 (N,)
        """
        self.data = data
        self.labels = labels

    def __len__(self):
        """
        返回数据集大小,即数据集中的样本数
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        根据索引 idx 返回数据和标签
        :param idx: 数据索引
        :return: 数据样本和对应标签
        """
        sample = self.data[idx]
        label = self.labels[idx]
        return sample, label
案例:自定义数据集 MyDataset 获取图像数据和标签信息

在此示例中,我们通过自定义数据集来读取图片路径,并将其与标签关联。

from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import os

class MyDataset(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.img_path = os.listdir(os.path.join(self.root_dir, self.label_dir))
        
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        return img_name, self.label_dir
    
    def __len__(self):
        return len(self.img_path)
    
# 示例
md = MyDataset('hymenoptera_data_1/train', 'ants')
img, label = md[0]
ant_image = plt.imread(f"hymenoptera_data_1/train/ants/{img}")
plt.imshow(ant_image)

1.2 DataLoader 类:批量加载数据

概述:

DataLoader 是 PyTorch 中用于批量加载数据的工具,能够自动将 Dataset 中的数据分批,并支持多线程加载,极大提高了训练效率。DataLoader 的常见参数如下:

  • batch_size:每个批次加载多少数据。
  • shuffle:是否在每个 epoch 结束时打乱数据集,通常用于训练数据。
  • num_workers:用于加载数据的子进程数量。如果计算机有多个 CPU 核心,可以增加此参数来加速数据加载。
  • collate_fn:指定如何将一个批次的数据聚合成一个批量,通常可以直接使用默认实现。
  • drop_last:当数据集大小不能被 batch_size 整除时,是否舍弃最后一个小批次。
示例代码:
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

# 下载CIFAR10数据集
test_set = torchvision.datasets.CIFAR10(
    root='data-test',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

# 使用DataLoader批量加载数据
dataLoader = DataLoader(
    dataset=test_set,
    batch_size=64,
    shuffle=True,
    num_workers=0,
    drop_last=False
)

# 使用TensorBoard可视化数据
writer = SummaryWriter('logs')

step = 0
for epoch in range(2):
    for data in dataLoader:
        imgs, labels = data
        writer.add_images(f'Epoch:{epoch}', imgs, step)
        step += 1
    step = 0
writer.close()

1.3 TensorBoard 可视化

概述:

TensorBoard 是一个强大的可视化工具,专门用于可视化深度学习训练过程中的各类信息。它帮助用户追踪损失、精度、学习率、梯度等参数的变化,甚至可以展示训练中的图像、音频等数据。

  1. 创建一个 SummaryWriter 实例:用来记录训练过程中的各种信息。
  2. 记录标量数据:例如损失、精度等,TensorBoard 会将它们绘制成曲线图。
  3. 启动 TensorBoard:通过命令行启动 TensorBoard,查看训练过程中的可视化结果。

启动 TensorBoard:

tensorboard --logdir=logs
示例代码:记录和可视化标量数据
from torch.utils.tensorboard import SummaryWriter
import numpy as np

# 创建SummaryWriter实例
writer = SummaryWriter('logs')

# 记录sin函数的变化曲线
for i in np.arange(-2 * np.pi, 2 * np.pi, 0.1):
    writer.add_scalar(
        tag='sin function',  # 图表标签
        global_step=i,  # x坐标
        scalar_value=np.sin(i)  # y坐标
    )

writer.close()

# 在Jupyter中查看结果
%load_ext tensorboard
%tensorboard --logdir=logs
示例代码:记录和可视化图像数据
# 创建随机图像数据
image_data = np.random.rand(256, 256, 3)

# 记录图像数据到TensorBoard
writer.add_image("random_image", image_data, dataformats='HWC')
writer.close()

1.4 Transforms:图像预处理

transforms 是 PyTorch 提供的一个模块,用于图像预处理。常用的预处理方法包括:转换为Tensor、图像标准化、调整图像大小等。以下是一些常用的 transforms 操作。

1.4.1 ToTensor:将PIL图像或numpy数组转换为Tensor
from torchvision import transforms
from PIL import Image
import cv2

# 创建转换实例
to_tensor = transforms.ToTensor()

# 使用PIL加载图像
img_pil = Image.open('data/hymenoptera_data/train/ants_image/0013035.jpg')
print(to_tensor(img_pil))

# 使用Numpy加载图像
img_np = cv2.imread('data/hymenoptera_data/train/ants_image/0013035.jpg')
print(to_tensor(img_np))
1.4.2 Normalize:图像标准化

标准化的作用是将图像的每个通道减去均值,再除以标准差。

# 创建标准化转换
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

# 使用标准化
img_norm = normalize(to_tensor(img_np))
1.4.3 Resize:调整图像大小

Resize 可以将图像调整到指定的大小,可以传入一个宽高元组或者单一的整数来等比缩放图像。

# 创建Resize转换
resize = transforms.Resize((512, 512))
resized_img = resize(to_tensor(img_pil))

# 记录图像到TensorBoard
writer.add_image('resized_ants', resized_img)

1.5 torchvision 数据集的简单处理

torchvision 库中,PyTorch 提供了许多标准数据集(如 CIFAR-10、MNIST 等),并且这些数据集都可以直接与 DataLoader 一起使用。

示例代码:
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

# 定义图像预处理操作
transform = transforms.Compose([transforms.ToTensor()])

# 下载并加载CIFAR-10数据集
train_set = torchvision.datasets.CIFAR10(
    root='data-train',
    train=True,
    transform=transform,
    download=True
)

# 可视化前10张图像
writer = SummaryWriter('logs')
for i in range(10):
    img_tensor, label = train_set[i]
    writer.add_image('CIFAR10', img_tensor, i)
writer.close()
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐