目录

1. ResNet结构解析

ResNet核心思想

典型ResNet块结构

2. CBAM放置位置的思考

最佳实践位置

插入示例(以ResNet为例)

3. 针对预训练模型的训练策略

a. 差异化学习率

实施方法

b. 三阶段微调策略

阶段1:冻结特征提取器

阶段2:部分解冻

阶段3:全网络微调

代码实现示例

总结


1. ResNet结构解析

ResNet核心思想

  • 残差连接:解决深层网络梯度消失问题

  • 恒等映射:通过shortcut连接实现

  • 瓶颈结构:1×1卷积降维 → 3×3卷积 → 1×1卷积升维

典型ResNet块结构

输入

├─ 卷积层1 → BN → ReLU

├─ 卷积层2 → BN → ReLU

└─ (下采样shortcut,如果需要)

相加 → ReLU

输出

2. CBAM放置位置的思考

最佳实践位置

  1. 每个残差块之后:增强块内特征选择能力

  2. 下采样层之前:帮助网络选择重要特征进行传递

  3. 网络末端:强化最终输出特征

插入示例(以ResNet为例)

原始ResNet块:
[Conv→BN→ReLU→Conv→BN] + shortcut → ReLU

加入CBAM后:
[Conv→BN→ReLU→Conv→BN] → CBAM + shortcut → ReLU

3. 针对预训练模型的训练策略

a. 差异化学习率

实施方法
  1. 分层设置

    • 浅层:较小学习率(保持底层特征)

    • 新增层/CBAM层:较大学习率(快速适应)

    • 深层:中等学习率(微调高层语义)

  2. 参数组示例

optimizer = torch.optim.Adam([
    {'params': model.backbone.parameters(), 'lr': 1e-5},    # 预训练部分
    {'params': model.cbam.parameters(), 'lr': 1e-3},        # 新增CBAM
    {'params': model.fc.parameters(), 'lr': 1e-4}           # 分类头
])

b. 三阶段微调策略

阶段1:冻结特征提取器
  • 操作:冻结所有预训练层

  • 训练:仅训练新增CBAM模块和分类头

  • 目的:初步适应目标任务

阶段2:部分解冻
  • 操作:解冻最后1-2个stage的ResNet层

  • 训练:同时训练解冻层+CBAM+分类头

  • 学习率:比阶段1略低

阶段3:全网络微调
  • 操作:解冻全部网络层

  • 训练:整体微调,使用更小学习率

  • 技巧:添加学习率warmup

代码实现示例

from torchvision.models import resnet50
import torch.nn as nn

class ResNet_CBAM(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        # 加载预训练ResNet
        self.backbone = resnet50(pretrained=True)  
        
        # 在layer2-4后插入CBAM
        self.backbone.layer2 = nn.Sequential(
            self.backbone.layer2,
            CBAM(512)
        )
        self.backbone.layer3 = nn.Sequential(
            self.backbone.layer3,
            CBAM(1024)
        )
        self.backbone.layer4 = nn.Sequential(
            self.backbone.layer4,
            CBAM(2048)
        )
        
        # 替换分类头
        self.backbone.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        return self.backbone(x)

总结

  1. 架构设计:CBAM应插入残差块之后或网络关键位置

  2. 训练策略:采用分阶段、差异化学习率微调

  3. 性能平衡:保持预训练特征的同时有效集成CBAM

  4. 实践建议:从小学习率开始,逐步解冻网络层

@浙大疏锦行

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐