准备工作

安装必要的库和工具。确保已安装 PyTorch、Torchvision 和 COCO API。使用以下命令安装依赖:

pip install torch torchvision
pip install pycocotools

下载 COCO 数据集。官方提供 train2017val2017annotations 文件,需从 COCO 官网下载并解压到指定目录。

加载数据集

使用 Torchvision 提供的 COCO 数据集接口加载数据。以下是加载训练集的示例代码:

from torchvision.datasets import CocoDetection
import torchvision.transforms as T

transform = T.Compose([
    T.ToTensor(),
])

dataset = CocoDetection(
    root='path/to/train2017',
    annFile='path/to/annotations/instances_train2017.json',
    transform=transform
)

模型配置

加载预训练的 Faster R-CNN 模型。Torchvision 提供了在 COCO 上预训练的模型:

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 91  # COCO 默认类别数
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

训练模型

设置训练参数并启动训练。需定义数据加载器、优化器和损失函数:

import torch
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

for epoch in range(5):
    model.train()
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

模型评估

在验证集上评估模型性能。使用 COCO 提供的评估工具计算 mAP(平均精度):

from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

coco_gt = COCO('path/to/annotations/instances_val2017.json')
model.eval()
results = []

for image_id in coco_gt.getImgIds()[:100]:  # 示例:评估前 100 张图像
    image_info = coco_gt.loadImgs(image_id)[0]
    image_path = f'path/to/val2017/{image_info["file_name"]}'
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        prediction = model(image_tensor)
    results.extend(format_coco_predictions(prediction, image_id))

coco_dt = coco_gt.loadRes(results)
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()

预测与可视化

对单张图像进行预测并可视化结果:

import matplotlib.pyplot as plt
import matplotlib.patches as patches

def plot_prediction(image, prediction):
    fig, ax = plt.subplots(1)
    ax.imshow(image)
    for box, label, score in zip(prediction['boxes'], prediction['labels'], prediction['scores']):
        if score > 0.5:  # 置信度阈值
            box = box.cpu().numpy()
            rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
            ax.text(box[0], box[1], f'{label}: {score:.2f}', color='white', bbox=dict(facecolor='red', alpha=0.5))
    plt.show()

image = Image.open('test_image.jpg').convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
    prediction = model(image_tensor)[0]
plot_prediction(image, prediction)

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐