PyTorch 深度学习笔记(十一):Tanh 激活函数的 PyTorch 代码示例与训练验证
Tanh 适用于 RNN 和 LSTM 的隐藏状态转换,在 CNN 中效果通常不如 ReLU 系列激活函数。
·
PyTorch 深度学习笔记(十一):Tanh 激活函数
Tanh 函数解析
Tanh(双曲正切)激活函数将输入值映射到 $(-1, 1)$ 区间,其数学表达式为: $$\tanh(x) = \frac{\sinh(x)}{\cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$ 函数特性:
- 零中心化:输出均值为 0,有利于梯度下降优化
- 饱和性:当 $|x| > 2$ 时梯度接近 0(梯度消失问题)
- 导数计算:$\frac{d}{dx}\tanh(x) = 1 - \tanh^2(x)$
PyTorch 实现示例
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
# 1. 数据准备
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))
])
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
# 2. 定义Tanh网络
class TanhNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.act = nn.Tanh() # Tanh激活层
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 784) # 展平图像
x = self.act(self.fc1(x))
return self.fc2(x)
# 3. 模型初始化
model = TanhNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 4. 训练函数
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item():.4f}')
# 5. 验证函数
def validate():
model.eval()
test_set = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(test_set, batch_size=1000)
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
accuracy = 100. * correct / len(test_loader.dataset)
print(f'\n验证准确率: {accuracy:.2f}%\n')
# 6. 执行训练验证
for epoch in range(1, 6): # 训练5个epoch
train(epoch)
validate()
关键特性验证
# Tanh梯度测试
x = torch.linspace(-5, 5, 100, requires_grad=True)
y = torch.tanh(x)
y.sum().backward()
# 可视化梯度变化
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x.detach(), y.detach(), label='Tanh')
plt.title("激活值")
plt.subplot(122)
plt.plot(x.detach(), x.grad, label='Gradient')
plt.title("梯度值")
plt.tight_layout()
plt.show()
训练注意事项
- 学习率调整:Tanh在饱和区梯度较小,建议使用自适应优化器(如Adam)或学习率衰减
- 权重初始化:采用 Xavier 初始化:
nn.init.xavier_uniform_(self.fc1.weight) - 批归一化:配合 BatchNorm 可缓解梯度消失:
self.bn = nn.BatchNorm1d(128) x = self.act(self.bn(self.fc1(x)))
输出特征分析
当输入 $x \in \mathbb{R}$ 时,Tanh 输出分布:
| 输入区间 | 输出区间 | 梯度强度 |
|---|---|---|
| $x < -2$ | $(-1, -0.96)$ | $<0.08$ |
| $-2 ≤ x ≤ 2$ | $[-0.96, 0.96]$ | $[0.08, 1]$ |
| $x > 2$ | $(0.96, 1)$ | $<0.08$ |
应用建议:Tanh 适用于 RNN 和 LSTM 的隐藏状态转换,在 CNN 中效果通常不如 ReLU 系列激活函数。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)