PyTorch实现NMS算法
·
介绍
参考链接1:NMS 算法源码实现
参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。
示例代码
以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes(边界框),scores(每个边界框的得分),和 iou_threshold(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2],其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标。
import torch
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):
"""
Perform Non-Maximum Suppression (NMS) on bounding boxes.
Args:
boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxes
of shape [x1, y1, x2, y2], where N is the number of boxes.
scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.
iou_threshold (float): The IoU threshold for suppressing boxes.
Returns:
torch.Tensor: A tensor of indices of the boxes to keep.
"""
# Get the areas of the boxes
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1) * (y2 - y1)
# Sort the scores in descending order and get the sorted indices
_, order = scores.sort(0, descending=True)
keep = []
while order.numel() > 0:
if order.numel() == 1:
i = order.item()
keep.append(i)
break
else:
i = order[0].item()
keep.append(i)
# Compute the IoU of the kept box with the rest
xx1 = torch.max(x1[i], x1[order[1:]])
yy1 = torch.max(y1[i], y1[order[1:]])
xx2 = torch.min(x2[i], x2[order[1:]])
yy2 = torch.min(y2[i], y2[order[1:]])
w = torch.clamp(xx2 - xx1, min=0)
h = torch.clamp(yy2 - yy1, min=0)
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter)
# Keep the boxes with IoU less than the threshold
inds = torch.where(iou <= iou_threshold)[0]
order = order[inds + 1]
return torch.tensor(keep, dtype=torch.long)
代码工作原理:
- 计算每个边界框的面积。
- 根据得分对边界框进行降序排序。
- 依次选择得分最高的边界框,并计算它与其他边界框的 IoU。
- 保留 IoU 小于阈值的边界框,并继续处理剩余的边界框。
- 返回保留的边界框的索引。
补充NMS前置处理
以下是yolov5前置处理代码:
def non_max_suppression(
prediction,
conf_thres=0.25,
iou_thres=0.45,
classes=None,
agnostic=False,
multi_label=False,
labels=(),
max_det=300,
nm=0, # number of masks
):
"""
Non-Maximum Suppression (NMS) on inference results to reject overlapping detections.
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
# Checks
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
device = prediction.device
mps = "mps" in device.type # Apple MPS
if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
prediction = prediction.cpu()
bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - nm - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 0.5 + 0.05 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
mi = 5 + nc # mask start index
output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
x = x[xc[xi]] # confidence
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
lb = labels[xi]
v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
v[:, :4] = lb[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(lb)), lb[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
# Box/Mask
box = xywh2xyxy(x[:, :4]) # center_x, center_y, width, height) to (x1, y1, x2, y2)
mask = x[:, mi:] # zero columns if no masks
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:mi] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 5 + j, None], j[:, None].float(), mask[i]), 1)
else: # best class only
conf, j = x[:, 5:mi].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
i = i[:max_det] # limit detections
if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if mps:
output[xi] = output[xi].to(device)
if (time.time() - t) > time_limit:
LOGGER.warning(f"WARNING ⚠️ NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return output
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)