PyTorch 深度学习笔记(十一):Swish 激活函数的 PyTorch 实现与训练效果验证
Swish 激活函数是 Google 提出的一种新型激活函数,其定义为 $f(x) = x \cdot \sigma(\beta x)$,其中 $\sigma$ 为 Sigmoid 函数,$\beta$ 为可学习的参数。Swish 激活函数在多个任务上表现优于 ReLU 激活函数。Swish 激活函数的 PyTorch 实现有多种方式,可以直接使用内置函数或自定义实现。
·
PyTorch 实现与训练效果验证
Swish 激活函数是 Google 提出的一种新型激活函数,其定义为 $f(x) = x \cdot \sigma(\beta x)$,其中 $\sigma$ 为 Sigmoid 函数,$\beta$ 为可学习的参数。Swish 激活函数在多个任务上表现优于 ReLU 激活函数。
Swish 激活函数的 PyTorch 实现
Swish 激活函数的 PyTorch 实现有多种方式,可以直接使用内置函数或自定义实现。
方法一:使用 torch.nn.Swish(PyTorch 1.7+)
import torch
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.swish(self.fc1(x))
x = F.swish(self.fc2(x))
return x
方法二:手动实现 Swish
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x) # 或使用 beta 参数
训练效果验证
Swish 激活函数通常在实际任务中表现良好,以下是其训练效果的验证方法:
对比实验
-
任务验证
- 在 CIFAR-10 和 CIFAR-100 数据集上训练 ResNet 模型,使用 Swish 代替 ReLU 激活函数。
- 观察准确率和损失曲线的变化。
-
训练日志
- 记录训练过程中的损失、准确率以及学习率变化。
- 分析 Swish 是否有助于梯度传播和模型收敛。
结果分析
- 收敛速度
Swish 激活函数在某些任务上收敛速度更快,尤其是在深度网络中。 - 梯度消失
Swish 可以缓解梯度消失问题,因为其梯度始终为正,不像 ReLU 在负数区域完全关闭。
常见问题解答
问题 1:Swish 激活函数在残差网络中的表现
- 残差连接:Swish 在残差网络中的表现与 ReLU 类似,但通常更平滑。
- 梯度消失:Swish 对梯度消失问题有更好的缓解效果。
问题 2:Swish 在 RNN 中的表现
- RNN 任务:Swish 在 RNN 中的表现优于 tanh 激活函数,但略逊于 ReLU。
问题 3:Swish 在 Transformer 模型中的表现
- Transformer 任务:Swish 在 Transformer 模型中的表现与 ReLU 类似,但有时更稳定。
其他注意事项
- 内存占用:Swish 激活函数通常比 ReLU 占用更多的内存,因为需要计算 Sigmoid 函数。
- 计算效率:Swish 的计算效率通常低于 ReLU 激活函数,但可以通过优化实现来提升。
注意事项
- 内存占用:Swish 激活函数在 GPU 上的计算效率低于 ReLU 激活函数,但可以通过优化实现来提升。
- 计算效率:Swish 的计算效率通常低于 ReLU 激活函数,但可以通过优化实现来提升。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐

所有评论(0)