错误姿势

对于读取BraTS数据集,我原本是这样写的

import os

import numpy as np
import torch
from torch.utils.data import Dataset
import nibabel as nib

class BraTSLGGMultiModalDataset(Dataset):
    def __init__(self, data_dir,modalities  = ['t1','t2','flair','t1ce'] , transform=None):
        self.data_dir = data_dir
        self.modalities = modalities
        self.transform = transform
        self.samples = []

        for patient in os.listdir(data_dir):
            patient_path = os.path.join(data_dir, patient)
            if not os.path.isdir(patient_path):
                continue

            valid = True
            paths = {}
            for mod in modalities:
                mod_path = os.path.join(patient_path, f"{patient}_{mod}.nii")
                if not os.path.isfile(mod_path):
                    valid = False
                    break
                paths[mod] = mod_path
            if valid:
                self.samples.append(paths)

    def  __len__(self):
            return len(self.samples)

    def __getitem__(self, idx):
            data = []
            for mod in  self.modalities:
                nii_img = nib.load(self.samples[idx][mod])
                img = nii_img.get_fdata()
                # 取中间切片
                img = img[:,:,img.shape[2]//2]

                # 归一化(Z-score)
                img = (img - np.mean(img)) / (np.std(img) + 1e-8)

                # 转为tensor并添加batch维度以便resize
                img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0) # [1,1,H,W]

                # Resize到模型输入大小
                img = torch.nn.functional.interpolate(img, (224, 224), mode='bilinear',  align_corners=False)

                # 去掉多余的维度
                img = img.squeeze()
                data.append(img)

            # 拼接成4-channels
            image_4c = torch.stack(data, dim=0)

            if self.transform:
                image_4c = self.transform(image_4c)

            return image_4c, 0

正确姿势

但我转念一想,对于医学MRI数据集,如果每个病人只取中间的某一个切片,未免太过可惜。所以尝试了加载所有切片(z轴方向)。

import os
import numpy as np
import torch
from torch.utils.data import Dataset
import nibabel as nib
from torch.nn.functional import interpolate


class BraTSLGGMultiModalDataset(Dataset):
    def __init__(self, data_dir, modalities=None, transform=None, slice_range=None):
        """
        Args:
            data_dir: 数据根目录,包含多个病人文件夹
            modalities: 模态列表,默认为 ['t1', 't2', 'flair', 't1ce']
            transform: 数据增强
            slice_range: 要使用的切片范围,例如 (20, 130) 排除头尾无信息切片
                         None 表示使用所有切片
        """
        self.data_dir = data_dir
        self.modalities = modalities or ['t1', 't2', 'flair', 't1ce']
        self.transform = transform
        self.samples = []  # 存储 (patient_path, z_index) 元组

        # 遍历每个病人
        for patient in os.listdir(data_dir):
            patient_path = os.path.join(data_dir, patient)
            if not os.path.isdir(patient_path):
                continue

            # 检查 4 个模态是否存在
            mod_files = {}
            valid = True
            for mod in self.modalities:
                mod_filename = f"{patient}_{mod}.nii"  # 注意:有些数据集是 .nii.gz,根据实际情况调整
                mod_path = os.path.join(patient_path, mod_filename)
                if not os.path.isfile(mod_path):
                    valid = False
                    break
                mod_files[mod] = mod_path

            if not valid:
                continue

            # 加载一个模态以获取 shape(假设所有模态尺寸一致)
            sample_img = nib.load(mod_files[self.modalities[0]])
            _, _, depth = sample_img.shape  # (H, W, Z)

            # 确定切片范围
            if slice_range is not None:
                start_z, end_z = slice_range
                start_z = max(0, start_z)
                end_z = min(depth, end_z)
            else:
                start_z, end_z = 0, depth

            # 将每个 z_index 添加为独立样本
            for z in range(start_z, end_z):
                self.samples.append({
                    'patient': patient,
                    'paths': mod_files,
                    'z_index': z
                })

        print(f"Total number of 2D slices: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        data = []

        for mod in self.modalities:
            nii_img = nib.load(sample['paths'][mod])
            img_3d = nii_img.get_fdata()  # (H, W, Z)
            img = img_3d[:, :, sample['z_index']]  # (H, W)

            # 归一化(Z-score)
            mean = np.mean(img)
            std = np.std(img)
            img = (img - mean) / (std + 1e-8)

            # 转为 tensor 并插值到 224x224
            img = torch.FloatTensor(img).unsqueeze(0).unsqueeze(0)  # (1, 1, H, W)
            img = interpolate(img, size=(224, 224), mode='bilinear', align_corners=False)
            img = img.squeeze(0).squeeze(0)  # (224, 224)

            data.append(img)

        # 堆叠成 (4, 224, 224)
        image_4c = torch.stack(data, dim=0)  # (C, H, W)

        if self.transform:
            image_4c = self.transform(image_4c)

        # 返回图像和 dummy label(自监督不需要 label)
        return image_4c, 0

背景回顾

  • 原始代码:每个病人 → 只取 一个中间切片 → Dataset size = 病人数

  • 现在代码:每个病人 → 取 所有 z 轴切片 → Dataset size = 病人数 × 切片数

修改目标

  • 从每个病人中读取 4 个模态的 完整 3D 体积

  • 遍历每个 z 轴切片(z=0 到 z=Z-1

  • 每个 (病人, z_index) 作为一个 2D 样本

  • 总样本数从 75 → 75 × 100~150 = 7,500~11,250

详细代码对比

①self.samples 存储结构改变:存储单位从病人变成了切片

原始代码:

self.samples = []
for patient in os.listdir(data_dir):
    ...
    paths = {}
    for mod in modalities:
        mod_path = ... 
        paths[mod] = mod_path
    self.samples.append(paths)  # 只存路径字典

当前代码:

self.samples = []
for patient in os.listdir(data_dir):
    ...
    for z in range(start_z, end_z):
        self.samples.append({
            'patient': patient,
            'paths': mod_files,
            'z_index': z
        })

self.samples现在存储的每一个元素,都对应一个独立的2D切片样本

每个 self.samples[i] 包含 含义
'patient' 是哪个病人(如 TCGA-XX-XXXX
'paths' 四个模态(T1, T2, FLAIR, T1ce)的 NIfTI 文件路径
'z_index' 要从 3D 图像中提取的 第 z 层切片

因此

len(self.samples)

就等于

所有病人 x 每个病人的有效切片数

②新增slice_range参数

def __init__(..., slice_range=None):
    ...
    if slice_range is not None:
        start_z, end_z = slice_range
        start_z = max(0, start_z)
        end_z = min(depth, end_z)
    else:
        start_z, end_z = 0, depth
  • 可以排除头尾无信息的切片(如只有空气或颅骨)

  • 例如 slice_range=(10, 140),只保留中间 130 层

  • 提高训练效率和质量

③__getitem__:从“取中间切片”    —>    “取指定z切片”

原始代码:

img = img[:,:,img.shape[2]//2]  # 固定中间切片

当前代码:

img = img_3d[:, :, sample['z_index']]  # 动态取 z

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐