《学习迁移》

我们可以通过Pytorch、Git等途径获得别人用其他数据集训练好的模型,我们可以利用他们训练好的模型,微调一下来跑自己的任务

本项目是迁移Pytorch上的ResNet18,该模型是用ImageNet数据集训练的,ImageNet数据集包含了1000个不同的图像类,我们尝试用该模型来分类蚂蚁和蜜蜂(这两个类ImageNet数据集并不包含)

为了训练这个新分类器,将用120张新的蜜蜂和蚂蚁训练图像训练这个模型,从而进行微调,与ImageNet中数百万张图像相比,这个数据集就微乎其微了

项目代码如下:

1、加载项目所需的库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR

import torchvision
from torchvision import datasets
from torchvision import models
from torchvision import transforms

import numpy as np

from io import BytesIO
from urllib.request import urlopen
from zipfile import ZipFile

import matplotlib.pyplot as plt

2、下载数据

调用urllib.request.urlopen从下载网址https://pytorch.tips/bee-zip中下载数据,并生成对象zipresp;

调用read读取zipresp的二进制数据并通过调用BytesIO写入经过UTF-8编码的bytes流后调用zipfile.ZipFile生成对象zfile

调用extractall对zfile对象进行解压文件到指定文件夹./data

用IO流来进行操作,无需将zip文件下载到本地磁盘上
zipurl='https://pytorch.tips/bee-zip'
with urlopen(zipurl) as zipresp:
    with ZipFile(BytesIO(zipresp.read())) as zfile:
                 zfile.extractall('./data')

3、定义数据的变换

#定义变换
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

4、加载数据

#数据集
train_dataset = datasets.ImageFolder(
    root = './data/hymenoptera_data/train',
    transform = train_transforms)

val_dataset = datasets.ImageFolder(
    root = './data/hymenoptera_data/val',
    transform = val_transforms)

#数据加载器
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4
)
val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=4
)

5、生成ResNet18模型

#模型
model = models.resnet18(pretrained = True)
print(model.fc)
model.fc = nn.Linear(model.fc.in_features, 2)
print(model.fc)

6、定义超参数

model = model.to("cuda")

Loss = nn.CrossEntropyLoss()

optim = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

exp_lr_scheduler = StepLR(optim, step_size=7, gamma=0.1)

7、微调

#训练
num_epochs = 25
for epoch in range(num_epochs):
    #训练模型
    model.train()
    running_loss = 0.0
    running_corrects = 0
    for inputs, labels in train_loader:
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = Loss(outputs, labels)
        loss.backward()
        optim.step()
        optim.zero_grad()
        
        running_loss += loss.item()/inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)/inputs.size(0)
    
    exp_lr_scheduler.step()
    train_epoch_loss = running_loss/len(train_loader)
    train_epoch_acc = running_corrects/len(train_loader)
    
    #测试模型
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    for inputs, labels in val_loader:
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")
        
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = Loss(outputs, labels)
       
        running_loss += loss.item()/inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)/inputs.size(0)
    
    epoch_loss = running_loss/len(val_loader)
    epoch_acc = running_corrects/len(val_loader)
    
    if((epoch+1)%5==0):
        print("epoch:{},  "
            "Train_loss:{:.4f} Train_acc:{:.4f},  "
            "Loss:{:.4f}, Acc:{:.4f}".format(epoch+1, train_epoch_loss, train_epoch_acc, epoch_loss, epoch_acc))

8、测试可视化

#测试可视化
def imshow(inp, title=None):
    #从 C-H-W 切换回 H-W-C 图像格式
    inp = inp.numpy().transpose((1, 2, 0))
    #撤销归一化
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    
    if title is not None:
        plt.title(title)

inputs , classes = next(iter(val_loader))
out = torchvision.utils.make_grid(inputs)
class_names = val_dataset.classes
outputs = model(inputs.to("cuda"))

_, preds = torch.max(outputs, 1)

imshow(out, title=[class_names[x] for x in preds])

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐