使用Pytorch来完成CNN的训练和验证过程,逻辑结构如下:

  • 构造训练集和验证集;
  • 每轮进行训练和验证,并根据最优验证集精度保存模型。
train_loader = torch.utils.data.DataLoader(
  train_dataset,
  batch_size=10, 
  shuffle=True, 
  num_workers=10, 
)

val_loader = torch.utils.data.DataLoader(
  val_dataset,
  batch_size=10, 
  shuffle=False, 
  num_workers=10, 
)

model = Model1()
criterion = nn.CrossEntropyLoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):
print('Epoch: ', epoch)

train(train_loader, model, criterion, optimizer, epoch)
val_loss = validate(val_loader, model, criterion)

# 记录下验证集精度
if val_loss < best_loss:
    best_loss = val_loss
    torch.save(model.state_dict(), './model.pt')

在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:

torch.save(model_object.state_dict(), 'model.pt')

model.load_state_dict(torch.load(' model.pt'))

模型调参流程:
在这里插入图片描述

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐