PyTorch深度学习框架60天进阶学习计划 - 第36天:医疗影像诊断(一)

朋友们!真没想到能写到第36天!今天我们要踏入一个既充满挑战又极具意义的领域——医疗影像诊断。我们将学习如何利用3D ResNet对肺部CT进行分析,探索适合医学图像的数据增强技术,并解决医疗数据中常见的类别不平衡问题。

医疗AI有一句玩笑:“普通的AI模型出错了,可能只是把猫识别成狗;医疗AI出错了,可能就把健康人送进了ICU。” 所以,让我们带着敬畏之心,开始今天的学习吧!

一、医疗影像诊断概述

医疗影像诊断是AI在医疗领域最有前景的应用之一。与普通图像不同,医疗影像通常具有以下特点:

  1. 维度多样:CT和MRI等医疗影像是3D数据,而不是简单的2D图像
  2. 数据稀缺:标注的医疗数据远少于普通图像数据集
  3. 类别不平衡:疾病样本通常远少于健康样本
  4. 高精度要求:医疗诊断对准确性要求极高,容错率低

今天我们将聚焦于肺部CT的分析,这在肺癌、肺炎和COVID-19等疾病诊断中有重要应用。

二、3D ResNet结构设计

2.1 为什么选择ResNet?

在医疗影像中,我们通常需要提取复杂的特征。ResNet的残差连接可以有效解决深层网络的梯度消失问题,使我们能够构建更深的网络。同时,医学特征往往需要从微小的变化中捕捉,ResNet良好的特征提取能力使其成为理想选择。

2.2 从2D到3D的转换

将2D ResNet转换为3D版本主要涉及以下变化:

2D组件 3D对应组件 变化说明
Conv2d Conv3d 卷积核从(k×k)变为(k×k×k)
MaxPool2d MaxPool3d 池化窗口从(k×k)变为(k×k×k)
BatchNorm2d BatchNorm3d 归一化维度增加
Adaptive AvgPool2d Adaptive AvgPool3d 自适应池化维度增加

2.3 3D ResNet基本结构

我们的3D ResNet主要由以下部分组成:

  1. 初始卷积层:捕捉基本特征
  2. 残差块:提取复杂特征并解决梯度消失问题
  3. 全局池化层:降维并保留重要特征
  4. 全连接层:进行最终分类
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock3D(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, 
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class Bottleneck3D(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(Bottleneck3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
                              padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=2, zero_init_residual=False):
        super(ResNet3D, self).__init__()
        self.in_planes = 64
        
        # 初始卷积层
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        
        # 残差层
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # 权重初始化
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # 残差块特殊初始化
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck3D):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock3D):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv3d(self.in_planes, planes * block.expansion, kernel_size=1, 
                         stride=stride, bias=False),
                nn.BatchNorm3d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        # 输入处理和初始特征提取
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # 特征提取网络
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # 分类头
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

def resnet18_3d(num_classes=2, **kwargs):
    """18层3D ResNet"""
    return ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes=num_classes, **kwargs)

def resnet34_3d(num_classes=2, **kwargs):
    """34层3D ResNet"""
    return ResNet3D(BasicBlock3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

def resnet50_3d(num_classes=2, **kwargs):
    """50层3D ResNet"""
    return ResNet3D(Bottleneck3D, [3, 4, 6, 3], num_classes=num_classes, **kwargs)

三、医学影像数据处理与增强

3.1 医学影像数据集

医学影像数据的组织通常比普通图像复杂。让我们首先了解常见的CT数据格式:

  • DICOM格式:医学影像的标准格式,包含图像和患者信息
  • NIfTI格式:神经影像常用格式,多用于研究
  • NRRD格式:适合存储多维医学数据

对于肺部CT,我们需要处理的是一系列横截面图像,每个患者可能有几十到几百张切片。

3.2 数据预处理

医学影像预处理通常包括以下步骤:

  1. 数据读取:解析DICOM或其他医学影像格式
  2. 窗口化:调整CT值范围以突出感兴趣的组织(肺窗通常为-1000到400HU)
  3. 重采样:将不同分辨率的CT统一到相同的体素大小
  4. 切割:去除无关区域,只保留肺部
  5. 标准化:将像素值归一化到适合神经网络的范围
import numpy as np
import pydicom
import glob
import os
import SimpleITK as sitk
from skimage import measure
from scipy import ndimage

def load_dicom_series(directory):
    """
    加载DICOM系列文件并转换为3D体积
    
    参数:
        directory: 包含DICOM文件的目录
    
    返回:
        3D numpy数组,形状为 [深度, 高度, 宽度]
    """
    reader = sitk.ImageSeriesReader()
    dicom_names = reader.GetGDCMSeriesFileNames(directory)
    reader.SetFileNames(dicom_names)
    image = reader.Execute()
    
    # 转换为numpy数组
    array = sitk.GetArrayFromImage(image)
    return array, image

def apply_lung_window(ct_scan, min_bound=-1000, max_bound=400):
    """
    应用肺窗口值
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        min_bound: HU值下限
        max_bound: HU值上限
    
    返回:
        窗口化和归一化后的CT扫描
    """
    # 截断HU值
    ct_scan = np.clip(ct_scan, min_bound, max_bound)
    
    # 归一化到[0,1]
    ct_scan = (ct_scan - min_bound) / (max_bound - min_bound)
    
    return ct_scan

def resample_volume(img, spacing, new_spacing=[1.0, 1.0, 1.0]):
    """
    重采样CT体积到指定的体素间距
    
    参数:
        img: SimpleITK图像对象
        spacing: 原始体素间距
        new_spacing: 目标体素间距
        
    返回:
        重采样后的SimpleITK图像对象
    """
    # 计算新的尺寸
    spacing = np.array(spacing)
    new_spacing = np.array(new_spacing)
    orig_size = np.array(img.GetSize())
    resize_factor = spacing / new_spacing
    new_size = orig_size * resize_factor
    new_size = np.round(new_size).astype(int)
    
    # 执行重采样
    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(new_spacing)
    resample.SetSize(new_size.tolist())
    resample.SetOutputDirection(img.GetDirection())
    resample.SetOutputOrigin(img.GetOrigin())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(img.GetPixelIDValue())
    resample.SetInterpolator(sitk.sitkLinear)
    
    return resample.Execute(img)

def segment_lungs(ct_scan, fill_lung_structures=True):
    """
    分割肺部区域
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        fill_lung_structures: 是否填充肺内结构
        
    返回:
        肺部掩码和应用掩码后的CT扫描
    """
    # 阈值化得到二值图像
    binary_image = np.array(ct_scan < -320, dtype=np.int8)
    
    # 标记所有区域
    labels = measure.label(binary_image)
    
    # 假设肺区域不是最大的连通区域
    # 背景通常是最大的连通区域
    background_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
    binary_image[labels == background_label] = 0
    
    # 用形态学闭操作填充肺内结构
    if fill_lung_structures:
        for i in range(ct_scan.shape[0]):
            slice = binary_image[i]
            binary_image[i] = ndimage.binary_fill_holes(slice)
    
    # 创建肺部掩码
    lung_mask = binary_image
    
    # 应用掩码到原始CT扫描
    masked_ct = ct_scan * lung_mask
    
    return lung_mask, masked_ct

def normalize_scan(ct_scan):
    """
    标准化CT扫描值到0-1范围
    
    参数:
        ct_scan: CT扫描的3D numpy数组
        
    返回:
        标准化后的CT扫描
    """
    ct_scan = ct_scan.astype(np.float32)
    ct_scan = (ct_scan - np.min(ct_scan)) / (np.max(ct_scan) - np.min(ct_scan))
    return ct_scan

def preprocess_ct_scan(dicom_dir, output_size=(128, 128, 128)):
    """
    完整的CT扫描预处理流程
    
    参数:
        dicom_dir: DICOM文件目录
        output_size: 输出体积的尺寸
        
    返回:
        预处理后的CT扫描,准备用于深度学习模型
    """
    # 加载DICOM文件
    ct_array, ct_image = load_dicom_series(dicom_dir)
    
    # 应用肺窗口值
    windowed_ct = apply_lung_window(ct_array)
    
    # 重采样到统一分辨率
    spacing = ct_image.GetSpacing()
    resampled_ct_image = resample_volume(ct_image, spacing)
    resampled_ct = sitk.GetArrayFromImage(resampled_ct_image)
    
    # 肺部分割
    lung_mask, masked_ct = segment_lungs(resampled_ct)
    
    # 标准化
    normalized_ct = normalize_scan(masked_ct)
    
    # 调整到目标大小
    # 这里使用简单的缩放,实际应用中可能需要更复杂的方法
    from scipy.ndimage import zoom
    resize_factor = np.array(output_size) / np.array(normalized_ct.shape)
    final_ct = zoom(normalized_ct, resize_factor, order=1)
    
    return final_ct

3.3 医学影像数据增强

医学影像的数据增强需要特别谨慎,不能引入不真实的变化。以下是适合肺部CT的数据增强策略:

增强方法 描述 适用性
旋转 绕各轴小角度旋转 高,肺部诊断通常不依赖方向
缩放 轻微的体积缩放 中,需保持合理的解剖结构比例
亮度/对比度调整 轻微调整CT值窗口 高,模拟不同的CT扫描仪参数
噪声添加 添加高斯噪声 中,应保持主要特征清晰
弹性变形 局部非刚性变形 低,可能引入不真实的病变形态
随机裁剪 从原始体积中裁剪子块 高,适合大体积数据
import torch
import numpy as np
from scipy.ndimage import rotate, zoom, shift
import elasticdeform
from torchvision import transforms

class CTAugmentation3D:
    """
    用于3D医学影像(特别是CT)的数据增强类
    """
    def __init__(self, 
                 rotation_range=(-10, 10),
                 scale_range=(0.9, 1.1),
                 shift_range=(-5, 5),
                 noise_factor=0.05,
                 brightness_range=(0.9, 1.1),
                 contrast_range=(0.9, 1.1),
                 p_rotation=0.5,
                 p_scale=0.5,
                 p_shift=0.5, 
                 p_noise=0.3,
                 p_brightness=0.3,
                 p_contrast=0.3,
                 p_elastic=0.2):
        """
        初始化3D增强器
        
        参数:
            rotation_range: 旋转角度范围(度)
            scale_range: 缩放因子范围
            shift_range: 平移像素范围
            noise_factor: 噪声强度系数
            brightness_range: 亮度调整范围
            contrast_range: 对比度调整范围
            p_*: 各增强方法的应用概率
        """
        self.rotation_range = rotation_range
        self.scale_range = scale_range
        self.shift_range = shift_range
        self.noise_factor = noise_factor
        self.brightness_range = brightness_range
        self.contrast_range = contrast_range
        
        self.p_rotation = p_rotation
        self.p_scale = p_scale
        self.p_shift = p_shift
        self.p_noise = p_noise
        self.p_brightness = p_brightness
        self.p_contrast = p_contrast
        self.p_elastic = p_elastic
    
    def apply_rotation(self, volume):
        """应用随机旋转"""
        # 为每个轴随机生成旋转角度
        angles = [np.random.uniform(self.rotation_range[0], 
                                   self.rotation_range[1]) for _ in range(3)]
        
        # 沿着每个轴旋转
        for i, angle in enumerate(angles):
            axes = tuple([j for j in range(3) if j != i])
            volume = rotate(volume, angle, axes=axes, reshape=False, order=1, mode='nearest')
        
        return volume
    
    def apply_scaling(self, volume):
        """应用随机缩放"""
        # 为每个维度随机生成缩放因子
        scale_factor = np.random.uniform(self.scale_range[0], self.scale_range[1])
        
        # 应用缩放
        volume = zoom(volume, scale_factor, order=1)
        
        # 确保大小一致(如果缩放后尺寸变化)
        if volume.shape != self.original_shape:
            # 计算需要裁剪或padding的量
            diffs = np.array(volume.shape) - np.array(self.original_shape)
            
            # 裁剪或padding
            result = np.zeros(self.original_shape)
            
            # 为每个维度确定切片范围
            slices_src = []
            slices_dst = []
            
            for i in range(3):
                if diffs[i] > 0:  # 需要裁剪
                    # 从中心裁剪
                    start_src = diffs[i] // 2
                    end_src = start_src + self.original_shape[i]
                    start_dst = 0
                    end_dst = self.original_shape[i]
                else:  # 需要padding
                    # 中心padding
                    start_src = 0
                    end_src = volume.shape[i]
                    start_dst = -diffs[i] // 2
                    end_dst = start_dst + volume.shape[i]
                
                slices_src.append(slice(start_src, end_src))
                slices_dst.append(slice(start_dst, end_dst))
            
            # 将缩放后的体积复制到结果中
            result[tuple(slices_dst)] = volume[tuple(slices_src)]
            volume = result
        
        return volume
    
    def apply_shift(self, volume):
        """应用随机平移"""
        shifts = [np.random.uniform(self.shift_range[0], 
                                   self.shift_range[1]) for _ in range(3)]
        return shift(volume, shifts, order=1, mode='nearest')
    
    def apply_noise(self, volume):
        """添加高斯噪声"""
        noise = np.random.normal(0, self.noise_factor, volume.shape)
        volume = volume + noise
        volume = np.clip(volume, 0, 1)  # 确保值在有效范围内
        return volume
    
    def apply_brightness(self, volume):
        """调整亮度"""
        factor = np.random.uniform(self.brightness_range[0], self.brightness_range[1])
        volume = volume * factor
        volume = np.clip(volume, 0, 1)
        return volume
    
    def apply_contrast(self, volume):
        """调整对比度"""
        factor = np.random.uniform(self.contrast_range[0], self.contrast_range[1])
        mean = np.mean(volume)
        volume = (volume - mean) * factor + mean
        volume = np.clip(volume, 0, 1)
        return volume
    
    def apply_elastic_deformation(self, volume):
        """应用弹性变形"""
        # 为3D体积生成变形场
        # sigma控制变形的平滑度,较大的值产生更平滑的变形
        # points控制变形网格的粗细,较大的值产生更精细的变形
        deformed_volume = elasticdeform.deform_random_grid(
            volume, 
            sigma=3, 
            points=3,
            order=1,
            mode='nearest'
        )
        return deformed_volume
    
    def __call__(self, volume):
        """
        对输入的3D体积应用增强
        
        参数:
            volume: numpy数组,形状为[D, H, W]
            
        返回:
            增强后的体积
        """
        # 保存原始形状以便于缩放后的形状修正
        self.original_shape = volume.shape
        
        # 应用各种增强,每种增强都有一定概率应用
        if np.random.random() < self.p_rotation:
            volume = self.apply_rotation(volume)
            
        if np.random.random() < self.p_scale:
            volume = self.apply_scaling(volume)
            
        if np.random.random() < self.p_shift:
            volume = self.apply_shift(volume)
            
        if np.random.random() < self.p_noise:
            volume = self.apply_noise(volume)
            
        if np.random.random() < self.p_brightness:
            volume = self.apply_brightness(volume)
            
        if np.random.random() < self.p_contrast:
            volume = self.apply_contrast(volume)
            
        if np.random.random() < self.p_elastic:
            volume = self.apply_elastic_deformation(volume)
        
        return volume

# PyTorch的3D CT数据集类
class LungCTDataset(torch.utils.data.Dataset):
    def __init__(self, ct_paths, labels=None, transform=None, phase='train'):
        """
        肺部CT数据集
        
        参数:
            ct_paths: CT数据路径列表
            labels: 对应的标签列表
            transform: 数据增强和转换
            phase: 'train', 'val' 或 'test'
        """
        self.ct_paths = ct_paths
        self.labels = labels
        self.transform = transform
        self.phase = phase
        
    def __len__(self):
        return len(self.ct_paths)
    
    def __getitem__(self, idx):
        # 加载预处理好的CT体积
        # 假设每个路径是一个.npy文件,包含预处理好的CT体积
        ct_volume = np.load(self.ct_paths[idx])
        
        # 应用数据增强
        if self.transform and self.phase == 'train':
            ct_volume = self.transform(ct_volume)
        
        # 确保数据是浮点数并且形状正确([C, D, H, W])
        ct_volume = ct_volume.astype(np.float32)
        ct_volume = np.expand_dims(ct_volume, axis=0)  # 添加通道维度
        
        # 转换为PyTorch张量
        ct_tensor = torch.from_numpy(ct_volume)
        
        # 返回数据和标签(如果有)
        if self.labels is not None:
            label = self.labels[idx]
            return ct_tensor, label
        else:
            return ct_tensor

四、处理类别不平衡的损失函数设计

4.1 常见的类别不平衡问题解决方案

方法 描述 优点 缺点
欠采样 减少多数类样本 减少训练时间 丢失信息,模型可能欠拟合
过采样 增加少数类样本 保留所有数据 可能过拟合少数类
合成样本生成 如SMOTE算法生成少数类样本 平衡数据集不丢失信息 生成样本可能不真实
类别权重 在损失函数中给少数类更高权重 简单有效,保留所有数据 需要调整权重参数
焦点损失 (Focal Loss) 关注难分类样本 自动调整不同样本的权重 需要调整超参数
组合采样 结合欠采样和过采样 平衡各方法的优缺点 实现较复杂

在医学影像中,由于数据珍贵且获取成本高,我们通常不会采用单纯的欠采样。而是倾向于损失函数调整和过采样的组合方法。

4.2 特定的损失函数设计

对于肺部CT分析,我们将设计几种适合类别不平衡的损失函数:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class WeightedCrossEntropyLoss(nn.Module):
    """
    带类别权重的交叉熵损失
    适合处理类别不平衡问题
    """
    def __init__(self, weight=None, reduction='mean'):
        """
        参数:
            weight: 各类别的权重,通常少数类权重更高
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(WeightedCrossEntropyLoss, self).__init__()
        self.weight = weight
        self.reduction = reduction
    
    def forward(self, input, target):
        return F.cross_entropy(
            input, target, 
            weight=self.weight, 
            reduction=self.reduction
        )

class FocalLoss(nn.Module):
    """
    Focal Loss(聚焦损失)
    通过降低易分类样本的权重,关注难以分类的样本
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        参数:
            alpha: 各类别的权重
            gamma: 聚焦参数,越大对易分类样本的惩罚越大
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, input, target):
        ce_loss = F.cross_entropy(input, target, reduction='none', weight=self.alpha)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

class DiceLoss(nn.Module):
    """
    Dice Loss
    常用于医学图像分割,也适用于分类问题
    """
    def __init__(self, smooth=1.0, reduction='mean'):
        """
        参数:
            smooth: 平滑系数,防止分母为0
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.reduction = reduction
        
    def forward(self, input, target):
        # 将预测值转换为概率
        prob = F.softmax(input, dim=1)
        
        # 将目标转换为one-hot编码
        target = F.one_hot(target, num_classes=input.size(1)).float()
        target = target.permute(0, 3, 1, 2)
        
        # 计算Dice系数
        numerator = 2 * torch.sum(prob * target, dim=(2, 3))
        denominator = torch.sum(prob + target, dim=(2, 3)) + self.smooth
        dice_coeff = numerator / denominator
        dice_loss = 1 - dice_coeff
        
        if self.reduction == 'mean':
            return dice_loss.mean()
        elif self.reduction == 'sum':
            return dice_loss.sum()
        else:
            return dice_loss

class CombinedLoss(nn.Module):
    """
    结合Focal Loss和Dice Loss的复合损失
    综合利用两种损失函数的优点
    """
    def __init__(self, alpha=None, gamma=2.0, weight_focal=0.5, weight_dice=0.5, smooth=1.0):
        """
        参数:
            alpha: Focal Loss的类别权重
            gamma: Focal Loss的聚焦参数
            weight_focal: Focal Loss的权重
            weight_dice: Dice Loss的权重
            smooth: Dice Loss的平滑系数
        """
        super(CombinedLoss, self).__init__()
        self.focal_loss = FocalLoss(alpha=alpha, gamma=gamma)
        self.dice_loss = DiceLoss(smooth=smooth)
        self.weight_focal = weight_focal
        self.weight_dice = weight_dice
        
    def forward(self, input, target):
        return (
            self.weight_focal * self.focal_loss(input, target) + 
            self.weight_dice * self.dice_loss(input, target)
        )

class AsymmetricLoss(nn.Module):
    """
    非对称损失
    对不同类别使用不同的gamma值,更加灵活地处理类别不平衡
    """
    def __init__(self, gamma_pos=0, gamma_neg=4, clip=0.05, reduction='mean'):
        """
        参数:
            gamma_pos: 正类的gamma值
            gamma_neg: 负类的gamma值
            clip: 截断阈值
            reduction: 'none', 'mean', 'sum'中的一个
        """
        super(AsymmetricLoss, self).__init__()
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.clip = clip
        self.reduction = reduction
    
    def forward(self, input, target):
        # 将目标转换为one-hot编码
        target = F.one_hot(target, num_classes=input.size(1)).float()
        
        # Sigmoid输出
        prob = torch.sigmoid(input)
        
        # 裁剪概率,增加数值稳定性
        prob = torch.clamp(prob, self.clip, 1.0 - self.clip)
        
        # 计算正样本和负样本的聚焦因子
        pos_loss = target * torch.log(prob) * (1 - prob) ** self.gamma_pos
        neg_loss = (1 - target) * torch.log(1 - prob) * prob ** self.gamma_neg
        
        loss = -(pos_loss + neg_loss)
        
        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

def calculate_class_weights(labels, method='inverse', beta=0.999):
    """
    计算类别权重
    
    参数:
        labels: 训练集标签列表
        method: 计算方法,'inverse'(反比例)或'effective'(有效样本数)
        beta: 有效样本数方法的平衡因子
        
    返回:
        各类别的权重
    """
    # 计算每个类别的样本数
    class_counts = np.bincount(labels)
    n_classes = len(class_counts)
    
    if method == 'inverse':
        # 权重与类别频率成反比
        weights = 1.0 / np.array(class_counts)
        # 归一化权重
        weights = weights / np.sum(weights) * n_classes
    
    elif method == 'effective':
        # 使用有效样本数计算权重
        effective_num = 1.0 - np.power(beta, class_counts)
        weights = (1.0 - beta) / effective_num
        # 归一化权重
        weights = weights / np.sum(weights) * n_classes
    
    return torch.FloatTensor(weights)

4.3 损失函数的选择策略

在肺部CT诊断中,不同损失函数的适用场景:

损失函数 适用场景 优势
加权交叉熵 中度不平衡 简单有效,易于理解和调整
Focal Loss 高度不平衡 自适应关注难例,减少易分样本影响
Dice Loss 二分类问题 不受类别比例影响,适合评估重叠度
组合损失 复杂不平衡 综合多种损失函数优点
非对称损失 极度不平衡 对正负类分别调整焦点参数

一个经验法则是:当阳性样本比例<10%时,考虑使用Focal Loss或组合损失;当比例在10%-30%之间时,加权交叉熵通常足够;如果更关注召回率,Dice Loss是个不错的选择。

五、完整训练流程

现在,让我们将前面的组件整合起来,构建一个完整的肺部CT分析训练流程:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import time
import random
from tensorboardX import SummaryWriter

# 设置随机种子,确保可重复性
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class LungCTTrainer:
    def __init__(self, model, train_dataset, val_dataset, test_dataset=None, 
                 batch_size=8, lr=0.001, loss_fn=None, device=None, 
                 class_weights=None, experiment_name="lung_ct_analysis"):
        """
        肺部CT分析训练器
        
        参数:
            model: 3D ResNet模型
            train_dataset: 训练数据集
            val_dataset: 验证数据集
            test_dataset: 测试数据集
            batch_size: 批处理大小
            lr: 学习率
            loss_fn: 损失函数
            device: 训练设备
            class_weights: 类别权重
            experiment_name: 实验名称
        """
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size
        self.lr = lr
        
        # 设置设备
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # 将模型移动到设备上
        self.model = self.model.to(self.device)
        
        # 设置损失函数
        if loss_fn is None:
            if class_weights is not None:
                self.loss_fn = WeightedCrossEntropyLoss(weight=class_weights.to(self.device))
            else:
                self.loss_fn = nn.CrossEntropyLoss()
        else:
            self.loss_fn = loss_fn
        
        # 设置优化器
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        
        # 学习率调度器
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        
        # 创建数据加载器
        self.train_loader = DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True, 
            num_workers=4, pin_memory=True
        )
        self.val_loader = DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False, 
            num_workers=4, pin_memory=True
        )
        if test_dataset:
            self.test_loader = DataLoader(
                test_dataset, batch_size=batch_size, shuffle=False, 
                num_workers=4, pin_memory=True
            )
        else:
            self.test_loader = None
        
        # 设置TensorBoard
        self.writer = SummaryWriter(f"runs/{experiment_name}_{time.strftime('%Y%m%d_%H%M%S')}")
        
        # 训练状态跟踪
        self.best_val_loss = float('inf')
        self.best_model_path = f"models/{experiment_name}_best_model.pth"
        self.early_stop_patience = 15
        self.early_stop_counter = 0
        
        # 创建保存模型的目录
        os.makedirs("models", exist_ok=True)
    
    def train_epoch(self, epoch):
        """训练一个epoch"""
        self.model.train()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        
        # 使用tqdm创建进度条
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1} [Train]")
        
        for inputs, targets in pbar:
            # 将数据移到设备上
            inputs = inputs.to(self.device, non_blocking=True)
            targets = targets.to(self.device, non_blocking=True)
            
            # 清零梯度
            self.optimizer.zero_grad()
            
            # 前向传播
            outputs = self.model(inputs)
            loss = self.loss_fn(outputs, targets)
            
            # 反向传播
            loss.backward()
            
            # 梯度裁剪,防止梯度爆炸
            nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            # 更新参数
            self.optimizer.step()
            
            # 统计
            running_loss += loss.item() * inputs.size(0)
            
            # 收集预测和目标
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            # 更新进度条
            pbar.set_postfix({"loss": loss.item()})
        
        # 计算平均损失和评估指标
        epoch_loss = running_loss / len(self.train_dataset)
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 记录到TensorBoard
        self.writer.add_scalar('Loss/train', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/train', epoch_acc, epoch)
        self.writer.add_scalar('Precision/train', epoch_prec, epoch)
        self.writer.add_scalar('Recall/train', epoch_recall, epoch)
        self.writer.add_scalar('F1/train', epoch_f1, epoch)
        
        return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1
    
    def validate_epoch(self, epoch):
        """验证一个epoch"""
        self.model.eval()
        running_loss = 0.0
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc=f"Epoch {epoch+1} [Val]")
            for inputs, targets in pbar:
                # 将数据移到设备上
                inputs = inputs.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)
                
                # 前向传播
                outputs = self.model(inputs)
                loss = self.loss_fn(outputs, targets)
                
                # 统计
                running_loss += loss.item() * inputs.size(0)
                
                # 收集预测、目标和概率
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
                
                # 更新进度条
                pbar.set_postfix({"loss": loss.item()})
        
        # 计算平均损失和评估指标
        epoch_loss = running_loss / len(self.val_dataset)
        epoch_acc = accuracy_score(all_targets, all_preds)
        epoch_prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        epoch_f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 如果是二分类问题,计算AUC
        if len(np.unique(all_targets)) == 2:
            epoch_auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
            self.writer.add_scalar('AUC/val', epoch_auc, epoch)
        else:
            epoch_auc = None
        
        # 记录到TensorBoard
        self.writer.add_scalar('Loss/val', epoch_loss, epoch)
        self.writer.add_scalar('Accuracy/val', epoch_acc, epoch)
        self.writer.add_scalar('Precision/val', epoch_prec, epoch)
        self.writer.add_scalar('Recall/val', epoch_recall, epoch)
        self.writer.add_scalar('F1/val', epoch_f1, epoch)
        
        # 更新学习率
        self.scheduler.step(epoch_loss)
        
        # 保存最佳模型
        if epoch_loss < self.best_val_loss:
            self.best_val_loss = epoch_loss
            torch.save(self.model.state_dict(), self.best_model_path)
            print(f"Best model saved with val loss: {epoch_loss:.4f}")
            self.early_stop_counter = 0
        else:
            self.early_stop_counter += 1
        
        return epoch_loss, epoch_acc, epoch_prec, epoch_recall, epoch_f1, epoch_auc
    
    def train(self, epochs=100):
        """训练模型"""
        print(f"Starting training for {epochs} epochs...")
        
        # 训练历史记录
        history = {
            'train_loss': [], 'train_acc': [], 'train_prec': [],
            'train_recall': [], 'train_f1': [],
            'val_loss': [], 'val_acc': [], 'val_prec': [],
            'val_recall': [], 'val_f1': [], 'val_auc': []
        }
        
        # 训练循环
        for epoch in range(epochs):
            # 训练阶段
            train_loss, train_acc, train_prec, train_recall, train_f1 = self.train_epoch(epoch)
            
            # 验证阶段
            val_loss, val_acc, val_prec, val_recall, val_f1, val_auc = self.validate_epoch(epoch)
            
            # 记录历史
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['train_prec'].append(train_prec)
            history['train_recall'].append(train_recall)
            history['train_f1'].append(train_f1)
            
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            history['val_prec'].append(val_prec)
            history['val_recall'].append(val_recall)
            history['val_f1'].append(val_f1)
            history['val_auc'].append(val_auc)
            
            # 打印当前结果
            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
            print(f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
            if val_auc:
                print(f"Val AUC: {val_auc:.4f}")
            print("-" * 50)
            
            # 早停检查
            if self.early_stop_counter >= self.early_stop_patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
        
        # 训练完成,关闭TensorBoard writer
        self.writer.close()
        
        # 绘制训练历史
        self.plot_training_history(history)
        
        return history
    
    def plot_training_history(self, history):
        """绘制训练历史"""
        # 创建一个2x2的子图布局
        fig, axes = plt.subplots(2, 2, figsize=(18, 12))
        
        # 损失图
        axes[0, 0].plot(history['train_loss'], label='Train Loss')
        axes[0, 0].plot(history['val_loss'], label='Val Loss')
        axes[0, 0].set_title('Loss')
        axes[0, 0].set_xlabel('Epochs')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # 准确率图
        axes[0, 1].plot(history['train_acc'], label='Train Accuracy')
        axes[0, 1].plot(history['val_acc'], label='Val Accuracy')
        axes[0, 1].set_title('Accuracy')
        axes[0, 1].set_xlabel('Epochs')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # F1分数图
        axes[1, 0].plot(history['train_f1'], label='Train F1')
        axes[1, 0].plot(history['val_f1'], label='Val F1')
        axes[1, 0].set_title('F1 Score')
        axes[1, 0].set_xlabel('Epochs')
        axes[1, 0].set_ylabel('F1 Score')
        axes[1, 0].legend()
        axes[1, 0].grid(True)
        
        # AUC图(如果有)
        if None not in history['val_auc']:
            axes[1, 1].plot(history['val_auc'], label='Val AUC')
            axes[1, 1].set_title('AUC')
            axes[1, 1].set_xlabel('Epochs')
            axes[1, 1].set_ylabel('AUC')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        else:
            # 如果没有AUC,可以绘制其他指标
            axes[1, 1].plot(history['train_prec'], label='Train Precision')
            axes[1, 1].plot(history['val_prec'], label='Val Precision')
            axes[1, 1].set_title('Precision')
            axes[1, 1].set_xlabel('Epochs')
            axes[1, 1].set_ylabel('Precision')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig('training_history.png')
        plt.show()
    
    def test(self, load_best_model=True):
        """测试模型"""
        if self.test_loader is None:
            print("No test dataset provided.")
            return None
        
        # 加载最佳模型
        if load_best_model:
            self.model.load_state_dict(torch.load(self.best_model_path))
            print(f"Loaded best model from {self.best_model_path}")
        
        self.model.eval()
        all_preds = []
        all_targets = []
        all_probs = []
        
        with torch.no_grad():
            for inputs, targets in tqdm(self.test_loader, desc="Testing"):
                # 将数据移到设备上
                inputs = inputs.to(self.device, non_blocking=True)
                
                # 前向传播
                outputs = self.model(inputs)
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)
                
                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # 计算评估指标
        acc = accuracy_score(all_targets, all_preds)
        prec = precision_score(all_targets, all_preds, average='weighted', zero_division=0)
        recall = recall_score(all_targets, all_preds, average='weighted', zero_division=0)
        f1 = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
        
        # 如果是二分类问题,计算AUC
        if len(np.unique(all_targets)) == 2:
            auc = roc_auc_score(all_targets, np.array(all_probs)[:, 1])
        else:
            auc = None
        
        # 打印结果
        print("\nTest Results:")
        print(f"Accuracy: {acc:.4f}")
        print(f"Precision: {prec:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1 Score: {f1:.4f}")
        if auc:
            print(f"AUC: {auc:.4f}")
        
        return {
            'accuracy': acc,
            'precision': prec,
            'recall': recall,
            'f1': f1,
            'auc': auc,
            'predictions': all_preds,
            'targets': all_targets,
            'probabilities': all_probs
        }

# 使用示例
def main():
    # 设置随机种子
    set_seed(42)
    
    # 假设我们已经有预处理好的数据
    # 这里仅作示例,实际使用需替换为真实数据路径
    ct_paths = ["path/to/ct1.npy", "path/to/ct2.npy", "..."]
    labels = [0, 1, "..."]  # 0代表正常,1代表疾病
    
    # 计算类别权重
    class_weights = calculate_class_weights(labels, method='effective')
    
    # 创建数据增强器
    augmentation = CTAugmentation3D(
        rotation_range=(-10, 10),
        scale_range=(0.9, 1.1),
        shift_range=(-5, 5),
        noise_factor=0.03,
        brightness_range=(0.9, 1.1),
        contrast_range=(0.9, 1.1)
    )
    
    # 创建数据集
    full_dataset = LungCTDataset(ct_paths, labels, transform=augmentation)
    
    # 划分数据集
    train_size = int(0.7 * len(full_dataset))
    val_size = int(0.15 * len(full_dataset))
    test_size = len(full_dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size]
    )
    
    # 创建模型
    model = resnet18_3d(num_classes=2)
    
    # 创建损失函数
    # 对于严重类别不平衡,可以使用Focal Loss
    loss_fn = FocalLoss(alpha=class_weights, gamma=2.0)
    
    # 创建训练器
    trainer = LungCTTrainer(
        model=model,
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        batch_size=8,
        lr=0.001,
        loss_fn=loss_fn,
        class_weights=class_weights,
        experiment_name="lung_ct_3d_resnet"
    )
    
    # 训练模型
    history = trainer.train(epochs=50)
    
    # 测试模型
    test_results = trainer.test()
    
    # 打印测试结果汇总
    print("\nTest Results Summary:")
    print(f"Accuracy: {test_results['accuracy']:.4f}")
    print(f"F1 Score: {test_results['f1']:.4f}")
    if test_results['auc']:
        print(f"AUC: {test_results['auc']:.4f}")

if __name__ == "__main__":
    main()

清华大学全五版的《DeepSeek教程》完整的文档需要的朋友,关注我私信:deepseek 即可获得。

怎么样今天的内容还满意吗?再次感谢朋友们的观看,关注GZH:凡人的AI工具箱,回复666,送您价值199的AI大礼包。最后,祝您早日实现财务自由,还请给个赞,谢谢!

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐