Pytorch框架02(PyTorch 数据处理板块/Dataset和DataLoader数据加载/Transforms/TensorBoard)
Dataset类是 PyTorch 用于封装数据的基础类,通常通过继承:返回数据集的大小(即样本的数量)。:根据索引idx返回数据集中的某一项数据,通常返回(数据, 标签)。MyDatasetDataLoader是 PyTorch 中用于批量加载数据的工具,能够自动将Dataset中的数据分批,并支持多线程加载,极大提高了训练效率。DataLoaderbatch_size:每个批次加载多少数据。s
二、PyTorch 数据处理板块
1.1 Dataset
类:定义数据集
概述:
Dataset
类是 PyTorch 用于封装数据的基础类,通常通过继承 torch.utils.data.Dataset
来创建自定义数据集类,并实现两个关键方法:
__len__(self)
:返回数据集的大小(即样本的数量)。__getitem__(self, idx)
:根据索引idx
返回数据集中的某一项数据,通常返回(数据, 标签)
。
示例代码:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
"""
初始化数据集。
:param data: 数据,通常是一个张量或数组,大小为 (N, C, H, W) 或 (N, feature_size)
:param labels: 标签,通常是一个张量,大小为 (N,)
"""
self.data = data
self.labels = labels
def __len__(self):
"""
返回数据集大小,即数据集中的样本数
"""
return len(self.data)
def __getitem__(self, idx):
"""
根据索引 idx 返回数据和标签
:param idx: 数据索引
:return: 数据样本和对应标签
"""
sample = self.data[idx]
label = self.labels[idx]
return sample, label
案例:自定义数据集 MyDataset
获取图像数据和标签信息
在此示例中,我们通过自定义数据集来读取图片路径,并将其与标签关联。
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import os
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.img_path = os.listdir(os.path.join(self.root_dir, self.label_dir))
def __getitem__(self, idx):
img_name = self.img_path[idx]
return img_name, self.label_dir
def __len__(self):
return len(self.img_path)
# 示例
md = MyDataset('hymenoptera_data_1/train', 'ants')
img, label = md[0]
ant_image = plt.imread(f"hymenoptera_data_1/train/ants/{img}")
plt.imshow(ant_image)
1.2 DataLoader
类:批量加载数据
概述:
DataLoader
是 PyTorch 中用于批量加载数据的工具,能够自动将 Dataset
中的数据分批,并支持多线程加载,极大提高了训练效率。DataLoader
的常见参数如下:
batch_size
:每个批次加载多少数据。shuffle
:是否在每个 epoch 结束时打乱数据集,通常用于训练数据。num_workers
:用于加载数据的子进程数量。如果计算机有多个 CPU 核心,可以增加此参数来加速数据加载。collate_fn
:指定如何将一个批次的数据聚合成一个批量,通常可以直接使用默认实现。drop_last
:当数据集大小不能被batch_size
整除时,是否舍弃最后一个小批次。
示例代码:
from torch.utils.data import DataLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter
# 下载CIFAR10数据集
test_set = torchvision.datasets.CIFAR10(
root='data-test',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
# 使用DataLoader批量加载数据
dataLoader = DataLoader(
dataset=test_set,
batch_size=64,
shuffle=True,
num_workers=0,
drop_last=False
)
# 使用TensorBoard可视化数据
writer = SummaryWriter('logs')
step = 0
for epoch in range(2):
for data in dataLoader:
imgs, labels = data
writer.add_images(f'Epoch:{epoch}', imgs, step)
step += 1
step = 0
writer.close()
1.3 TensorBoard 可视化
概述:
TensorBoard 是一个强大的可视化工具,专门用于可视化深度学习训练过程中的各类信息。它帮助用户追踪损失、精度、学习率、梯度等参数的变化,甚至可以展示训练中的图像、音频等数据。
- 创建一个
SummaryWriter
实例:用来记录训练过程中的各种信息。 - 记录标量数据:例如损失、精度等,TensorBoard 会将它们绘制成曲线图。
- 启动 TensorBoard:通过命令行启动 TensorBoard,查看训练过程中的可视化结果。
启动 TensorBoard:
tensorboard --logdir=logs
示例代码:记录和可视化标量数据
from torch.utils.tensorboard import SummaryWriter
import numpy as np
# 创建SummaryWriter实例
writer = SummaryWriter('logs')
# 记录sin函数的变化曲线
for i in np.arange(-2 * np.pi, 2 * np.pi, 0.1):
writer.add_scalar(
tag='sin function', # 图表标签
global_step=i, # x坐标
scalar_value=np.sin(i) # y坐标
)
writer.close()
# 在Jupyter中查看结果
%load_ext tensorboard
%tensorboard --logdir=logs
示例代码:记录和可视化图像数据
# 创建随机图像数据
image_data = np.random.rand(256, 256, 3)
# 记录图像数据到TensorBoard
writer.add_image("random_image", image_data, dataformats='HWC')
writer.close()
1.4 Transforms:图像预处理
transforms
是 PyTorch 提供的一个模块,用于图像预处理。常用的预处理方法包括:转换为Tensor、图像标准化、调整图像大小等。以下是一些常用的 transforms
操作。
1.4.1 ToTensor
:将PIL图像或numpy数组转换为Tensor
from torchvision import transforms
from PIL import Image
import cv2
# 创建转换实例
to_tensor = transforms.ToTensor()
# 使用PIL加载图像
img_pil = Image.open('data/hymenoptera_data/train/ants_image/0013035.jpg')
print(to_tensor(img_pil))
# 使用Numpy加载图像
img_np = cv2.imread('data/hymenoptera_data/train/ants_image/0013035.jpg')
print(to_tensor(img_np))
1.4.2 Normalize
:图像标准化
标准化的作用是将图像的每个通道减去均值,再除以标准差。
# 创建标准化转换
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
# 使用标准化
img_norm = normalize(to_tensor(img_np))
1.4.3 Resize
:调整图像大小
Resize
可以将图像调整到指定的大小,可以传入一个宽高元组或者单一的整数来等比缩放图像。
# 创建Resize转换
resize = transforms.Resize((512, 512))
resized_img = resize(to_tensor(img_pil))
# 记录图像到TensorBoard
writer.add_image('resized_ants', resized_img)
1.5 torchvision
数据集的简单处理
在 torchvision
库中,PyTorch 提供了许多标准数据集(如 CIFAR-10、MNIST 等),并且这些数据集都可以直接与 DataLoader
一起使用。
示例代码:
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
# 定义图像预处理操作
transform = transforms.Compose([transforms.ToTensor()])
# 下载并加载CIFAR-10数据集
train_set = torchvision.datasets.CIFAR10(
root='data-train',
train=True,
transform=transform,
download=True
)
# 可视化前10张图像
writer = SummaryWriter('logs')
for i in range(10):
img_tensor, label = train_set[i]
writer.add_image('CIFAR10', img_tensor, i)
writer.close()

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)