pytorch 演示 “结构化状态空间序列模型(Structured State-Space Sequence Model, S4)“ 基于 MINIST数据集
·
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据集加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = torchvision.datasets.MNIST(
root='/app7/dataset/',
train=True,
download=True,
transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='/app7/dataset/',
train=False,
download=True,
transform=transform
)
# 小批量大小以适应显存限制
batch_size = 64
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
# 简化的S4层实现
class S4Layer(nn.Module):
def __init__(self, input_dim, state_dim, seq_len):
super(S4Layer, self).__init__()
self.state_dim = state_dim
self.seq_len = seq_len
# 状态矩阵参数 (对角化简化)
self.A = nn.Parameter(torch.randn(state_dim) * 0.1)
# 输入转换矩阵
self.B = nn.Parameter(torch.randn(input_dim, state_dim) * 0.1)
# 输出转换矩阵
self.C = nn.Parameter(torch.randn(input_dim, state_dim) * 0.1)
# 跳跃连接
self.D = nn.Parameter(torch.randn(input_dim) * 0.1)
# 时间步参数
self.dt = nn.Parameter(torch.randn(1) * 0.1)
def forward(self, x):
# x 尺寸: (batch_size, seq_len, input_dim)
batch_size, seq_len, input_dim = x.shape
# 离散化参数
dt = torch.exp(self.dt) # 确保为正数
A_disc = torch.exp(dt * self.A) # 对角矩阵离散化
# 初始化状态
state = torch.zeros(batch_size, self.state_dim, device=x.device)
outputs = []
# 循环处理序列
for t in range(seq_len):
# 当前输入 (batch_size, input_dim)
u_t = x[:, t, :]
# 状态更新 (对角矩阵乘法)
state = A_disc * state + torch.einsum('bi,ij->bj', u_t, self.B)
# 输出计算
y_t = torch.einsum('bi,ij->bj', state, self.C.t()) + self.D * u_t
outputs.append(y_t.unsqueeze(1))
# 堆叠所有时间步输出
output = torch.cat(outputs, dim=1)
return output
# 完整模型定义
class S4Model(nn.Module):
def __init__(self, input_dim=28, state_dim=32, seq_len=28, num_classes=10):
super(S4Model, self).__init__()
self.seq_len = seq_len
# 输入嵌入层
self.embedding = nn.Linear(input_dim, 32)
# S4层
self.s4 = S4Layer(input_dim=32, state_dim=state_dim, seq_len=seq_len)
# 输出层 (修复这里缺少的右括号)
self.fc = nn.Sequential(
nn.Linear(32 * seq_len, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
) # 添加了缺失的右括号
def forward(self, x):
# x 尺寸: (batch_size, 1, 28, 28)
x = x.squeeze(1) # 移除通道维度 -> (batch_size, 28, 28)
# 处理每行作为时间步
x = x.permute(0, 2, 1) # (batch_size, 28, 28) -> 每行28像素作为特征
# 嵌入层
x = self.embedding(x) # (batch_size, 28, 32)
# S4层处理
x = self.s4(x) # (batch_size, 28, 32)
# 展平输出
x = x.reshape(x.size(0), -1) # (batch_size, 28*32)
# 分类层
x = self.fc(x)
return x
# 初始化模型
model = S4Model().to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练函数
def train(model, loader, optimizer, criterion, epoch):
model.train()
running_loss = 0.0
correct = 0
total = 0
with tqdm(loader, desc=f'训练 Epoch {epoch+1}') as pbar:
for inputs, targets in pbar:
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
# 反向传播
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
pbar.set_postfix(loss=running_loss/(pbar.n+1), acc=100.*correct/total)
train_loss = running_loss / len(loader)
train_acc = 100. * correct / total
return train_loss, train_acc
# 测试函数
def test(model, loader, criterion):
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for inputs, targets in loader:
inputs, targets = inputs.to(device), targets.to(device)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, targets)
# 统计信息
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
test_loss = running_loss / len(loader)
test_acc = 100. * correct / total
return test_loss, test_acc
# 训练和测试循环
def main():
num_epochs = 10
train_losses = []
test_losses = []
for epoch in range(num_epochs):
train_loss, train_acc = train(model, train_loader, optimizer, criterion, epoch)
test_loss, test_acc = test(model, test_loader, criterion)
print(f"Epoch {epoch+1}/{num_epochs}:")
print(f" 训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
print(f" 测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.2f}%")
print("-" * 50)
train_losses.append(train_loss)
test_losses.append(test_loss)
# 保存模型
torch.save(model.state_dict(), 's4_mnist.pth')
return train_losses, test_losses
# 绘制损失曲线
def pltLoss(train_losses, test_losses):
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.plot(test_losses, label='测试损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.title('训练和测试损失曲线')
plt.legend()
plt.grid(True)
plt.savefig('loss_curve.png')
plt.show()
# 执行主程序
if __name__ == "__main__":
train_losses, test_losses = main()
pltLoss(train_losses, test_losses)
# 最终测试
final_test_loss, final_test_acc = test(model, test_loader, criterion)
print(f"最终测试结果: 损失={final_test_loss:.4f}, 准确率={final_test_acc:.2f}%")
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)