PyTorch 深度学习笔记(十一):Tanh 激活函数

Tanh 函数解析

Tanh(双曲正切)激活函数将输入值映射到 $(-1, 1)$ 区间,其数学表达式为: $$\tanh(x) = \frac{\sinh(x)}{\cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}}$$ 函数特性:

  1. 零中心化:输出均值为 0,有利于梯度下降优化
  2. 饱和性:当 $|x| > 2$ 时梯度接近 0(梯度消失问题)
  3. 导数计算:$\frac{d}{dx}\tanh(x) = 1 - \tanh^2(x)$
PyTorch 实现示例
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader

# 1. 数据准备
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5,), (0.5,))
])
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

# 2. 定义Tanh网络
class TanhNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.act = nn.Tanh()  # Tanh激活层
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)  # 展平图像
        x = self.act(self.fc1(x))
        return self.fc2(x)

# 3. 模型初始化
model = TanhNet()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. 训练函数
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item():.4f}')

# 5. 验证函数
def validate():
    model.eval()
    test_set = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
    test_loader = DataLoader(test_set, batch_size=1000)
    
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\n验证准确率: {accuracy:.2f}%\n')

# 6. 执行训练验证
for epoch in range(1, 6):  # 训练5个epoch
    train(epoch)
    validate()

关键特性验证
# Tanh梯度测试
x = torch.linspace(-5, 5, 100, requires_grad=True)
y = torch.tanh(x)
y.sum().backward()

# 可视化梯度变化
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.plot(x.detach(), y.detach(), label='Tanh')
plt.title("激活值")
plt.subplot(122)
plt.plot(x.detach(), x.grad, label='Gradient')
plt.title("梯度值")
plt.tight_layout()
plt.show()

训练注意事项
  1. 学习率调整:Tanh在饱和区梯度较小,建议使用自适应优化器(如Adam)或学习率衰减
  2. 权重初始化:采用 Xavier 初始化:
    nn.init.xavier_uniform_(self.fc1.weight)
    

  3. 批归一化:配合 BatchNorm 可缓解梯度消失:
    self.bn = nn.BatchNorm1d(128)
    x = self.act(self.bn(self.fc1(x)))
    

输出特征分析

当输入 $x \in \mathbb{R}$ 时,Tanh 输出分布:

输入区间 输出区间 梯度强度
$x < -2$ $(-1, -0.96)$ $<0.08$
$-2 ≤ x ≤ 2$ $[-0.96, 0.96]$ $[0.08, 1]$
$x > 2$ $(0.96, 1)$ $<0.08$

应用建议:Tanh 适用于 RNN 和 LSTM 的隐藏状态转换,在 CNN 中效果通常不如 ReLU 系列激活函数。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐