用 Faster R-CNN 玩转 COCO 数据集:目标检测与图像识别实操
安装必要的库和工具。确保已安装 PyTorch、Torchvision 和 COCO API。加载预训练的 Faster R-CNN 模型。使用 Torchvision 提供的 COCO 数据集接口加载数据。在验证集上评估模型性能。设置训练参数并启动训练。文件,需从 COCO 官网下载并解压到指定目录。下载 COCO 数据集。
·
准备工作
安装必要的库和工具。确保已安装 PyTorch、Torchvision 和 COCO API。使用以下命令安装依赖:
pip install torch torchvision
pip install pycocotools
下载 COCO 数据集。官方提供 train2017、val2017 和 annotations 文件,需从 COCO 官网下载并解压到指定目录。
加载数据集
使用 Torchvision 提供的 COCO 数据集接口加载数据。以下是加载训练集的示例代码:
from torchvision.datasets import CocoDetection
import torchvision.transforms as T
transform = T.Compose([
T.ToTensor(),
])
dataset = CocoDetection(
root='path/to/train2017',
annFile='path/to/annotations/instances_train2017.json',
transform=transform
)
模型配置
加载预训练的 Faster R-CNN 模型。Torchvision 提供了在 COCO 上预训练的模型:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 91 # COCO 默认类别数
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
训练模型
设置训练参数并启动训练。需定义数据加载器、优化器和损失函数:
import torch
from torch.utils.data import DataLoader
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
for epoch in range(5):
model.train()
for images, targets in data_loader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
模型评估
在验证集上评估模型性能。使用 COCO 提供的评估工具计算 mAP(平均精度):
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
coco_gt = COCO('path/to/annotations/instances_val2017.json')
model.eval()
results = []
for image_id in coco_gt.getImgIds()[:100]: # 示例:评估前 100 张图像
image_info = coco_gt.loadImgs(image_id)[0]
image_path = f'path/to/val2017/{image_info["file_name"]}'
image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
prediction = model(image_tensor)
results.extend(format_coco_predictions(prediction, image_id))
coco_dt = coco_gt.loadRes(results)
coco_eval = COCOeval(coco_gt, coco_dt, 'bbox')
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
预测与可视化
对单张图像进行预测并可视化结果:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def plot_prediction(image, prediction):
fig, ax = plt.subplots(1)
ax.imshow(image)
for box, label, score in zip(prediction['boxes'], prediction['labels'], prediction['scores']):
if score > 0.5: # 置信度阈值
box = box.cpu().numpy()
rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(box[0], box[1], f'{label}: {score:.2f}', color='white', bbox=dict(facecolor='red', alpha=0.5))
plt.show()
image = Image.open('test_image.jpg').convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
prediction = model(image_tensor)[0]
plot_prediction(image, prediction)
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)