深度学习模块实践手册(第七期)
本文介绍了五种改进的视觉Transformer模块: CloFormer:通过两分支结构(局部卷积分支和全局注意力分支)实现高效特征提取,适用于轻量化移动端模型。核心是AttnConv算子,结合共享权重和上下文感知权重捕获局部信息。 BiFormer:采用双层路由注意力机制(BRA),先进行区域级路由筛选相关区域,再执行细粒度token级注意力,显著降低计算复杂度。 STViT:引入类似超像素的&
31、Cloattention模块
论文《Rethinking Local Perception in Lightweight Vision Transformer》
1、作用
CloFormer(Context-aware Local Enhancement Vision Transformer)是一种轻量级的视觉Transformer,用于在保持模型轻量化的同时,提高在各种视觉任务中的性能,包括图像分类、目标检测和语义分割。其主要目的是提升移动设备上的视觉模型性能,克服直接缩减标准ViT(Vision Transformer)模型尺寸导致的性能下降问题。
2、机制
CloFormer通过引入AttnConv(Attention Style Convolution Operator)来实现上下文感知的局部增强,从而有效捕获高频局部信息。该模型采用两分支结构:
1、局部分支:
利用AttnConv融合共享权重和上下文感知权重来聚合高频局部信息。首先,使用深度可分离卷积(Depthwise Convolution,DWconv)提取局部表示,然后部署上下文感知权重来增强局部特征。
2、全局分支:
采用标准的注意力机制,通过对K和V进行下采样来降低FLOPs,帮助模型捕捉低频全局信息。
3、独特优势
1、上下文感知的局部增强:
通过AttnConv,CloFormer有效地结合了共享权重和上下文感知权重的优势,实现了高质量的局部特征增强。
2、两分支结构:
通过同时捕获高频和低频信息,模型能够在不同的视觉任务中达到更好的性能。
3、轻量化设计:
CloFormer专为移动设备设计,通过精心的模型架构设计和权重共享机制,实现了在保持轻量化的同时提高模型性能。
4、代码
import torch
import torch.nn as nn
from efficientnet_pytorch.model import MemoryEfficientSwish
class AttnMap(nn.Module):
"""
注意力图生成模块
功能:通过1x1卷积和Swish激活生成注意力权重映射
"""
def __init__(self, dim):
super().__init__()
self.act_block = nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), # 通道内特征调整
MemoryEfficientSwish(), # 高效Swish激活(节省内存)
nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0) # 生成注意力权重
)
def forward(self, x):
return self.act_block(x)
class EfficientAttention(nn.Module):
"""
高效注意力模块(Efficient Attention)
核心思想:通过分离高频和低频注意力,在保持性能的同时降低计算复杂度
"""
def __init__(self,
dim,
num_heads=8,
group_split=[4, 4], # 高频头数 + 低频头数 = 总头数
kernel_sizes=[5], # 高频注意力的卷积核尺寸
window_size=4, # 低频注意力的窗口大小
attn_drop=0.,
proj_drop=0.,
qkv_bias=True):
super().__init__()
# 参数校验
assert sum(group_split) == num_heads, "分组头数之和必须等于总头数"
assert len(kernel_sizes) + 1 == len(group_split), "核大小数量+1必须等于分组数量"
self.dim = dim
self.num_heads = num_heads
self.dim_head = dim // num_heads # 每个头的维度
self.scalor = self.dim_head **-0.5 # 注意力缩放因子
self.kernel_sizes = kernel_sizes
self.window_size = window_size
self.group_split = group_split
# 1. 高频注意力分支(处理细节特征)
self.high_freq_branches = nn.ModuleList()
self.convs = nn.ModuleList()
self.act_blocks = nn.ModuleList()
self.qkvs = nn.ModuleList()
for i in range(len(kernel_sizes)):
kernel_size = kernel_sizes[i]
group_head = group_split[i]
if group_head == 0:
continue
# 高频分支组件:QKV生成 + 深度卷积 + 注意力图生成
self.qkvs.append(nn.Conv2d(
dim, 3 * group_head * self.dim_head,
kernel_size=1, stride=1, padding=0, bias=qkv_bias
))
# 深度卷积(分组数=输入通道数,确保每个通道独立处理)
self.convs.append(nn.Conv2d(
3 * self.dim_head * group_head,
3 * self.dim_head * group_head,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
groups=3 * self.dim_head * group_head # 深度卷积
))
self.act_blocks.append(AttnMap(self.dim_head * group_head))
# 2. 低频注意力分支(处理全局特征)
self.low_freq_enabled = group_split[-1] != 0
if self.low_freq_enabled:
self.global_q = nn.Conv2d(
dim, group_split[-1] * self.dim_head,
kernel_size=1, stride=1, padding=0, bias=qkv_bias
)
self.global_kv = nn.Conv2d(
dim, group_split[-1] * self.dim_head * 2,
kernel_size=1, stride=1, padding=0, bias=qkv_bias
)
# 窗口平均池化(降低低频分支的分辨率)
self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity()
# 3. 输出投影与正则化
self.proj = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def high_fre_attntion(self, x, qkv_layer, conv_layer, attn_block):
"""高频注意力计算(细节特征)"""
b, c, h, w = x.shape
# 生成QKV并通过深度卷积处理
qkv = qkv_layer(x) # [B, 3*G*D, H, W]
qkv = conv_layer(qkv) # 深度卷积提取局部特征
qkv = qkv.reshape(b, 3, -1, h, w).transpose(0, 1) # [3, B, G*D, H, W]
q, k, v = qkv # 拆分Q、K、V
# 计算注意力权重并应用
attn = attn_block(q * k) * self.scalor # 生成注意力图
attn = self.attn_drop(torch.tanh(attn)) # tanh激活(限制权重范围)
return attn * v # 注意力加权V
def low_fre_attention(self, x):
"""低频注意力计算(全局特征)"""
b, c, h, w = x.shape
# 生成Q(原始分辨率)
q = self.global_q(x).reshape(
b, -1, self.dim_head, h * w
).transpose(-1, -2) # [B, G, H*W, D]
# 生成KV(下采样后,降低计算量)
kv = self.avgpool(x) # [B, C, H/win, W/win]
kv = self.global_kv(kv).view(
b, 2, -1, self.dim_head,
(h * w) // (self.window_size** 2)
).permute(1, 0, 2, 4, 3) # [2, B, G, (H*W)/win², D]
k, v = kv
# 计算全局注意力
attn = self.scalor * q @ k.transpose(-1, -2) # [B, G, H*W, (H*W)/win²]
attn = self.attn_drop(attn.softmax(dim=-1)) # 归一化
res = attn @ v # [B, G, H*W, D]
# 重塑回原始尺寸
return res.transpose(2, 3).reshape(b, -1, h, w)
def forward(self, x):
"""
Args:
x: 输入特征图,形状为 [batch_size, dim, height, width]
Returns:
注意力增强后的特征图
"""
outputs = []
# 1. 处理高频分支
for i in range(len(self.qkvs)):
if self.group_split[i] == 0:
continue
outputs.append(self.high_fre_attntion(
x, self.qkvs[i], self.convs[i], self.act_blocks[i]
))
# 2. 处理低频分支(若启用)
if self.low_freq_enabled:
outputs.append(self.low_fre_attention(x))
# 3. 融合所有分支并投影
fused = torch.cat(outputs, dim=1)
out = self.proj(fused) # 特征投影
out = self.proj_drop(out)
return out
# 测试代码
if __name__ == '__main__':
# 实例化高效注意力模块(8头,4高频+4低频)
attn = EfficientAttention(
dim=64,
num_heads=8,
group_split=[4, 4], # 4个高频头 + 4个低频头
kernel_sizes=[5], # 高频分支使用5x5卷积
window_size=4
).cuda()
# 输入特征图:[batch=1, dim=64, height=64, width=64]
x = torch.randn(1, 64, 64, 64).cuda()
# 前向传播
out = attn(x)
# 验证输出形状
print(f"输入形状: {x.shape}")
print(f"输出形状: {out.shape}") # 输出: torch.Size([1, 64, 64, 64])
32、BiFormer模块
论文《BiFormer: Vision Transformer with Bi-Level Routing Attention》
1、作用
BiFormer旨在解决视觉Transformer在处理图像时的计算和内存效率问题。它通过引入双层路由注意力(Bi-Level Routing Attention, BRA),实现了动态的、基于内容的稀疏注意力机制,以更灵活、高效地分配计算资源。
2、机制
BiFormer的核心是双层路由注意力(BRA),该机制包含两个主要步骤:区域到区域的路由和令牌到令牌的注意力。首先,通过构建一个区域级别的关联图并对其进行修剪,来确定哪些区域是相关的,并应该被进一步考虑。其次,在这些选定的区域中,应用细粒度的令牌到令牌注意力,以便每个查询仅与少数最相关的键-值对进行交互。这种方法允许BiFormer动态地关注图像中与特定查询最相关的部分,而不是在所有空间位置上计算成对的令牌交互,从而显著减少了计算复杂度和内存占用。
3、独特优势
1、计算效率:
BiFormer通过其双层路由注意力机制,实现了与传统全局注意力相比的显著计算和内存效率改进,具体体现在能够动态地仅对最相关的令牌子集进行计算。
2、动态稀疏性:
与其他稀疏注意力方法不同,BiFormer能够根据内容动态选择关注的区域和令牌,使其能够更有效地处理各种视觉任务。
3、高性能:
实验结果表明,BiFormer在图像分类、对象检测和语义分割等多个视觉任务上实现了优异的性能,尤其是在与模型大小和计算复杂度相当的情况下,其性能超越了现有的最先进方法。
4、代码
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, LongTensor
class TopkRouting(nn.Module):
"""
顶层K路由模块
功能:生成区域间的路由权重和索引,实现粗粒度的特征选择
"""
def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
super().__init__()
self.topk = topk
self.qk_dim = qk_dim
self.scale = qk_scale or qk_dim **-0.5 # 缩放因子
self.diff_routing = diff_routing # 是否启用可微分路由
# 可学习的嵌入层(若启用参数化路由)
self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
self.routing_act = nn.Softmax(dim=-1) # 路由权重归一化
def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor, Tensor]:
# 若不启用可微分路由,使用detach切断梯度
if not self.diff_routing:
query, key = query.detach(), key.detach()
# 嵌入投影与注意力计算
query_hat, key_hat = self.emb(query), self.emb(key)
attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)
# 取Top-K路由索引与权重
topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
r_weight = self.routing_act(topk_attn_logit) # 路由权重归一化
return r_weight, topk_index
class KVGather(nn.Module):
"""
键值对聚集模块
功能:根据路由索引选择对应的键值对,并应用路由权重
"""
def __init__(self, mul_weight='none'):
super().__init__()
assert mul_weight in ['none', 'soft', 'hard'], "mul_weight必须为'none'/'soft'/'hard'"
self.mul_weight = mul_weight # 权重应用方式
def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
n, p2, w2, c_kv = kv.size()
topk = r_idx.size(-1)
# 扩展索引维度以匹配KV形状
r_idx_expanded = r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
# 按索引聚集KV
topk_kv = torch.gather(
kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
dim=2,
index=r_idx_expanded
)
# 应用路由权重(仅soft模式)
if self.mul_weight == 'soft':
topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv
elif self.mul_weight == 'hard':
raise NotImplementedError("硬路由的可微分实现待完善")
return topk_kv # 形状: (n, p², topk, w2, c_kv)
class QKVLinear(nn.Module):
"""
QKV线性投影模块
功能:将输入特征投影到查询(Q)、键(K)、值(V)空间
"""
def __init__(self, dim, qk_dim, bias=True):
super().__init__()
self.dim = dim
self.qk_dim = qk_dim
# 单次线性投影生成QKV(Q: qk_dim,KV: qk_dim + dim)
self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
def forward(self, x):
q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
return q, kv
class BiLevelRoutingAttention(nn.Module):
"""
双层路由注意力模块(Bi-Level Routing Attention)
论文:《SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation》
核心思想:通过区域级路由和像素级注意力的双层机制,高效捕获长距离依赖
"""
def __init__(self,
dim,
n_win=7, # 区域划分数量(沿高/宽方向)
num_heads=8,
qk_dim=None,
qk_scale=None,
kv_per_win=4, # 每个区域的KV采样数
kv_downsample_ratio=4, # KV下采样比率
kv_downsample_mode='identity', # KV下采样方式
topk=4, # 区域路由的Top-K选择
param_attention="qkvo", # 注意力参数化模式
param_routing=False, # 路由是否参数化
diff_routing=False, # 路由是否可微分
soft_routing=False, # 是否启用软路由
side_dwconv=3, # 局部增强卷积核大小
auto_pad=True): # 是否自动填充以匹配区域划分
super().__init__()
self.dim = dim
self.n_win = n_win
self.num_heads = num_heads
self.qk_dim = qk_dim or dim
# 校验维度可分性
assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, \
"qk_dim和dim必须能被num_heads整除!"
self.scale = qk_scale or self.qk_dim ** -0.5
self.side_dwconv = side_dwconv
self.auto_pad = auto_pad
# 局部位置增强(LEPE)
self.lepe = nn.Conv2d(
dim, dim,
kernel_size=side_dwconv,
stride=1,
padding=side_dwconv // 2,
groups=dim
) if side_dwconv > 0 else lambda x: torch.zeros_like(x)
# 路由与聚集模块
self.topk = topk
self.param_routing = param_routing
self.diff_routing = diff_routing
self.soft_routing = soft_routing
assert not (self.param_routing and not self.diff_routing), \
"参数化路由必须启用可微分模式"
# 初始化路由模块
self.router = TopkRouting(
qk_dim=self.qk_dim,
qk_scale=self.scale,
topk=topk,
diff_routing=diff_routing,
param_routing=param_routing
)
# 初始化KV聚集模块
if soft_routing:
mul_weight = 'soft'
elif diff_routing:
mul_weight = 'hard'
else:
mul_weight = 'none'
self.kv_gather = KVGather(mul_weight=mul_weight)
# 注意力参数化配置
self.param_attention = param_attention
if self.param_attention == 'qkvo':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Linear(dim, dim) # 输出投影
elif self.param_attention == 'qkv':
self.qkv = QKVLinear(self.dim, self.qk_dim)
self.wo = nn.Identity() # 无输出投影
else:
raise ValueError(f"不支持的参数化模式: {param_attention}")
# KV下采样配置
self.kv_downsample_mode = kv_downsample_mode
self.kv_per_win = kv_per_win
self.kv_downsample_ratio = kv_downsample_ratio
# 初始化KV下采样层
if kv_downsample_mode == 'ada_avgpool':
assert kv_per_win is not None
self.kv_down = nn.AdaptiveAvgPool2d(kv_per_win)
elif kv_downsample_mode == 'ada_maxpool':
assert kv_per_win is not None
self.kv_down = nn.AdaptiveMaxPool2d(kv_per_win)
elif kv_downsample_mode == 'maxpool':
assert kv_downsample_ratio is not None
self.kv_down = nn.MaxPool2d(kv_downsample_ratio) if kv_downsample_ratio > 1 else nn.Identity()
elif kv_downsample_mode == 'avgpool':
assert kv_downsample_ratio is not None
self.kv_down = nn.AvgPool2d(kv_downsample_ratio) if kv_downsample_ratio > 1 else nn.Identity()
elif kv_downsample_mode == 'identity':
self.kv_down = nn.Identity()
else:
raise ValueError(f"不支持的下采样模式: {kv_downsample_mode}")
def forward(self, x: Tensor, ret_attn_mask=False):
"""
前向传播过程
Args:
x: 输入特征图,形状为 [batch_size, height, width, channels]
ret_attn_mask: 是否返回注意力掩码
Returns:
注意力增强后的特征图,形状为 [batch_size, channels, height, width]
"""
# 自动填充以匹配区域划分
if self.auto_pad:
N, H_in, W_in, C = x.size()
pad_l = pad_t = 0
pad_r = (self.n_win - W_in % self.n_win) % self.n_win
pad_b = (self.n_win - H_in % self.n_win) % self.n_win
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) # 填充宽和高
_, H, W, _ = x.size()
else:
N, H, W, C = x.size()
assert H % self.n_win == 0 and W % self.n_win == 0, \
"输入尺寸必须能被n_win整除(当auto_pad=False时)"
# 1. 区域划分(将特征图划分为n_win×n_win个区域)
x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
q, kv = self.qkv(x) # Q: [n, p², h, w, qk_dim]; KV: [n, p², h, w, qk_dim+dim]
# 2. 像素级QKV处理
q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c') # 像素级Q
# KV下采样与重排
kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
# 3. 区域级路由(粗粒度Top-K区域选择)
q_win = q.mean([2, 3]) # 区域级Q(平均池化像素)
k_win = kv[..., 0:self.qk_dim].mean([2, 3]) # 区域级K
r_weight, r_idx = self.router(q_win, k_win) # 路由权重与索引
# 4. KV聚集(根据路由选择Top-K区域的KV)
kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)
k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1) # 分离K和V
# 5. 像素级注意力计算
# 重塑多头维度
q_pix = rearrange(q_pix, 'n p2 (h w) (h d) -> (n p2 h) p2 (h w) d', h=self.num_heads)
k_pix_sel = rearrange(k_pix_sel, 'n p2 topk (h w) (h d) -> (n p2 h) p2 topk (h w) d', h=self.num_heads)
v_pix_sel = rearrange(v_pix_sel, 'n p2 topk (h w) (h d) -> (n p2 h) p2 topk (h w) d', h=self.num_heads)
# 注意力分数计算
attn = (q_pix * self.scale) @ k_pix_sel.transpose(-2, -1)
attn = attn.softmax(dim=-1)
# 注意力加权求和
out = attn @ v_pix_sel
out = rearrange(out, '(n p2 h) p2 (h w) d -> n (j h) (i w) (h d)',
j=self.n_win, i=self.n_win, h=self.num_heads)
# 6. 局部增强与输出投影
lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)',
j=self.n_win, i=self.n_win).contiguous())
lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
out = out + lepe # 局部位置增强
out = self.wo(out) # 输出投影
# 移除自动填充的部分
if self.auto_pad and (pad_r > 0 or pad_b > 0):
out = out[:, :H_in, :W_in, :].contiguous()
# 调整通道维度到第二维(NCHW格式)
if ret_attn_mask:
return rearrange(out, "n h w c -> n c h w"), r_weight, r_idx, attn
else:
return rearrange(out, "n h w c -> n c h w")
# 测试代码
if __name__ == '__main__':
# 实例化双层路由注意力模块(输入通道256,7×7区域划分)
attn = BiLevelRoutingAttention(
dim=256,
n_win=7,
num_heads=8,
topk=4
).cuda()
# 输入特征图:[batch=1, height=64, width=64, channels=256](注意输入格式为NHWC)
input_tensor = torch.randn(1, 64, 64, 256).cuda()
# 前向传播
output_tensor = attn(input_tensor)
# 验证输出形状(应保持NCHW格式的64×64尺寸)
print(f"输入形状: {input_tensor.shape}")
print(f"输出形状: {output_tensor.shape}") # 输出: torch.Size([1, 256, 64, 64])
33、STVit模块
论文《Vision Transformer with Super Token Sampling》
1、作用
STVit旨在通过改进视觉Transformer的空间-时间效率,解决在处理视频和图像任务时常见的计算冗余问题。该模型尝试减少早期层次捕捉局部特征时的冗余计算,从而减少不必要的计算成本。
2、机制
STVit引入了一种类似于图像处理中“超像素”的概念,称为“超级令牌”(super tokens),以减少自注意力计算中元素的数量,同时保留对全局关系建模的能力。该过程涉及从视觉令牌中采样超级令牌,对这些超级令牌执行自注意力操作,并将它们映射回原始令牌空间。
3、独特优势
STVit在不同的视觉任务中展示了强大的性能,包括图像分类、对象检测和分割,同时拥有更少的参数和较低的计算成本。例如,STVit在没有额外训练数据的情况下,在ImageNet-1K分类任务上达到了86.4%的顶级1准确率,且参数少于100M。
4、代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unfold(nn.Module):
"""
特征展开模块
功能:将输入特征图按局部窗口展开为扁平形式
"""
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
# 初始化固定权重(单位矩阵)用于展开操作
weights = torch.eye(kernel_size ** 2).reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.register_buffer('weights', weights) # 非训练参数
def forward(self, x):
b, c, h, w = x.shape
# 使用卷积实现滑动窗口展开
x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
return x.reshape(b, c * self.kernel_size ** 2, h * w)
class Fold(nn.Module):
"""
特征折叠模块
功能:将展开的特征图恢复为原始形状
"""
def __init__(self, kernel_size=3):
super().__init__()
self.kernel_size = kernel_size
# 初始化固定权重(单位矩阵)用于折叠操作
weights = torch.eye(kernel_size ** 2).reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
self.register_buffer('weights', weights) # 非训练参数
def forward(self, x):
b, c, h, w = x.shape
# 使用转置卷积实现特征折叠
x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
return x
class Attention(nn.Module):
"""
标准自注意力模块
功能:通过查询(Q)、键(K)、值(V)计算特征间的注意力关系
"""
def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.window_size = window_size
# 线性投影生成QKV
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
N = H * W
# 生成QKV并重塑为多头形式
q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3, dim=2)
# 计算注意力分数并应用softmax
attn = (k.transpose(-1, -2) @ q) * self.scale # [B, num_heads, N, N]
attn = attn.softmax(dim=-2)
attn = self.attn_drop(attn)
# 应用注意力权重并重塑回原始维度
x = (v @ attn).reshape(B, C, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class StokenAttention(nn.Module):
"""
空间Token注意力模块(StokenAttention)
论文:《Stoken: Learning Spatial Tokens for Vision Transformers》
核心思想:通过迭代优化空间Token,实现高效的视觉特征表示
"""
def __init__(self,
dim,
stoken_size,
n_iter=1,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.n_iter = n_iter # 迭代优化次数
self.stoken_size = stoken_size # 空间Token尺寸
self.scale = dim ** -0.5
# 模块组件
self.unfold = Unfold(3)
self.fold = Fold(3)
self.stoken_refine = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=proj_drop
)
def stoken_forward(self, x):
"""
空间Token迭代优化流程
"""
B, C, H0, W0 = x.shape
h, w = self.stoken_size
# 自动填充以匹配stoken_size
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
_, _, H, W = x.shape
hh, ww = H // h, W // w
# 1. 生成初始空间Token(通过自适应池化)
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # [B, C, hh, ww]
# 2. 像素特征重组(将特征图划分为局部块)
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
# 3. 迭代优化空间Token(无梯度更新,仅前向传播)
with torch.no_grad():
for idx in range(self.n_iter):
# 展开Token并计算相似度矩阵
stoken_features_unfolded = self.unfold(stoken_features) # [B, C*9, hh*ww]
stoken_features_unfolded = stoken_features_unfolded.transpose(1, 2).reshape(B, hh * ww, C, 9)
# 计算像素与Token间的亲和度矩阵
affinity_matrix = pixel_features @ stoken_features_unfolded * self.scale # [B, hh*ww, h*w, 9]
affinity_matrix = affinity_matrix.softmax(-1) # 归一化
# 计算亲和度矩阵的折叠版本(用于后续归一化)
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
# 更新Token特征(通过像素与亲和度矩阵的加权组合)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix # [B, hh*ww, C, 9]
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # 归一化
# 4. 应用自注意力进一步优化Token特征
stoken_features = self.stoken_refine(stoken_features)
# 5. 展开优化后的Token并映射回像素空间
stoken_features = self.unfold(stoken_features)
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
# 通过亲和度矩阵将Token特征映射回像素
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # [B, hh*ww, C, h*w]
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
# 移除填充部分
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0]
return pixel_features
def direct_forward(self, x):
"""
直接应用注意力(当stoken_size=1x1时)
"""
return self.stoken_refine(x)
def forward(self, x):
"""
前向传播入口
"""
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x)
else:
return self.direct_forward(x)
# 测试代码
if __name__ == '__main__':
# 实例化StokenAttention模块(输入通道64,空间Token尺寸8×8)
stoken_attn = StokenAttention(
dim=64,
stoken_size=[8, 8],
n_iter=1,
num_heads=8
).cuda()
# 输入特征图:[batch=3, channels=64, height=64, width=64]
input_tensor = torch.randn(3, 64, 64, 64).cuda()
# 前向传播
output_tensor = stoken_attn(input_tensor)
# 验证输出形状
print(f"输入形状: {input_tensor.shape}")
print(f"输出形状: {output_tensor.shape}") # 输出: torch.Size([3, 64, 64, 64])
34、IRMB模块
论文《Rethinking Mobile Block for Efficient Attention-based Models》
1、作用
本文提出了一种有效的轻量级模型设计方法,旨在开发现代高效的轻量级模型,用于密集预测任务,同时平衡参数、FLOPs和性能。作者通过重新思考高效的Inverted Residual Block(IRB)和Transformer的有效组件,从统一的视角出发,扩展了基于CNN的IRB到基于Meta attention的模型,并抽象出了一种一次残差的Meta Mobile Block(MMB),用于轻量级模型设计。
2、机制
本研究通过简单但有效的设计准则,提出了一种现代的Inverted Residual Mobile Block(iRMB),并使用iRMB构建了一个类似于ResNet的高效模型(EMO),仅用于下游任务。EMO通过将CNN的效率和Transformer的动态建模能力结合在iRMB中,有效地提高了模型性能。同时,EMO在不引入复杂结构的情况下,实现了与当前最先进的轻量级注意力模型的竞争性能。
3、独特优势
EMO在ImageNet-1K、COCO2017和ADE20K基准上的广泛实验展示了其优越性,例如,EMO-1M/2M/5M分别达到了71.5%、75.1%和78.4%的Top-1准确率,超过了同等级别的CNN-/Attention-based模型。同时,在参数效率和准确性之间取得了良好的平衡:在iPhone14上运行速度比EdgeNeXt快2.8-4.0倍。此外,EMO不使用复杂的操作,但仍然在多个视觉任务中获得了非常竞争性的结果,这证明了其作为轻量级注意力模型的有效性和实用性。
4、代码
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.efficientnet_blocks import SqueezeExcite as SE
from einops import rearrange, reduce
from timm.models.layers.activations import *
from timm.models.layers import DropPath
# 确保inplace变量定义
inplace = True
# 二维层归一化,适应于卷积层输出
class LayerNorm2d(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
super().__init__()
self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)
def forward(self, x):
x = rearrange(x, 'b c h w -> b h w c').contiguous()
x = self.norm(x)
x = rearrange(x, 'b h w c -> b c h w').contiguous()
return x
# 获取规范化层的辅助函数
def get_norm(norm_layer='in_1d'):
eps = 1e-6
norm_dict = {
'none': nn.Identity,
'in_1d': partial(nn.InstanceNorm1d, eps=eps),
'in_2d': partial(nn.InstanceNorm2d, eps=eps),
'in_3d': partial(nn.InstanceNorm3d, eps=eps),
'bn_1d': partial(nn.BatchNorm1d, eps=eps),
'bn_2d': partial(nn.BatchNorm2d, eps=eps),
'bn_3d': partial(nn.BatchNorm3d, eps=eps),
'gn': partial(nn.GroupNorm, eps=eps),
'ln_1d': partial(nn.LayerNorm, eps=eps),
'ln_2d': partial(LayerNorm2d, eps=eps),
}
return norm_dict[norm_layer]
# 获取激活函数的辅助函数
def get_act(act_layer='relu'):
act_dict = {
'none': nn.Identity,
'sigmoid': Sigmoid,
'swish': Swish,
'mish': Mish,
'hsigmoid': HardSigmoid,
'hswish': HardSwish,
'hmish': HardMish,
'tanh': Tanh,
'relu': nn.ReLU,
'relu6': nn.ReLU6,
'prelu': PReLU,
'gelu': GELU,
'silu': nn.SiLU
}
return act_dict[act_layer]
# 特征缩放模块
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=True):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(1, 1, dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class LayerScale2D(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=True):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(1, dim, 1, 1))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
# 集成卷积、规范化和激活的层
class ConvNormAct(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False,
skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
super().__init__()
self.has_skip = skip and dim_in == dim_out
padding = math.ceil((kernel_size - stride) / 2)
self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
self.norm = get_norm(norm_layer)(dim_out)
self.act = get_act(act_layer)(inplace=inplace)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
if self.has_skip:
x = self.drop_path(x) + shortcut
return x
# 多尺度Patch嵌入模块
class MSPatchEmb(nn.Module):
def __init__(self, dim_in, emb_dim, kernel_size=2, c_group=-1, stride=1, dilations=[1, 2, 3],
norm_layer='bn_2d', act_layer='silu'):
"""
多尺度Patch嵌入模块,使用不同扩张率的卷积捕获多尺度特征
Args:
dim_in: 输入通道数
emb_dim: 输出通道数
kernel_size: 卷积核大小
c_group: 卷积分组数
stride: 步长
dilations: 扩张率列表
norm_layer: 规范化层类型
act_layer: 激活函数类型
"""
super().__init__()
self.dilation_num = len(dilations)
assert dim_in % c_group == 0
# 自动计算分组数
c_group = math.gcd(dim_in, emb_dim) if c_group == -1 else c_group
# 创建多尺度卷积
self.convs = nn.ModuleList()
for i in range(len(dilations)):
padding = math.ceil(((kernel_size - 1) * dilations[i] + 1 - stride) / 2)
self.convs.append(nn.Sequential(
nn.Conv2d(dim_in, emb_dim, kernel_size, stride, padding, dilations[i], groups=c_group),
get_norm(norm_layer)(emb_dim),
get_act(act_layer)(emb_dim)
))
def forward(self, x):
if self.dilation_num == 1:
x = self.convs[0](x)
else:
# 多尺度特征融合
x = torch.cat([self.convs[i](x).unsqueeze(dim=-1) for i in range(self.dilation_num)], dim=-1)
x = reduce(x, 'b c h w n -> b c h w', 'mean').contiguous()
return x
# 改进的残差模块(iRMB)
class iRMB(nn.Module):
"""
改进的残差模块(iRMB),融合了空间注意力机制和深度可分离卷积
核心特点:
- 可配置的空间注意力机制
- 深度可分离卷积
- Squeeze-and-Excitation模块
- 灵活的规范化和激活函数选择
"""
def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0,
norm_layer='bn_2d', act_layer='relu', v_proj=True, dw_ks=3, stride=1,
dilation=1, se_ratio=0.0, dim_head=64, window_size=7, attn_s=True,
qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False):
super().__init__()
# 规范化层
self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
# 中间维度
dim_mid = int(dim_in * exp_ratio)
# 是否使用跳跃连接
self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
# 空间注意力配置
self.attn_s = attn_s
if self.attn_s:
assert dim_in % dim_head == 0, 'dim should be divisible by num_heads'
self.dim_head = dim_head
self.window_size = window_size
self.num_head = dim_in // dim_head
self.scale = self.dim_head ** -0.5
self.attn_pre = attn_pre
# 注意力机制的QK投影
self.qk = ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias,
norm_layer='none', act_layer='none')
# V投影
self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1,
groups=self.num_head if v_group else 1,
bias=qkv_bias, norm_layer='none',
act_layer=act_layer, inplace=inplace)
self.attn_drop = nn.Dropout(attn_drop)
else:
# 如果不使用注意力,仅使用V投影
self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias,
norm_layer='none', act_layer=act_layer, inplace=inplace)
# 局部卷积(深度可分离卷积)
self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride,
dilation=dilation, groups=dim_mid, norm_layer='bn_2d',
act_layer='silu', inplace=inplace)
# Squeeze-and-Excitation模块
self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
# 输出投影和正则化
self.proj_drop = nn.Dropout(drop)
self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
def forward(self, x):
"""
前向传播过程
"""
shortcut = x
x = self.norm(x)
B, C, H, W = x.shape
if self.attn_s:
# 窗口注意力实现
# 计算窗口大小并处理填充
if self.window_size <= 0:
window_size_W, window_size_H = W, H
else:
window_size_W, window_size_H = self.window_size, self.window_size
pad_l = pad_t = 0
pad_r = (window_size_W - W % window_size_W) % window_size_W
pad_b = (window_size_H - H % window_size_H) % window_size_H
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,))
n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
# 重塑为窗口形式
x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous()
# 计算注意力
b, c, h, w = x.shape
qk = self.qk(x)
qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head',
qk=2, heads=self.num_head, dim_head=self.dim_head).contiguous()
q, k = qk[0], qk[1]
# 计算注意力得分并应用softmax
attn_spa = (q @ k.transpose(-2, -1)) * self.scale
attn_spa = attn_spa.softmax(dim=-1)
attn_spa = self.attn_drop(attn_spa)
# 应用注意力(两种模式)
if self.attn_pre:
x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
x_spa = attn_spa @ x
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w',
heads=self.num_head, h=h, w=w).contiguous()
x_spa = self.v(x_spa)
else:
v = self.v(x)
v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
x_spa = attn_spa @ v
x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w',
heads=self.num_head, h=h, w=w).contiguous()
# 恢复原始形状并移除填充
x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous()
if pad_r > 0 or pad_b > 0:
x = x[:, :, :H, :W].contiguous()
else:
# 不使用注意力时,直接应用V投影
x = self.v(x)
# 应用局部卷积和SE模块
x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
# 输出投影和跳跃连接
x = self.proj_drop(x)
x = self.proj(x)
x = (shortcut + self.drop_path(x)) if self.has_skip else x
return x
# 测试代码
if __name__ == '__main__':
input = torch.randn(3, 64, 64, 64).cuda() # 创建一个随机输入
model = iRMB(64, 64).cuda() # 实例化注意力模块
output = model(input)
print(f"输入形状: {input.shape}")
print(f"输出形状: {output.shape}") # 打印输出形状
35、AFT模块
论文《An Attention Free Transformer》
1、作用
**注意力自由变换器(AFT)**旨在通过去除传统Transformer中的点积自注意力机制,提供一种更高效的变换器模型。它特别适用于需要高计算效率和较低内存消耗的应用场景,如移动设备和边缘计算。
2、机制
AFT通过直接对输入特征进行变换来实现序列间的关联,不再需要复杂的自注意力计算。它使用一种简单的基于位置的加权策略,通过这种方式,每个输出元素是输入元素的加权和,权重由元素的相对位置决定。这种方法极大地降低了模型的复杂性和运行时内存需求。
3、独特优势
1、高效性:AFT由于避免了昂贵的自注意力计算,因此在执行速度和计算效率上有明显优势。
2、简化模型结构:通过消除自注意力机制,AFT简化了模型结构,使得模型更加轻量化,易于实现和部署。
3、适应性强:AFT的结构使其更容易适应于不同的任务和数据集,具有良好的泛化能力。
4、资源占用低:对于资源受限的环境,如移动设备和边缘计算设备,AFT提供了一种实用的解决方案,能够在保持较高性能的同时,降低资源消耗。
4、代码
import torch
import torch.nn as nn
import torch.nn.init as init
class AFT_FULL(nn.Module):
"""
AFT-FULL模块(Attention Free Transformer - Full)
论文:《An Attention Free Transformer》
核心思想:用位置偏置的元素级操作替代传统自注意力的矩阵乘法,降低计算复杂度
"""
def __init__(self, d_model, n=49, simple=False):
"""
Args:
d_model: 特征维度(输入/输出通道数)
n: 序列长度(如图像分块后的patch数量)
simple: 是否使用简单模式(位置偏置为固定零矩阵,否则为可学习参数)
"""
super().__init__()
# QKV线性投影层(无偏置,简化设计)
self.fc_q = nn.Linear(d_model, d_model, bias=False)
self.fc_k = nn.Linear(d_model, d_model, bias=False)
self.fc_v = nn.Linear(d_model, d_model, bias=False)
# 位置偏置(建模序列中元素的相对位置关系)
if simple:
# 简单模式:固定零矩阵,无位置信息学习
self.register_buffer('position_biases', torch.zeros((n, n)))
else:
# 标准模式:可学习的位置偏置矩阵
self.position_biases = nn.Parameter(torch.ones((n, n)))
self.d_model = d_model
self.n = n # 序列长度(如H*W的patch数)
self.sigmoid = nn.Sigmoid() # 门控机制激活函数
# 初始化权重
self.init_weights()
def init_weights(self):
"""参数初始化:确保线性层权重初始化合理"""
for m in [self.fc_q, self.fc_k, self.fc_v]:
if isinstance(m, nn.Linear):
# 正态分布初始化,标准差与特征维度相关
init.normal_(m.weight, std=self.d_model ** -0.5)
# 位置偏置初始化(若为可学习参数)
if hasattr(self.position_biases, 'requires_grad') and self.position_biases.requires_grad:
init.constant_(self.position_biases, 0.1) # 小初始值避免softmax饱和
def forward(self, input):
"""
前向传播:用元素级操作替代自注意力的矩阵乘法
Args:
input: 输入特征,形状为 [batch_size, seq_len, d_model]
Returns:
输出特征,形状为 [batch_size, seq_len, d_model]
"""
bs, n, dim = input.shape # bs: 批大小;n: 序列长度;dim: 特征维度
# 1. QKV投影
q = self.fc_q(input) # [bs, n, dim]:查询向量
k = self.fc_k(input).view(1, bs, n, dim) # [1, bs, n, dim]:扩展维度用于广播
v = self.fc_v(input).view(1, bs, n, dim) # [1, bs, n, dim]:扩展维度用于广播
# 2. 位置偏置融合(核心创新点)
# 位置偏置矩阵重塑为 [n, 1, n, 1],与k的 [1, bs, n, dim] 广播匹配
pos_bias = self.position_biases.view(n, 1, -1, 1) # [n, 1, n, 1]
# 3. 加权求和(替代注意力分数的softmax加权)
# 分子:exp(k + 位置偏置) * v 的累加(按序列维度)
numerator = torch.sum(torch.exp(k + pos_bias) * v, dim=2) # [n, bs, dim]
# 分母:exp(k + 位置偏置) 的累加(归一化因子)
denominator = torch.sum(torch.exp(k + pos_bias), dim=2) # [n, bs, dim]
# 4. 门控机制(查询向量调制)
# 计算加权平均并转置为 [bs, n, dim]
weighted_v = (numerator / denominator).permute(1, 0, 2) # [bs, n, dim]
# 用sigmoid(q)作为门控,过滤无关特征
out = self.sigmoid(q) * weighted_v # [bs, n, dim]
return out
# 测试代码
if __name__ == '__main__':
# 实例化AFT-FULL模块(特征维度512,序列长度64)
model = AFT_FULL(d_model=512, n=64, simple=False).cuda()
# 随机输入:[batch_size=64, seq_len=64, d_model=512]
input_tensor = torch.randn(64, 64, 512).cuda()
# 前向传播
output_tensor = model(input_tensor)
# 验证输出形状
print(f"输入形状: {input_tensor.shape}")
print(f"输出形状: {output_tensor.shape}") # 预期: torch.Size([64, 64, 512])
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)