import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)

# 检查GPU可用性
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 数据集加载
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
  root='/app7/dataset/',
  train=True,
  download=True,
  transform=transform
)

test_dataset = torchvision.datasets.MNIST(
  root='/app7/dataset/',
  train=False,
  download=True,
  transform=transform
)

# 小批量大小以适应显存限制
batch_size = 64
train_loader = torch.utils.data.DataLoader(
  train_dataset,
  batch_size=batch_size,
  shuffle=True,
  num_workers=2
)

test_loader = torch.utils.data.DataLoader(
  test_dataset,
  batch_size=batch_size,
  shuffle=False,
  num_workers=2
)

# 简化的S4层实现
class S4Layer(nn.Module):
  def __init__(self, input_dim, state_dim, seq_len):
    super(S4Layer, self).__init__()
    self.state_dim = state_dim
    self.seq_len = seq_len
    
    # 状态矩阵参数 (对角化简化)
    self.A = nn.Parameter(torch.randn(state_dim) * 0.1)
    # 输入转换矩阵
    self.B = nn.Parameter(torch.randn(input_dim, state_dim) * 0.1)
    # 输出转换矩阵
    self.C = nn.Parameter(torch.randn(input_dim, state_dim) * 0.1)
    # 跳跃连接
    self.D = nn.Parameter(torch.randn(input_dim) * 0.1)
    # 时间步参数
    self.dt = nn.Parameter(torch.randn(1) * 0.1)
    
  def forward(self, x):
    # x 尺寸: (batch_size, seq_len, input_dim)
    batch_size, seq_len, input_dim = x.shape
    
    # 离散化参数
    dt = torch.exp(self.dt)  # 确保为正数
    A_disc = torch.exp(dt * self.A)  # 对角矩阵离散化
    
    # 初始化状态
    state = torch.zeros(batch_size, self.state_dim, device=x.device)
    outputs = []
    
    # 循环处理序列
    for t in range(seq_len):
      # 当前输入 (batch_size, input_dim)
      u_t = x[:, t, :]
      
      # 状态更新 (对角矩阵乘法)
      state = A_disc * state + torch.einsum('bi,ij->bj', u_t, self.B)
      
      # 输出计算
      y_t = torch.einsum('bi,ij->bj', state, self.C.t()) + self.D * u_t
      outputs.append(y_t.unsqueeze(1))
    
    # 堆叠所有时间步输出
    output = torch.cat(outputs, dim=1)
    return output

# 完整模型定义
class S4Model(nn.Module):
  def __init__(self, input_dim=28, state_dim=32, seq_len=28, num_classes=10):
    super(S4Model, self).__init__()
    self.seq_len = seq_len
    
    # 输入嵌入层
    self.embedding = nn.Linear(input_dim, 32)
    # S4层
    self.s4 = S4Layer(input_dim=32, state_dim=state_dim, seq_len=seq_len)
    # 输出层 (修复这里缺少的右括号)
    self.fc = nn.Sequential(
      nn.Linear(32 * seq_len, 128),
      nn.ReLU(),
      nn.Linear(128, num_classes)
    )  # 添加了缺失的右括号
  
  def forward(self, x):
    # x 尺寸: (batch_size, 1, 28, 28)
    x = x.squeeze(1)  # 移除通道维度 -> (batch_size, 28, 28)
    
    # 处理每行作为时间步
    x = x.permute(0, 2, 1)  # (batch_size, 28, 28) -> 每行28像素作为特征
    
    # 嵌入层
    x = self.embedding(x)  # (batch_size, 28, 32)
    
    # S4层处理
    x = self.s4(x)  # (batch_size, 28, 32)
    
    # 展平输出
    x = x.reshape(x.size(0), -1)  # (batch_size, 28*32)
    
    # 分类层
    x = self.fc(x)
    return x

# 初始化模型
model = S4Model().to(device)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练函数
def train(model, loader, optimizer, criterion, epoch):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0
  
  with tqdm(loader, desc=f'训练 Epoch {epoch+1}') as pbar:
    for inputs, targets in pbar:
      inputs, targets = inputs.to(device), targets.to(device)
      
      # 前向传播
      optimizer.zero_grad()
      outputs = model(inputs)
      loss = criterion(outputs, targets)
      
      # 反向传播
      loss.backward()
      optimizer.step()
      
      # 统计信息
      running_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()
      
      pbar.set_postfix(loss=running_loss/(pbar.n+1), acc=100.*correct/total)
  
  train_loss = running_loss / len(loader)
  train_acc = 100. * correct / total
  return train_loss, train_acc

# 测试函数
def test(model, loader, criterion):
  model.eval()
  running_loss = 0.0
  correct = 0
  total = 0
  
  with torch.no_grad():
    for inputs, targets in loader:
      inputs, targets = inputs.to(device), targets.to(device)
      
      # 前向传播
      outputs = model(inputs)
      loss = criterion(outputs, targets)
      
      # 统计信息
      running_loss += loss.item()
      _, predicted = outputs.max(1)
      total += targets.size(0)
      correct += predicted.eq(targets).sum().item()
  
  test_loss = running_loss / len(loader)
  test_acc = 100. * correct / total
  return test_loss, test_acc

# 训练和测试循环
def main():
  num_epochs = 10
  train_losses = []
  test_losses = []
  
  for epoch in range(num_epochs):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, epoch)
    test_loss, test_acc = test(model, test_loader, criterion)
    
    print(f"Epoch {epoch+1}/{num_epochs}:")
    print(f"  训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}%")
    print(f"  测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.2f}%")
    print("-" * 50)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
  
  # 保存模型
  torch.save(model.state_dict(), 's4_mnist.pth')
  return train_losses, test_losses

# 绘制损失曲线
def pltLoss(train_losses, test_losses):
  plt.figure(figsize=(10, 5))
  plt.plot(train_losses, label='训练损失')
  plt.plot(test_losses, label='测试损失')
  plt.xlabel('Epoch')
  plt.ylabel('损失')
  plt.title('训练和测试损失曲线')
  plt.legend()
  plt.grid(True)
  plt.savefig('loss_curve.png')
  plt.show()

# 执行主程序
if __name__ == "__main__":
  train_losses, test_losses = main()
  pltLoss(train_losses, test_losses)
  
  # 最终测试
  final_test_loss, final_test_acc = test(model, test_loader, criterion)
  print(f"最终测试结果: 损失={final_test_loss:.4f}, 准确率={final_test_acc:.2f}%")
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐