本文深入剖析人脸识别领域两大里程碑算法——Google的FaceNet和InsightFace的ArcFace,从数学原理、损失函数设计到完整PyTorch实现,帮你彻底理解现代人脸识别技术的核心。


一、引言:人脸识别的本质问题

1.1 人脸识别 ≠ 图像分类

初学者常有的误解:把人脸识别当作分类问题。

❌ 错误思路:分类方法
输入人脸 → CNN → Softmax → 输出"这是第1532号人"

问题:
1. 类别数巨大(十亿级身份)
2. 无法处理新注册的人(需要重新训练)
3. 每个人样本极少(很难训练好分类器)
✅ 正确思路:度量学习方法
输入人脸 → CNN → 特征向量(embedding) → 与数据库比对

优势:
1. 只需学习"什么是相似",不需要预定义类别
2. 新人注册只需提取特征,无需重新训练
3. 一次训练,处理无限身份

1.2 度量学习的核心目标

特征空间的理想状态:

┌────────────────────────────────────────────────────┐
│                                                    │
│      ●●●           同一人的特征                    │
│     ● A ●          聚集在一起        ▲▲▲          │
│      ●●●                            ▲ B ▲         │
│                                      ▲▲▲          │
│                    不同人的特征                    │
│   ■■■              相互分离                       │
│  ■ C ■                                            │
│   ■■■                          ◆◆◆               │
│                               ◆ D ◆              │
│                                ◆◆◆               │
│                                                    │
└────────────────────────────────────────────────────┘

数学目标:
- 类内距离最小化:d(A₁, A₂) → 0
- 类间距离最大化:d(A, B) → ∞

二、FaceNet:开创性的Triplet Loss

2.1 FaceNet概述

FaceNet是Google在2015年发表的开创性工作,首次将人脸识别准确率推到99.63%(LFW数据集)。

核心贡献

  1. 提出直接学习欧氏空间embedding的思路
  2. 设计Triplet Loss进行端到端训练
  3. 证明了128维embedding足够表示人脸
FaceNet架构:

Input Image (160×160×3)
        │
        ▼
┌───────────────────┐
│  CNN Backbone     │  ← Inception / ResNet
│  (特征提取)       │
└───────────────────┘
        │
        ▼
┌───────────────────┐
│  L2 Normalization │  ← 归一化到单位超球面
└───────────────────┘
        │
        ▼
   128-dim Embedding
   
   f(x) ∈ R^128, ||f(x)||₂ = 1

2.2 Triplet Loss原理

三元组的构成
每个训练样本是一个"三元组"(Triplet):

┌─────────────────────────────────────────────────────┐
│                                                     │
│   Anchor (A)          Positive (P)      Negative (N)│
│   ┌─────┐             ┌─────┐           ┌─────┐    │
│   │     │             │     │           │     │    │
│   │ 😀  │             │ 😄  │           │ 😐  │    │
│   │     │             │     │           │     │    │
│   └─────┘             └─────┘           └─────┘    │
│   Person A            Person A          Person B    │
│   (锚点)              (正样本)          (负样本)    │
│                                                     │
│   同一人的不同照片      与Anchor同一人    与Anchor不同人   │
│                                                     │
└─────────────────────────────────────────────────────┘
损失函数数学定义
Triplet Loss:

L = Σ max(0, ||f(A) - f(P)||² - ||f(A) - f(N)||² + α)
    ─────────────────────────────────────────────────
    所有三元组

其中:
- f(·): CNN特征提取函数
- ||·||²: 欧氏距离的平方
- α: margin(间隔),通常取0.2

直观理解:
要求 d(A,P) + α < d(A,N)
即:正样本距离 + 安全间隔 < 负样本距离
几何直觉:

训练前:                          训练后:
                                
    A ──────── N                     A ── P
    │                                     \
    │                                      \
    P                                       N (被推远)

目标:把P拉近,把N推远,且中间保持α的间隔
为什么需要margin?
没有margin的问题:

如果只要求 d(A,P) < d(A,N)

可能出现: d(A,P) = 0.49, d(A,N) = 0.50

虽然满足条件,但:
- 差距太小,容易误判
- 对噪声不鲁棒

有了margin:
要求 d(A,P) + 0.2 < d(A,N)
即 d(A,P) < d(A,N) - 0.2

这样就保证了足够的"安全距离"

2.3 三元组挖掘策略

Triplet Loss的效果严重依赖于三元组的选择

Easy/Hard/Semi-hard三元组
三元组难度分类:

设: d_pos = d(A, P), d_neg = d(A, N)

┌─────────────────────────────────────────────────────────┐
│                                                         │
│  Easy Negative (简单负样本):                            │
│  d_neg > d_pos + α                                      │
│  负样本已经足够远,Loss = 0,无学习信号                  │
│                                                         │
│  Hard Negative (困难负样本):                            │
│  d_neg < d_pos                                          │
│  负样本比正样本还近!可能导致训练不稳定                  │
│                                                         │
│  Semi-hard Negative (半困难负样本): ⭐推荐               │
│  d_pos < d_neg < d_pos + α                              │
│  负样本在"危险区间"内,提供有效学习信号                 │
│                                                         │
└─────────────────────────────────────────────────────────┘
数轴表示:

d_pos        d_pos + α
  │              │
  ▼              ▼
──┼──────────────┼──────────────────────→ d_neg
  │    Semi-hard │        Easy
  │              │
  │←───────────→│
  │  有效学习区间 │
Online Triplet Mining
# 在线三元组挖掘:在每个batch内动态选择三元组

def online_triplet_mining(embeddings, labels, margin=0.2):
    """
    Batch Hard策略:
    对每个anchor,选择最难的正样本和最难的负样本
    """
    pairwise_dist = compute_pairwise_distances(embeddings)
    
    triplet_loss = 0
    num_valid_triplets = 0
    
    for i in range(len(embeddings)):
        anchor_label = labels[i]
        
        # 找最难的正样本(同类中距离最远的)
        positive_mask = labels == anchor_label
        positive_mask[i] = False  # 排除自己
        hardest_positive_dist = pairwise_dist[i][positive_mask].max()
        
        # 找最难的负样本(异类中距离最近的)
        negative_mask = labels != anchor_label
        hardest_negative_dist = pairwise_dist[i][negative_mask].min()
        
        # 计算loss
        loss = max(0, hardest_positive_dist - hardest_negative_dist + margin)
        triplet_loss += loss
        
        if loss > 0:
            num_valid_triplets += 1
    
    return triplet_loss / max(num_valid_triplets, 1)

2.4 FaceNet的局限性

Triplet Loss的问题:

1. 三元组组合爆炸
   N个样本 → O(N³)种三元组
   难以遍历所有有效组合

2. 收敛慢
   每次只优化一个三元组
   需要大量迭代

3. 对采样策略敏感
   不好的三元组 → 训练失败
   需要精心设计mining策略

4. 没有显式的类别中心
   特征分布可能不够紧凑

三、ArcFace:基于角度间隔的革命性改进

3.1 从Softmax到ArcFace的演进

演进路线:

Softmax Loss
    │
    ▼ (引入margin)
L-Softmax (2016)
    │
    ▼ (简化margin形式)
SphereFace / A-Softmax (2017)
    │
    ▼ (改为余弦空间)
CosFace / AM-Softmax (2018)
    │
    ▼ (改为角度空间加性margin)
ArcFace (2019) ← 目前最优

3.2 Softmax Loss回顾

标准Softmax
传统分类的Softmax Loss:

L = -log(exp(W_y^T · x + b_y) / Σ_j exp(W_j^T · x + b_j))

其中:
- x: 特征向量
- W_j: 第j类的权重向量
- b_j: 偏置项
- y: 真实类别
问题:Softmax只要求"正确类别分数最高"
没有显式要求类间分离

可能出现:
Class 1: score = 0.35
Class 2: score = 0.34  ← 真实类别
Class 3: score = 0.31

虽然分类正确,但差距很小,不够鲁棒

3.3 角度视角的重新理解

关键洞察:内积 = 模长 × 余弦
将内积分解为角度形式:

W_j^T · x = ||W_j|| · ||x|| · cos(θ_j)

其中 θ_j 是特征向量x与第j类权重向量W_j的夹角

如果对W和x都做L2归一化:
||W_j|| = 1, ||x|| = 1

则:W_j^T · x = cos(θ_j)

Softmax变成了基于"角度"的分类!
几何直觉:

                    W_1 (Class 1)
                   ↗
                  /  θ_1 (夹角小 = 相似)
                 /
    ────────────●────────────→ W_2 (Class 2)
                x\
                  \  θ_2 (夹角大 = 不相似)
                   ↘
                    W_3 (Class 3)

分类决策 = 找到与x夹角最小的W_j

3.4 ArcFace损失函数

数学定义
ArcFace Loss:

L = -log(exp(s · cos(θ_y + m)) / (exp(s · cos(θ_y + m)) + Σ_{j≠y} exp(s · cos(θ_j))))

其中:
- θ_y: 特征与真实类别权重的夹角
- m: 角度间隔(margin),通常取0.5 (弧度,约28.6°)
- s: 缩放因子,通常取64

关键改动:在真实类别的角度上加一个惩罚项 m
cos(θ_y + m) < cos(θ_y),使得正确分类更难
直观理解
ArcFace的几何意义:

原始决策边界:
─────────────────────────────
        Class A │ Class B
                │
              θ = 90°

ArcFace决策边界(对Class A而言):
─────────────────────────────
   Class A  │   │  Class B
            │   │
          θ=90°-m  θ=90°+m
            
为了被判为Class A,x需要满足:
θ_A + m < θ_B
即 θ_A < θ_B - m

必须"更接近"Class A才行,margin m就是额外的要求
训练效果对比:

Softmax:                    ArcFace:
     W_A                         W_A
    ↗                           ↗
   / 松散的决策边界              / 紧凑的类内分布
  /  ●                         /●●●
 /   ● ●                      / ●●
/     ●                      /
──────────→ W_B           ──────────→ W_B
      ▲ ▲                        ▲▲▲
     ▲   ▲                      ▲▲▲▲
                              更大的类间间隔

3.5 为什么ArcFace更好?

与其他Margin方法对比
不同Margin Loss的对比:

┌────────────────────────────────────────────────────────┐
│ 方法          │ 公式                    │ 特点         │
├────────────────────────────────────────────────────────┤
│ SphereFace    │ cos(m·θ)               │ 乘性角度margin│
│ (A-Softmax)   │                         │ 优化困难     │
├────────────────────────────────────────────────────────┤
│ CosFace       │ cos(θ) - m             │ 加性余弦margin│
│ (AM-Softmax)  │                         │ 实现简单     │
├────────────────────────────────────────────────────────┤
│ ArcFace       │ cos(θ + m)             │ 加性角度margin│
│               │                         │ 几何意义清晰 │
└────────────────────────────────────────────────────────┘
决策边界的几何对比:

        cos(θ)
           ↑
         1 ┼───────────────────────
           │    ╲
           │     ╲  Softmax (无margin)
           │      ╲
           │       ╲
      cos(m)┼────────╲─────────────  SphereFace
           │         ╲╲
           │          ╲ ╲ ArcFace (角度空间等距)
           │           ╲  ╲
           │            ╲   ╲ CosFace (余弦空间等距)
         0 ┼─────────────┼───┼─────→ θ
           0            π/2   π

ArcFace的优势:
在角度空间上有恒定的间隔,几何意义最直观

3.6 ArcFace的训练细节

数值稳定性处理
# 当 θ + m > π 时,cos(θ + m) 会出问题
# 需要特殊处理

def arcface_loss(logits, labels, s=64.0, m=0.5):
    """
    数值稳定的ArcFace实现
    """
    # logits = cos(θ),范围 [-1, 1]
    # 由于数值精度,需要clamp
    cos_theta = torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)
    
    # 计算 θ
    theta = torch.acos(cos_theta)
    
    # 计算 cos(θ + m)
    # 只对正确类别加margin
    target_logits = torch.cos(theta + m)
    
    # 处理边界情况:当 θ + m > π
    # 使用 cos(θ) - m*sin(θ) 近似
    # 或者使用阈值截断
    
    # 组合最终logits
    one_hot = F.one_hot(labels, num_classes)
    output = logits * (1 - one_hot) + target_logits * one_hot
    output *= s
    
    return F.cross_entropy(output, labels)

四、完整PyTorch实现

4.1 Triplet Loss实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class TripletLoss(nn.Module):
    """
    Triplet Loss with online triplet mining
    
    支持多种挖掘策略:
    - batch_all: 使用所有有效三元组
    - batch_hard: 每个anchor选最难的正负样本
    - batch_semi_hard: 使用半困难三元组
    """
    
    def __init__(self, margin=0.2, mining='batch_hard'):
        super().__init__()
        self.margin = margin
        self.mining = mining
    
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: [B, D] L2归一化的特征向量
            labels: [B] 类别标签
        Returns:
            loss: 标量损失值
        """
        # 计算成对距离矩阵
        # dist[i,j] = ||emb_i - emb_j||²
        dist_mat = self._pairwise_distances(embeddings)
        
        if self.mining == 'batch_all':
            return self._batch_all_triplet_loss(dist_mat, labels)
        elif self.mining == 'batch_hard':
            return self._batch_hard_triplet_loss(dist_mat, labels)
        elif self.mining == 'batch_semi_hard':
            return self._batch_semi_hard_triplet_loss(dist_mat, labels)
        else:
            raise ValueError(f"Unknown mining strategy: {self.mining}")
    
    def _pairwise_distances(self, embeddings):
        """计算成对欧氏距离的平方"""
        # ||a - b||² = ||a||² + ||b||² - 2*a·b
        dot_product = torch.matmul(embeddings, embeddings.t())
        square_norm = torch.diag(dot_product)
        
        distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
        distances = F.relu(distances)  # 防止数值误差导致的负数
        
        return distances
    
    def _get_anchor_positive_mask(self, labels):
        """返回有效的anchor-positive对的mask"""
        # 同类且不是同一个样本
        labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
        indices_not_equal = ~torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
        return labels_equal & indices_not_equal
    
    def _get_anchor_negative_mask(self, labels):
        """返回有效的anchor-negative对的mask"""
        return labels.unsqueeze(0) != labels.unsqueeze(1)
    
    def _batch_all_triplet_loss(self, dist_mat, labels):
        """使用所有有效三元组"""
        anchor_positive_mask = self._get_anchor_positive_mask(labels)
        anchor_negative_mask = self._get_anchor_negative_mask(labels)
        
        # 计算所有三元组的loss
        # triplet_loss[i,j,k] = d(i,j) - d(i,k) + margin
        anchor_positive_dist = dist_mat.unsqueeze(2)
        anchor_negative_dist = dist_mat.unsqueeze(1)
        
        triplet_loss = anchor_positive_dist - anchor_negative_dist + self.margin
        
        # 创建三元组mask
        mask = anchor_positive_mask.unsqueeze(2) & anchor_negative_mask.unsqueeze(1)
        mask = mask.float()
        
        triplet_loss = triplet_loss * mask
        triplet_loss = F.relu(triplet_loss)
        
        # 计算有效三元组的平均loss
        num_positive_triplets = (triplet_loss > 1e-16).float().sum()
        loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
        
        return loss
    
    def _batch_hard_triplet_loss(self, dist_mat, labels):
        """
        Batch Hard策略
        对每个anchor,选择最难的正样本和最难的负样本
        """
        anchor_positive_mask = self._get_anchor_positive_mask(labels)
        anchor_negative_mask = self._get_anchor_negative_mask(labels)
        
        # 最难的正样本:同类中距离最大的
        anchor_positive_dist = dist_mat * anchor_positive_mask.float()
        hardest_positive_dist, _ = anchor_positive_dist.max(dim=1, keepdim=True)
        
        # 最难的负样本:异类中距离最小的
        # 把同类的距离设为很大的值
        max_dist = dist_mat.max()
        anchor_negative_dist = dist_mat + max_dist * (~anchor_negative_mask).float()
        hardest_negative_dist, _ = anchor_negative_dist.min(dim=1, keepdim=True)
        
        # 计算triplet loss
        triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin)
        
        return triplet_loss.mean()
    
    def _batch_semi_hard_triplet_loss(self, dist_mat, labels):
        """
        Semi-hard策略
        选择满足 d(a,p) < d(a,n) < d(a,p) + margin 的负样本
        """
        anchor_positive_mask = self._get_anchor_positive_mask(labels)
        anchor_negative_mask = self._get_anchor_negative_mask(labels)
        
        # 对每个anchor-positive对
        anchor_positive_dist = dist_mat.unsqueeze(2)
        anchor_negative_dist = dist_mat.unsqueeze(1)
        
        # Semi-hard条件: d(a,p) < d(a,n) < d(a,p) + margin
        semi_hard_mask = (anchor_negative_dist > anchor_positive_dist) & \
                        (anchor_negative_dist < anchor_positive_dist + self.margin)
        
        # 结合三元组有效性mask
        mask = anchor_positive_mask.unsqueeze(2) & \
               anchor_negative_mask.unsqueeze(1) & \
               semi_hard_mask
        
        triplet_loss = anchor_positive_dist - anchor_negative_dist + self.margin
        triplet_loss = triplet_loss * mask.float()
        triplet_loss = F.relu(triplet_loss)
        
        num_positive_triplets = (triplet_loss > 1e-16).float().sum()
        loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
        
        return loss


# 使用示例
def triplet_loss_example():
    # 创建loss
    criterion = TripletLoss(margin=0.2, mining='batch_hard')
    
    # 模拟数据
    batch_size = 32
    embedding_dim = 128
    
    embeddings = torch.randn(batch_size, embedding_dim)
    embeddings = F.normalize(embeddings, p=2, dim=1)  # L2归一化
    labels = torch.randint(0, 8, (batch_size,))  # 8个类别
    
    # 计算loss
    loss = criterion(embeddings, labels)
    print(f"Triplet Loss: {loss.item():.4f}")

4.2 ArcFace Loss实现

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class ArcFaceLoss(nn.Module):
    """
    ArcFace Loss (Additive Angular Margin Loss)
    
    论文: ArcFace: Additive Angular Margin Loss for Deep Face Recognition
    
    L = -log(exp(s*cos(θ_y + m)) / (exp(s*cos(θ_y + m)) + Σexp(s*cos(θ_j))))
    """
    
    def __init__(self, in_features, out_features, s=64.0, m=0.50, easy_margin=False):
        """
        Args:
            in_features: 输入特征维度(embedding维度)
            out_features: 输出类别数
            s: 缩放因子 (scale)
            m: 角度间隔 (margin),弧度制,0.5 rad ≈ 28.6°
            easy_margin: 是否使用easy margin
        """
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.easy_margin = easy_margin
        
        # 可学习的类别权重矩阵
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        
        # 预计算常量
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)  # 阈值
        self.mm = math.sin(math.pi - m) * m
    
    def forward(self, embeddings, labels):
        """
        Args:
            embeddings: [B, in_features] L2归一化的特征向量
            labels: [B] 类别标签
        Returns:
            loss: ArcFace损失
        """
        # 归一化权重
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        
        # 归一化输入(如果还没有归一化)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        
        # 计算 cos(θ) = x · W^T
        # 由于都是归一化的,内积就是余弦值
        cos_theta = F.linear(embeddings_norm, weight_norm)
        cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)  # 数值稳定
        
        # 计算 sin(θ)
        sin_theta = torch.sqrt(1.0 - cos_theta.pow(2))
        
        # 计算 cos(θ + m) = cos(θ)cos(m) - sin(θ)sin(m)
        cos_theta_m = cos_theta * self.cos_m - sin_theta * self.sin_m
        
        if self.easy_margin:
            # easy margin: 当 cos(θ) > 0 时才加margin
            cos_theta_m = torch.where(cos_theta > 0, cos_theta_m, cos_theta)
        else:
            # 标准ArcFace: 当 cos(θ) > cos(π - m) 时才加margin
            # 否则使用线性近似
            cos_theta_m = torch.where(cos_theta > self.th, 
                                      cos_theta_m, 
                                      cos_theta - self.mm)
        
        # 构建one-hot标签
        one_hot = torch.zeros_like(cos_theta)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)
        
        # 只对正确类别加margin
        output = one_hot * cos_theta_m + (1.0 - one_hot) * cos_theta
        
        # 缩放
        output *= self.s
        
        # 交叉熵损失
        loss = F.cross_entropy(output, labels)
        
        return loss
    
    def get_logits(self, embeddings):
        """仅获取logits,用于推理时验证"""
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        cos_theta = F.linear(embeddings_norm, weight_norm)
        return cos_theta * self.s


class CombinedMarginLoss(nn.Module):
    """
    统一的Margin Loss框架
    支持ArcFace、CosFace、SphereFace及其组合
    
    cos(m1 * θ + m2) - m3
    
    ArcFace:    m1=1, m2=0.5, m3=0
    CosFace:    m1=1, m2=0,   m3=0.35
    SphereFace: m1=4, m2=0,   m3=0
    """
    
    def __init__(self, in_features, out_features, s=64.0, m1=1.0, m2=0.5, m3=0.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m1 = m1
        self.m2 = m2
        self.m3 = m3
        
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    
    def forward(self, embeddings, labels):
        weight_norm = F.normalize(self.weight, p=2, dim=1)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)
        
        cos_theta = F.linear(embeddings_norm, weight_norm)
        cos_theta = cos_theta.clamp(-1.0 + 1e-7, 1.0 - 1e-7)
        
        theta = torch.acos(cos_theta)
        
        # cos(m1 * θ + m2) - m3
        target_logits = torch.cos(self.m1 * theta + self.m2) - self.m3
        
        one_hot = torch.zeros_like(cos_theta)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)
        
        output = one_hot * target_logits + (1.0 - one_hot) * cos_theta
        output *= self.s
        
        return F.cross_entropy(output, labels)

4.3 人脸识别网络

import torch
import torch.nn as nn
import torchvision.models as models


class FaceRecognitionNet(nn.Module):
    """
    人脸识别网络
    Backbone + Embedding Layer + Loss Head
    """
    
    def __init__(self, 
                 backbone='resnet50',
                 embedding_dim=512,
                 num_classes=10000,
                 loss_type='arcface',
                 pretrained=True):
        """
        Args:
            backbone: 骨干网络类型
            embedding_dim: embedding维度
            num_classes: 训练时的类别数
            loss_type: 'arcface', 'cosface', 'triplet'
            pretrained: 是否使用预训练权重
        """
        super().__init__()
        
        # 骨干网络
        self.backbone = self._build_backbone(backbone, pretrained)
        
        # 获取backbone输出维度
        with torch.no_grad():
            dummy = torch.zeros(1, 3, 112, 112)
            backbone_out_dim = self.backbone(dummy).shape[1]
        
        # Embedding层
        self.embedding = nn.Sequential(
            nn.Linear(backbone_out_dim, embedding_dim),
            nn.BatchNorm1d(embedding_dim)
        )
        
        # Loss头
        self.loss_type = loss_type
        if loss_type == 'arcface':
            self.loss_head = ArcFaceLoss(embedding_dim, num_classes, s=64.0, m=0.5)
        elif loss_type == 'cosface':
            self.loss_head = CombinedMarginLoss(embedding_dim, num_classes, 
                                                s=64.0, m1=1.0, m2=0.0, m3=0.35)
        elif loss_type == 'triplet':
            self.loss_head = TripletLoss(margin=0.2, mining='batch_hard')
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")
        
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
    
    def _build_backbone(self, backbone_name, pretrained):
        """构建骨干网络"""
        if backbone_name == 'resnet50':
            backbone = models.resnet50(pretrained=pretrained)
            # 移除最后的FC层,保留到avgpool
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        
        elif backbone_name == 'resnet34':
            backbone = models.resnet34(pretrained=pretrained)
            backbone = nn.Sequential(*list(backbone.children())[:-1])
        
        elif backbone_name == 'mobilenet_v2':
            backbone = models.mobilenet_v2(pretrained=pretrained)
            backbone.classifier = nn.Identity()
        
        elif backbone_name == 'iresnet50':
            # InsightFace的IResNet
            backbone = IResNet50()
        
        else:
            raise ValueError(f"Unknown backbone: {backbone_name}")
        
        return backbone
    
    def extract_embedding(self, x):
        """
        提取人脸特征(用于推理)
        
        Args:
            x: [B, 3, H, W] 输入图像
        Returns:
            embedding: [B, embedding_dim] L2归一化的特征向量
        """
        # 骨干网络
        features = self.backbone(x)
        features = features.flatten(1)
        
        # Embedding层
        embedding = self.embedding(features)
        
        # L2归一化
        embedding = F.normalize(embedding, p=2, dim=1)
        
        return embedding
    
    def forward(self, x, labels=None):
        """
        前向传播
        
        训练时:返回loss
        推理时:返回embedding
        """
        embedding = self.extract_embedding(x)
        
        if labels is not None:
            # 训练模式
            if self.loss_type == 'triplet':
                loss = self.loss_head(embedding, labels)
            else:
                loss = self.loss_head(embedding, labels)
            return loss, embedding
        else:
            # 推理模式
            return embedding


class IResNet50(nn.Module):
    """
    InsightFace的IResNet50
    针对人脸识别优化的ResNet变体
    """
    
    def __init__(self, num_features=512, dropout=0.0):
        super().__init__()
        
        # 使用标准ResNet50作为基础
        resnet = models.resnet50(pretrained=False)
        
        # 修改第一个卷积层(适应112x112输入)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = resnet.bn1
        self.prelu = nn.PReLU(64)
        
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
        
        self.bn2 = nn.BatchNorm2d(2048)
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(2048 * 7 * 7, num_features)
        self.bn3 = nn.BatchNorm1d(num_features)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.bn2(x)
        x = self.dropout(x)
        x = x.flatten(1)
        x = self.fc(x)
        x = self.bn3(x)
        
        return x

4.4 训练流程

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class FaceRecognitionTrainer:
    """人脸识别训练器"""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # 构建模型
        self.model = FaceRecognitionNet(
            backbone=config['backbone'],
            embedding_dim=config['embedding_dim'],
            num_classes=config['num_classes'],
            loss_type=config['loss_type'],
            pretrained=config['pretrained']
        ).to(self.device)
        
        # 优化器
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config['lr'],
            momentum=0.9,
            weight_decay=config['weight_decay']
        )
        
        # 学习率调度
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=config['lr_milestones'],
            gamma=0.1
        )
        
        # 混合精度训练
        self.scaler = GradScaler()
        
        # 最佳指标
        self.best_acc = 0.0
    
    def train_epoch(self, train_loader, epoch):
        """训练一个epoch"""
        self.model.train()
        total_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for batch_idx, (images, labels) in enumerate(pbar):
            images = images.to(self.device)
            labels = labels.to(self.device)
            
            # 混合精度前向传播
            with autocast():
                loss, embeddings = self.model(images, labels)
            
            # 反向传播
            self.optimizer.zero_grad()
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            
            # 计算训练准确率(对于ArcFace)
            if hasattr(self.model.loss_head, 'get_logits'):
                with torch.no_grad():
                    logits = self.model.loss_head.get_logits(embeddings)
                    _, predicted = logits.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.*correct/max(total,1):.2f}%'
            })
        
        return total_loss / len(train_loader), 100. * correct / max(total, 1)
    
    @torch.no_grad()
    def validate(self, val_loader):
        """验证(计算特征用于评估)"""
        self.model.eval()
        
        all_embeddings = []
        all_labels = []
        
        for images, labels in tqdm(val_loader, desc='Validating'):
            images = images.to(self.device)
            embeddings = self.model.extract_embedding(images)
            
            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels)
        
        all_embeddings = torch.cat(all_embeddings, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        
        # 计算验证指标(如LFW准确率)
        acc = self.compute_verification_accuracy(all_embeddings, all_labels)
        
        return acc
    
    def compute_verification_accuracy(self, embeddings, labels):
        """
        计算1:1验证准确率
        """
        # 简化版:计算同类和异类的相似度分布
        embeddings = F.normalize(embeddings, p=2, dim=1)
        similarity_matrix = torch.mm(embeddings, embeddings.t())
        
        # 同类对
        same_mask = labels.unsqueeze(0) == labels.unsqueeze(1)
        same_mask.fill_diagonal_(False)  # 排除自己
        
        if same_mask.sum() > 0:
            same_sim = similarity_matrix[same_mask].mean().item()
        else:
            same_sim = 0
        
        # 异类对
        diff_mask = ~same_mask
        diff_mask.fill_diagonal_(False)
        
        if diff_mask.sum() > 0:
            diff_sim = similarity_matrix[diff_mask].mean().item()
        else:
            diff_sim = 0
        
        # 简单的阈值准确率估计
        threshold = (same_sim + diff_sim) / 2
        
        correct_same = (similarity_matrix[same_mask] > threshold).float().mean().item()
        correct_diff = (similarity_matrix[diff_mask] < threshold).float().mean().item()
        
        accuracy = (correct_same + correct_diff) / 2 * 100
        
        logger.info(f"Same similarity: {same_sim:.4f}, Diff similarity: {diff_sim:.4f}")
        logger.info(f"Threshold: {threshold:.4f}, Accuracy: {accuracy:.2f}%")
        
        return accuracy
    
    def train(self, train_loader, val_loader, num_epochs):
        """完整训练流程"""
        for epoch in range(1, num_epochs + 1):
            # 训练
            train_loss, train_acc = self.train_epoch(train_loader, epoch)
            logger.info(f"Epoch {epoch}: Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%")
            
            # 更新学习率
            self.scheduler.step()
            logger.info(f"Learning rate: {self.scheduler.get_last_lr()[0]:.6f}")
            
            # 验证
            if epoch % self.config['val_interval'] == 0:
                val_acc = self.validate(val_loader)
                
                # 保存最佳模型
                if val_acc > self.best_acc:
                    self.best_acc = val_acc
                    self.save_checkpoint(f'best_model.pth')
                    logger.info(f"New best model! Accuracy: {val_acc:.2f}%")
            
            # 定期保存
            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(f'checkpoint_epoch_{epoch}.pth')
    
    def save_checkpoint(self, filename):
        """保存检查点"""
        os.makedirs(self.config['save_dir'], exist_ok=True)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_acc': self.best_acc
        }, os.path.join(self.config['save_dir'], filename))


def main():
    """主函数"""
    config = {
        'backbone': 'resnet50',
        'embedding_dim': 512,
        'num_classes': 85742,  # MS1MV2数据集的类别数
        'loss_type': 'arcface',
        'pretrained': True,
        
        'lr': 0.1,
        'weight_decay': 5e-4,
        'lr_milestones': [10, 18, 22],
        
        'batch_size': 64,
        'num_epochs': 25,
        'val_interval': 1,
        'save_interval': 5,
        'save_dir': './checkpoints',
        
        'num_workers': 8
    }
    
    # 创建数据加载器(需要自己实现)
    # train_loader = ...
    # val_loader = ...
    
    # 训练
    trainer = FaceRecognitionTrainer(config)
    # trainer.train(train_loader, val_loader, config['num_epochs'])
    
    print("Training completed!")


if __name__ == '__main__':
    main()

4.5 推理与特征比对

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms


class FaceRecognizer:
    """人脸识别推理器"""
    
    def __init__(self, model_path, backbone='resnet50', embedding_dim=512, device='cuda'):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        
        # 加载模型
        self.model = FaceRecognitionNet(
            backbone=backbone,
            embedding_dim=embedding_dim,
            num_classes=1,  # 推理时不需要分类头
            loss_type='arcface',
            pretrained=False
        )
        
        # 加载权重
        checkpoint = torch.load(model_path, map_location=self.device)
        # 只加载backbone和embedding的权重
        state_dict = {k: v for k, v in checkpoint['model_state_dict'].items() 
                     if not k.startswith('loss_head')}
        self.model.load_state_dict(state_dict, strict=False)
        
        self.model.to(self.device)
        self.model.eval()
        
        # 预处理
        self.transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
    
    def preprocess(self, image):
        """
        图像预处理
        Args:
            image: PIL Image 或 numpy array (BGR)
        """
        if isinstance(image, np.ndarray):
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(image)
        
        return self.transform(image).unsqueeze(0)
    
    @torch.no_grad()
    def extract_feature(self, image):
        """
        提取人脸特征
        
        Args:
            image: 人脸图像
        Returns:
            embedding: 512维归一化特征向量
        """
        img_tensor = self.preprocess(image).to(self.device)
        embedding = self.model.extract_embedding(img_tensor)
        return embedding.cpu().numpy().flatten()
    
    @torch.no_grad()
    def extract_features_batch(self, images):
        """批量提取特征"""
        tensors = torch.stack([self.preprocess(img).squeeze(0) for img in images])
        tensors = tensors.to(self.device)
        embeddings = self.model.extract_embedding(tensors)
        return embeddings.cpu().numpy()
    
    @staticmethod
    def cosine_similarity(feat1, feat2):
        """余弦相似度"""
        return np.dot(feat1, feat2)
    
    @staticmethod
    def euclidean_distance(feat1, feat2):
        """欧氏距离"""
        return np.linalg.norm(feat1 - feat2)
    
    def verify(self, image1, image2, threshold=0.5):
        """
        1:1人脸验证
        
        Returns:
            is_same: 是否为同一人
            similarity: 相似度分数
        """
        feat1 = self.extract_feature(image1)
        feat2 = self.extract_feature(image2)
        
        similarity = self.cosine_similarity(feat1, feat2)
        is_same = similarity >= threshold
        
        return is_same, similarity
    
    def identify(self, query_image, gallery_features, gallery_labels, threshold=0.5):
        """
        1:N人脸识别
        
        Args:
            query_image: 查询图像
            gallery_features: 底库特征 [N, 512]
            gallery_labels: 底库标签 [N]
            threshold: 识别阈值
        Returns:
            identity: 识别结果,None表示未识别
            similarity: 相似度分数
        """
        query_feat = self.extract_feature(query_image)
        
        # 计算与所有底库特征的相似度
        similarities = np.dot(gallery_features, query_feat)
        
        # 找最相似的
        max_idx = np.argmax(similarities)
        max_similarity = similarities[max_idx]
        
        if max_similarity >= threshold:
            return gallery_labels[max_idx], max_similarity
        else:
            return None, max_similarity


# 使用示例
def demo():
    # 初始化
    recognizer = FaceRecognizer(
        model_path='checkpoints/best_model.pth',
        backbone='resnet50',
        embedding_dim=512
    )
    
    # 1:1验证
    img1 = cv2.imread('person1_a.jpg')
    img2 = cv2.imread('person1_b.jpg')
    
    is_same, similarity = recognizer.verify(img1, img2)
    print(f"Same person: {is_same}, Similarity: {similarity:.4f}")
    
    # 1:N识别
    # 构建底库
    gallery_images = [cv2.imread(f'gallery/{i}.jpg') for i in range(10)]
    gallery_labels = ['Alice', 'Bob', 'Charlie', ...]
    gallery_features = recognizer.extract_features_batch(gallery_images)
    
    # 查询
    query_img = cv2.imread('query.jpg')
    identity, similarity = recognizer.identify(
        query_img, gallery_features, gallery_labels, threshold=0.5
    )
    print(f"Identity: {identity}, Similarity: {similarity:.4f}")

五、FaceNet vs ArcFace 深度对比

5.1 核心差异

┌─────────────────────────────────────────────────────────────────────┐
│                    FaceNet vs ArcFace 对比                          │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  维度          │  FaceNet (Triplet)    │  ArcFace (Angular Margin)  │
│ ───────────────┼───────────────────────┼────────────────────────────│
│  损失函数      │  Triplet Loss         │  Softmax + Angular Margin  │
│  优化目标      │  相对距离约束         │  绝对角度间隔              │
│  训练信号      │  每次一个三元组       │  所有类别参与              │
│  收敛速度      │  慢                   │  快                        │
│  实现复杂度    │  需要triplet mining   │  简单直接                  │
│  超参数        │  margin, mining策略   │  s, m                      │
│  类别中心      │  隐式学习             │  显式存储(权重矩阵)      │
│  可扩展性      │  好(无需类别权重)   │  需要存储大权重矩阵        │
│  性能          │  LFW ~99.6%           │  LFW ~99.8%                │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

5.2 数学视角对比

Triplet Loss的优化目标:

    对于每个三元组 (A, P, N):
    ||f(A) - f(P)||² + α < ||f(A) - f(N)||²
    
    只约束相对关系:正样本比负样本更近
    没有约束绝对位置


ArcFace的优化目标:

    对于每个样本 x 属于类别 y:
    cos(θ_y + m) > cos(θ_j), ∀j ≠ y
    
    等价于:θ_y + m < θ_j
    即:与正确类别的角度比其他类别小至少 m
    
    这是一个绝对约束,强制类内紧凑、类间分离

5.3 特征空间可视化

Triplet Loss训练后的特征分布:

                  ●
               ●     ●
            ●    A     ●        类A分布较散
               ●   ●
                 ●
                              ▲
                           ▲    ▲
                        ▲    B    ▲    类B也较散
                           ▲  ▲
                              ▲
    
    类内方差较大,类间有分离但不够紧凑


ArcFace训练后的特征分布:

              ●●●
             ● A ●        类A非常紧凑
              ●●●
                              
                              ▲▲▲
                             ▲ B ▲    类B也非常紧凑
                              ▲▲▲
    
    类内非常紧凑,类间有明确的角度间隔

5.4 实际选择建议

选择FaceNet/Triplet Loss的场景:

✓ 类别数极大(>100万)
  - ArcFace需要存储 [num_classes, embedding_dim] 的权重矩阵
  - 100万类别 × 512维 = 2GB显存

✓ 开放集识别
  - 不需要预定义所有类别
  - 只需要学习"相似性"的概念

✓ 跨域迁移
  - Triplet学到的相似性更通用


选择ArcFace的场景:

✓ 追求最高精度
  - ArcFace在主流benchmark上性能最佳

✓ 类别数可控(<10万)
  - 显存充足时首选ArcFace

✓ 快速收敛
  - 比Triplet Loss收敛快得多

✓ 工业部署
  - 训练稳定,超参数少

六、总结

6.1 核心要点

FaceNet (Triplet Loss)

  • 直接优化embedding空间的相对距离
  • 需要精心设计的triplet mining策略
  • 优点:可扩展性好,适合超大规模类别
  • 缺点:收敛慢,对采样敏感

ArcFace (Angular Margin)

  • 在角度空间加入加性margin
  • 训练简单,收敛快
  • 优点:精度最高,训练稳定
  • 缺点:需要存储类别权重矩阵

6.2 现代最佳实践

2024年人脸识别最佳实践:

1. 骨干网络: IResNet100 / EfficientNet
2. 损失函数: ArcFace (s=64, m=0.5) 或 AdaFace
3. 数据增强: 随机裁剪、颜色抖动、MixUp
4. 训练策略: 
   - 大batch (≥512)
   - Cosine学习率衰减
   - 混合精度训练
5. 后处理: 特征归一化、PCA白化(可选)

6.3 一句话总结

FaceNet告诉我们"应该学什么"(学习相似性),ArcFace告诉我们"怎么学得更好"(角度间隔约束)。

希望这篇文章帮助你深入理解了人脸识别的核心算法。如有问题,欢迎评论区交流!


参考文献

  1. Schroff F, et al. “FaceNet: A Unified Embedding for Face Recognition and Clustering.” CVPR 2015.
  2. Deng J, et al. “ArcFace: Additive Angular Margin Loss for Deep Face Recognition.” CVPR 2019.
  3. Wang H, et al. “CosFace: Large Margin Cosine Loss for Deep Face Recognition.” CVPR 2018.
  4. Liu W, et al. “SphereFace: Deep Hypersphere Embedding for Face Recognition.” CVPR 2017.

作者:Jia

更多技术文章,欢迎关注我的CSDN博客!

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐