【MAE】如何读取BraTS数据集做自监督训练
对于读取BraTS数据集,我原本是这样写的。
·
错误姿势
对于读取BraTS数据集,我原本是这样写的
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import nibabel as nib
class BraTSLGGMultiModalDataset(Dataset):
def __init__(self, data_dir,modalities = ['t1','t2','flair','t1ce'] , transform=None):
self.data_dir = data_dir
self.modalities = modalities
self.transform = transform
self.samples = []
for patient in os.listdir(data_dir):
patient_path = os.path.join(data_dir, patient)
if not os.path.isdir(patient_path):
continue
valid = True
paths = {}
for mod in modalities:
mod_path = os.path.join(patient_path, f"{patient}_{mod}.nii")
if not os.path.isfile(mod_path):
valid = False
break
paths[mod] = mod_path
if valid:
self.samples.append(paths)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
data = []
for mod in self.modalities:
nii_img = nib.load(self.samples[idx][mod])
img = nii_img.get_fdata()
# 取中间切片
img = img[:,:,img.shape[2]//2]
# 归一化(Z-score)
img = (img - np.mean(img)) / (np.std(img) + 1e-8)
# 转为tensor并添加batch维度以便resize
img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0) # [1,1,H,W]
# Resize到模型输入大小
img = torch.nn.functional.interpolate(img, (224, 224), mode='bilinear', align_corners=False)
# 去掉多余的维度
img = img.squeeze()
data.append(img)
# 拼接成4-channels
image_4c = torch.stack(data, dim=0)
if self.transform:
image_4c = self.transform(image_4c)
return image_4c, 0
正确姿势
但我转念一想,对于医学MRI数据集,如果每个病人只取中间的某一个切片,未免太过可惜。所以尝试了加载所有切片(z轴方向)。
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import nibabel as nib
from torch.nn.functional import interpolate
class BraTSLGGMultiModalDataset(Dataset):
def __init__(self, data_dir, modalities=None, transform=None, slice_range=None):
"""
Args:
data_dir: 数据根目录,包含多个病人文件夹
modalities: 模态列表,默认为 ['t1', 't2', 'flair', 't1ce']
transform: 数据增强
slice_range: 要使用的切片范围,例如 (20, 130) 排除头尾无信息切片
None 表示使用所有切片
"""
self.data_dir = data_dir
self.modalities = modalities or ['t1', 't2', 'flair', 't1ce']
self.transform = transform
self.samples = [] # 存储 (patient_path, z_index) 元组
# 遍历每个病人
for patient in os.listdir(data_dir):
patient_path = os.path.join(data_dir, patient)
if not os.path.isdir(patient_path):
continue
# 检查 4 个模态是否存在
mod_files = {}
valid = True
for mod in self.modalities:
mod_filename = f"{patient}_{mod}.nii" # 注意:有些数据集是 .nii.gz,根据实际情况调整
mod_path = os.path.join(patient_path, mod_filename)
if not os.path.isfile(mod_path):
valid = False
break
mod_files[mod] = mod_path
if not valid:
continue
# 加载一个模态以获取 shape(假设所有模态尺寸一致)
sample_img = nib.load(mod_files[self.modalities[0]])
_, _, depth = sample_img.shape # (H, W, Z)
# 确定切片范围
if slice_range is not None:
start_z, end_z = slice_range
start_z = max(0, start_z)
end_z = min(depth, end_z)
else:
start_z, end_z = 0, depth
# 将每个 z_index 添加为独立样本
for z in range(start_z, end_z):
self.samples.append({
'patient': patient,
'paths': mod_files,
'z_index': z
})
print(f"Total number of 2D slices: {len(self.samples)}")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
data = []
for mod in self.modalities:
nii_img = nib.load(sample['paths'][mod])
img_3d = nii_img.get_fdata() # (H, W, Z)
img = img_3d[:, :, sample['z_index']] # (H, W)
# 归一化(Z-score)
mean = np.mean(img)
std = np.std(img)
img = (img - mean) / (std + 1e-8)
# 转为 tensor 并插值到 224x224
img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
img = interpolate(img, size=(224, 224), mode='bilinear', align_corners=False)
img = img.squeeze(0).squeeze(0) # (224, 224)
data.append(img)
# 堆叠成 (4, 224, 224)
image_4c = torch.stack(data, dim=0) # (C, H, W)
if self.transform:
image_4c = self.transform(image_4c)
# 返回图像和 dummy label(自监督不需要 label)
return image_4c, 0
背景回顾
-
原始代码:每个病人 → 只取 一个中间切片 →
Dataset size = 病人数 -
现在代码:每个病人 → 取 所有 z 轴切片 →
Dataset size = 病人数 × 切片数
修改目标
-
从每个病人中读取 4 个模态的 完整 3D 体积
-
遍历每个 z 轴切片(
z=0到z=Z-1) -
每个
(病人, z_index)作为一个 2D 样本 -
总样本数从 75 → 75 × 100~150 = 7,500~11,250
详细代码对比
①self.samples 存储结构改变:存储单位从病人变成了切片
原始代码:
self.samples = []
for patient in os.listdir(data_dir):
...
paths = {}
for mod in modalities:
mod_path = ...
paths[mod] = mod_path
self.samples.append(paths) # 只存路径字典
当前代码:
self.samples = []
for patient in os.listdir(data_dir):
...
for z in range(start_z, end_z):
self.samples.append({
'patient': patient,
'paths': mod_files,
'z_index': z
})
self.samples现在存储的每一个元素,都对应一个独立的2D切片样本
每个 self.samples[i] 包含 |
含义 |
|---|---|
'patient' |
是哪个病人(如 TCGA-XX-XXXX) |
'paths' |
四个模态(T1, T2, FLAIR, T1ce)的 NIfTI 文件路径 |
'z_index' |
要从 3D 图像中提取的 第 z 层切片 |
因此
len(self.samples)
就等于
所有病人 x 每个病人的有效切片数
②新增slice_range参数
def __init__(..., slice_range=None):
...
if slice_range is not None:
start_z, end_z = slice_range
start_z = max(0, start_z)
end_z = min(depth, end_z)
else:
start_z, end_z = 0, depth
-
可以排除头尾无信息的切片(如只有空气或颅骨)
-
例如
slice_range=(10, 140),只保留中间 130 层 -
提高训练效率和质量
③__getitem__:从“取中间切片” —> “取指定z切片”
原始代码:
img = img[:,:,img.shape[2]//2] # 固定中间切片
当前代码:
img = img_3d[:, :, sample['z_index']] # 动态取 z
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)