31、Cloattention模块

论文《Rethinking Local Perception in Lightweight Vision Transformer》

1、作用

CloFormer(Context-aware Local Enhancement Vision Transformer)是一种轻量级的视觉Transformer,用于在保持模型轻量化的同时,提高在各种视觉任务中的性能,包括图像分类、目标检测和语义分割。其主要目的是提升移动设备上的视觉模型性能,克服直接缩减标准ViT(Vision Transformer)模型尺寸导致的性能下降问题。

2、机制

CloFormer通过引入AttnConv(Attention Style Convolution Operator)来实现上下文感知的局部增强,从而有效捕获高频局部信息。该模型采用两分支结构:

1、局部分支

利用AttnConv融合共享权重和上下文感知权重来聚合高频局部信息。首先,使用深度可分离卷积(Depthwise Convolution,DWconv)提取局部表示,然后部署上下文感知权重来增强局部特征。

2、全局分支

采用标准的注意力机制,通过对K和V进行下采样来降低FLOPs,帮助模型捕捉低频全局信息。

3、独特优势

1、上下文感知的局部增强

通过AttnConv,CloFormer有效地结合了共享权重和上下文感知权重的优势,实现了高质量的局部特征增强。

2、两分支结构

通过同时捕获高频和低频信息,模型能够在不同的视觉任务中达到更好的性能。

3、轻量化设计

CloFormer专为移动设备设计,通过精心的模型架构设计和权重共享机制,实现了在保持轻量化的同时提高模型性能。

4、代码

import torch
import torch.nn as nn
from efficientnet_pytorch.model import MemoryEfficientSwish


class AttnMap(nn.Module):
    """
    注意力图生成模块
    功能:通过1x1卷积和Swish激活生成注意力权重映射
    """
    def __init__(self, dim):
        super().__init__()
        self.act_block = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),  # 通道内特征调整
            MemoryEfficientSwish(),  # 高效Swish激活(节省内存)
            nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)   # 生成注意力权重
        )

    def forward(self, x):
        return self.act_block(x)


class EfficientAttention(nn.Module):
    """
    高效注意力模块(Efficient Attention)
    核心思想:通过分离高频和低频注意力,在保持性能的同时降低计算复杂度
    """
    def __init__(self, 
                 dim, 
                 num_heads=8, 
                 group_split=[4, 4],  # 高频头数 + 低频头数 = 总头数
                 kernel_sizes=[5],   # 高频注意力的卷积核尺寸
                 window_size=4,      # 低频注意力的窗口大小
                 attn_drop=0., 
                 proj_drop=0., 
                 qkv_bias=True):
        super().__init__()
        # 参数校验
        assert sum(group_split) == num_heads, "分组头数之和必须等于总头数"
        assert len(kernel_sizes) + 1 == len(group_split), "核大小数量+1必须等于分组数量"
        
        self.dim = dim
        self.num_heads = num_heads
        self.dim_head = dim // num_heads  # 每个头的维度
        self.scalor = self.dim_head **-0.5  # 注意力缩放因子
        self.kernel_sizes = kernel_sizes
        self.window_size = window_size
        self.group_split = group_split
        
        # 1. 高频注意力分支(处理细节特征)
        self.high_freq_branches = nn.ModuleList()
        self.convs = nn.ModuleList()
        self.act_blocks = nn.ModuleList()
        self.qkvs = nn.ModuleList()
        
        for i in range(len(kernel_sizes)):
            kernel_size = kernel_sizes[i]
            group_head = group_split[i]
            if group_head == 0:
                continue
            # 高频分支组件:QKV生成 + 深度卷积 + 注意力图生成
            self.qkvs.append(nn.Conv2d(
                dim, 3 * group_head * self.dim_head, 
                kernel_size=1, stride=1, padding=0, bias=qkv_bias
            ))
            # 深度卷积(分组数=输入通道数,确保每个通道独立处理)
            self.convs.append(nn.Conv2d(
                3 * self.dim_head * group_head, 
                3 * self.dim_head * group_head,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2,
                groups=3 * self.dim_head * group_head  # 深度卷积
            ))
            self.act_blocks.append(AttnMap(self.dim_head * group_head))
        
        # 2. 低频注意力分支(处理全局特征)
        self.low_freq_enabled = group_split[-1] != 0
        if self.low_freq_enabled:
            self.global_q = nn.Conv2d(
                dim, group_split[-1] * self.dim_head,
                kernel_size=1, stride=1, padding=0, bias=qkv_bias
            )
            self.global_kv = nn.Conv2d(
                dim, group_split[-1] * self.dim_head * 2,
                kernel_size=1, stride=1, padding=0, bias=qkv_bias
            )
            # 窗口平均池化(降低低频分支的分辨率)
            self.avgpool = nn.AvgPool2d(window_size, window_size) if window_size != 1 else nn.Identity()
        
        # 3. 输出投影与正则化
        self.proj = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj_drop = nn.Dropout(proj_drop)

    def high_fre_attntion(self, x, qkv_layer, conv_layer, attn_block):
        """高频注意力计算(细节特征)"""
        b, c, h, w = x.shape
        # 生成QKV并通过深度卷积处理
        qkv = qkv_layer(x)  # [B, 3*G*D, H, W]
        qkv = conv_layer(qkv)  # 深度卷积提取局部特征
        qkv = qkv.reshape(b, 3, -1, h, w).transpose(0, 1)  # [3, B, G*D, H, W]
        q, k, v = qkv  # 拆分Q、K、V
        
        # 计算注意力权重并应用
        attn = attn_block(q * k) * self.scalor  # 生成注意力图
        attn = self.attn_drop(torch.tanh(attn))  # tanh激活(限制权重范围)
        return attn * v  # 注意力加权V

    def low_fre_attention(self, x):
        """低频注意力计算(全局特征)"""
        b, c, h, w = x.shape
        # 生成Q(原始分辨率)
        q = self.global_q(x).reshape(
            b, -1, self.dim_head, h * w
        ).transpose(-1, -2)  # [B, G, H*W, D]
        
        # 生成KV(下采样后,降低计算量)
        kv = self.avgpool(x)  # [B, C, H/win, W/win]
        kv = self.global_kv(kv).view(
            b, 2, -1, self.dim_head, 
            (h * w) // (self.window_size** 2)
        ).permute(1, 0, 2, 4, 3)  # [2, B, G, (H*W)/win², D]
        k, v = kv
        
        # 计算全局注意力
        attn = self.scalor * q @ k.transpose(-1, -2)  # [B, G, H*W, (H*W)/win²]
        attn = self.attn_drop(attn.softmax(dim=-1))  # 归一化
        res = attn @ v  # [B, G, H*W, D]
        
        # 重塑回原始尺寸
        return res.transpose(2, 3).reshape(b, -1, h, w)

    def forward(self, x):
        """
        Args:
            x: 输入特征图,形状为 [batch_size, dim, height, width]
        Returns:
            注意力增强后的特征图
        """
        outputs = []
        
        # 1. 处理高频分支
        for i in range(len(self.qkvs)):
            if self.group_split[i] == 0:
                continue
            outputs.append(self.high_fre_attntion(
                x, self.qkvs[i], self.convs[i], self.act_blocks[i]
            ))
        
        # 2. 处理低频分支(若启用)
        if self.low_freq_enabled:
            outputs.append(self.low_fre_attention(x))
        
        # 3. 融合所有分支并投影
        fused = torch.cat(outputs, dim=1)
        out = self.proj(fused)  # 特征投影
        out = self.proj_drop(out)
        
        return out


# 测试代码
if __name__ == '__main__':
    # 实例化高效注意力模块(8头,4高频+4低频)
    attn = EfficientAttention(
        dim=64,
        num_heads=8,
        group_split=[4, 4],  # 4个高频头 + 4个低频头
        kernel_sizes=[5],    # 高频分支使用5x5卷积
        window_size=4
    ).cuda()
    
    # 输入特征图:[batch=1, dim=64, height=64, width=64]
    x = torch.randn(1, 64, 64, 64).cuda()
    
    # 前向传播
    out = attn(x)
    
    # 验证输出形状
    print(f"输入形状: {x.shape}")
    print(f"输出形状: {out.shape}")  # 输出: torch.Size([1, 64, 64, 64])

32、BiFormer模块

论文《BiFormer: Vision Transformer with Bi-Level Routing Attention》

1、作用

BiFormer旨在解决视觉Transformer在处理图像时的计算和内存效率问题。它通过引入双层路由注意力(Bi-Level Routing Attention, BRA),实现了动态的、基于内容的稀疏注意力机制,以更灵活、高效地分配计算资源。

2、机制

BiFormer的核心是双层路由注意力(BRA),该机制包含两个主要步骤:区域到区域的路由和令牌到令牌的注意力。首先,通过构建一个区域级别的关联图并对其进行修剪,来确定哪些区域是相关的,并应该被进一步考虑。其次,在这些选定的区域中,应用细粒度的令牌到令牌注意力,以便每个查询仅与少数最相关的键-值对进行交互。这种方法允许BiFormer动态地关注图像中与特定查询最相关的部分,而不是在所有空间位置上计算成对的令牌交互,从而显著减少了计算复杂度和内存占用。

3、独特优势

1、计算效率

BiFormer通过其双层路由注意力机制,实现了与传统全局注意力相比的显著计算和内存效率改进,具体体现在能够动态地仅对最相关的令牌子集进行计算。

2、动态稀疏性

与其他稀疏注意力方法不同,BiFormer能够根据内容动态选择关注的区域和令牌,使其能够更有效地处理各种视觉任务。

3、高性能

实验结果表明,BiFormer在图像分类、对象检测和语义分割等多个视觉任务上实现了优异的性能,尤其是在与模型大小和计算复杂度相当的情况下,其性能超越了现有的最先进方法。

4、代码

from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor, LongTensor


class TopkRouting(nn.Module):
    """
    顶层K路由模块
    功能:生成区域间的路由权重和索引,实现粗粒度的特征选择
    """
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim **-0.5  # 缩放因子
        self.diff_routing = diff_routing  # 是否启用可微分路由
        # 可学习的嵌入层(若启用参数化路由)
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        self.routing_act = nn.Softmax(dim=-1)  # 路由权重归一化

    def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor, Tensor]:
        # 若不启用可微分路由,使用detach切断梯度
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        
        # 嵌入投影与注意力计算
        query_hat, key_hat = self.emb(query), self.emb(key)
        attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)
        
        # 取Top-K路由索引与权重
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)
        r_weight = self.routing_act(topk_attn_logit)  # 路由权重归一化
        
        return r_weight, topk_index


class KVGather(nn.Module):
    """
    键值对聚集模块
    功能:根据路由索引选择对应的键值对,并应用路由权重
    """
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard'], "mul_weight必须为'none'/'soft'/'hard'"
        self.mul_weight = mul_weight  # 权重应用方式

    def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        
        # 扩展索引维度以匹配KV形状
        r_idx_expanded = r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
        # 按索引聚集KV
        topk_kv = torch.gather(
            kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
            dim=2,
            index=r_idx_expanded
        )
        
        # 应用路由权重(仅soft模式)
        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv
        elif self.mul_weight == 'hard':
            raise NotImplementedError("硬路由的可微分实现待完善")
        
        return topk_kv  # 形状: (n, p², topk, w2, c_kv)


class QKVLinear(nn.Module):
    """
    QKV线性投影模块
    功能:将输入特征投影到查询(Q)、键(K)、值(V)空间
    """
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        # 单次线性投影生成QKV(Q: qk_dim,KV: qk_dim + dim)
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)

    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
        return q, kv


class BiLevelRoutingAttention(nn.Module):
    """
    双层路由注意力模块(Bi-Level Routing Attention)
    论文:《SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation》
    核心思想:通过区域级路由和像素级注意力的双层机制,高效捕获长距离依赖
    """
    def __init__(self, 
                 dim, 
                 n_win=7,  # 区域划分数量(沿高/宽方向)
                 num_heads=8, 
                 qk_dim=None, 
                 qk_scale=None, 
                 kv_per_win=4,  # 每个区域的KV采样数
                 kv_downsample_ratio=4,  # KV下采样比率
                 kv_downsample_mode='identity',  # KV下采样方式
                 topk=4,  # 区域路由的Top-K选择
                 param_attention="qkvo",  # 注意力参数化模式
                 param_routing=False,  # 路由是否参数化
                 diff_routing=False,  # 路由是否可微分
                 soft_routing=False,  # 是否启用软路由
                 side_dwconv=3,  # 局部增强卷积核大小
                 auto_pad=True):  # 是否自动填充以匹配区域划分
        super().__init__()
        self.dim = dim
        self.n_win = n_win
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        # 校验维度可分性
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, \
            "qk_dim和dim必须能被num_heads整除!"
        
        self.scale = qk_scale or self.qk_dim ** -0.5
        self.side_dwconv = side_dwconv
        self.auto_pad = auto_pad
        
        # 局部位置增强(LEPE)
        self.lepe = nn.Conv2d(
            dim, dim, 
            kernel_size=side_dwconv, 
            stride=1, 
            padding=side_dwconv // 2, 
            groups=dim
        ) if side_dwconv > 0 else lambda x: torch.zeros_like(x)
        
        # 路由与聚集模块
        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        assert not (self.param_routing and not self.diff_routing), \
            "参数化路由必须启用可微分模式"
        
        # 初始化路由模块
        self.router = TopkRouting(
            qk_dim=self.qk_dim,
            qk_scale=self.scale,
            topk=topk,
            diff_routing=diff_routing,
            param_routing=param_routing
        )
        
        # 初始化KV聚集模块
        if soft_routing:
            mul_weight = 'soft'
        elif diff_routing:
            mul_weight = 'hard'
        else:
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)
        
        # 注意力参数化配置
        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)  # 输出投影
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()  # 无输出投影
        else:
            raise ValueError(f"不支持的参数化模式: {param_attention}")
        
        # KV下采样配置
        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        
        # 初始化KV下采样层
        if kv_downsample_mode == 'ada_avgpool':
            assert kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(kv_per_win)
        elif kv_downsample_mode == 'ada_maxpool':
            assert kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(kv_per_win)
        elif kv_downsample_mode == 'maxpool':
            assert kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(kv_downsample_ratio) if kv_downsample_ratio > 1 else nn.Identity()
        elif kv_downsample_mode == 'avgpool':
            assert kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(kv_downsample_ratio) if kv_downsample_ratio > 1 else nn.Identity()
        elif kv_downsample_mode == 'identity':
            self.kv_down = nn.Identity()
        else:
            raise ValueError(f"不支持的下采样模式: {kv_downsample_mode}")

    def forward(self, x: Tensor, ret_attn_mask=False):
        """
        前向传播过程
        Args:
            x: 输入特征图,形状为 [batch_size, height, width, channels]
            ret_attn_mask: 是否返回注意力掩码
        Returns:
            注意力增强后的特征图,形状为 [batch_size, channels, height, width]
        """
        # 自动填充以匹配区域划分
        if self.auto_pad:
            N, H_in, W_in, C = x.size()
            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))  # 填充宽和高
            _, H, W, _ = x.size()
        else:
            N, H, W, C = x.size()
            assert H % self.n_win == 0 and W % self.n_win == 0, \
                "输入尺寸必须能被n_win整除(当auto_pad=False时)"
        
        # 1. 区域划分(将特征图划分为n_win×n_win个区域)
        x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
        q, kv = self.qkv(x)  # Q: [n, p², h, w, qk_dim]; KV: [n, p², h, w, qk_dim+dim]
        
        # 2. 像素级QKV处理
        q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')  # 像素级Q
        # KV下采样与重排
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
        
        # 3. 区域级路由(粗粒度Top-K区域选择)
        q_win = q.mean([2, 3])  # 区域级Q(平均池化像素)
        k_win = kv[..., 0:self.qk_dim].mean([2, 3])  # 区域级K
        r_weight, r_idx = self.router(q_win, k_win)  # 路由权重与索引
        
        # 4. KV聚集(根据路由选择Top-K区域的KV)
        kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)
        k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)  # 分离K和V
        
        # 5. 像素级注意力计算
        # 重塑多头维度
        q_pix = rearrange(q_pix, 'n p2 (h w) (h d) -> (n p2 h) p2 (h w) d', h=self.num_heads)
        k_pix_sel = rearrange(k_pix_sel, 'n p2 topk (h w) (h d) -> (n p2 h) p2 topk (h w) d', h=self.num_heads)
        v_pix_sel = rearrange(v_pix_sel, 'n p2 topk (h w) (h d) -> (n p2 h) p2 topk (h w) d', h=self.num_heads)
        
        # 注意力分数计算
        attn = (q_pix * self.scale) @ k_pix_sel.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        
        # 注意力加权求和
        out = attn @ v_pix_sel
        out = rearrange(out, '(n p2 h) p2 (h w) d -> n (j h) (i w) (h d)', 
                       j=self.n_win, i=self.n_win, h=self.num_heads)
        
        # 6. 局部增强与输出投影
        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', 
                                 j=self.n_win, i=self.n_win).contiguous())
        lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
        out = out + lepe  # 局部位置增强
        out = self.wo(out)  # 输出投影
        
        # 移除自动填充的部分
        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()
        
        # 调整通道维度到第二维(NCHW格式)
        if ret_attn_mask:
            return rearrange(out, "n h w c -> n c h w"), r_weight, r_idx, attn
        else:
            return rearrange(out, "n h w c -> n c h w")


# 测试代码
if __name__ == '__main__':
    # 实例化双层路由注意力模块(输入通道256,7×7区域划分)
    attn = BiLevelRoutingAttention(
        dim=256,
        n_win=7,
        num_heads=8,
        topk=4
    ).cuda()
    
    # 输入特征图:[batch=1, height=64, width=64, channels=256](注意输入格式为NHWC)
    input_tensor = torch.randn(1, 64, 64, 256).cuda()
    
    # 前向传播
    output_tensor = attn(input_tensor)
    
    # 验证输出形状(应保持NCHW格式的64×64尺寸)
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output_tensor.shape}")  # 输出: torch.Size([1, 256, 64, 64])

33、STVit模块

论文《Vision Transformer with Super Token Sampling》

1、作用

STVit旨在通过改进视觉Transformer的空间-时间效率,解决在处理视频和图像任务时常见的计算冗余问题。该模型尝试减少早期层次捕捉局部特征时的冗余计算,从而减少不必要的计算成本。

2、机制

STVit引入了一种类似于图像处理中“超像素”的概念,称为“超级令牌”(super tokens),以减少自注意力计算中元素的数量,同时保留对全局关系建模的能力。该过程涉及从视觉令牌中采样超级令牌,对这些超级令牌执行自注意力操作,并将它们映射回原始令牌空间。

3、独特优势

STVit在不同的视觉任务中展示了强大的性能,包括图像分类、对象检测和分割,同时拥有更少的参数和较低的计算成本。例如,STVit在没有额外训练数据的情况下,在ImageNet-1K分类任务上达到了86.4%的顶级1准确率,且参数少于100M。

4、代码

import torch
import torch.nn as nn
import torch.nn.functional as F


class Unfold(nn.Module):
    """
    特征展开模块
    功能:将输入特征图按局部窗口展开为扁平形式
    """
    def __init__(self, kernel_size=3):
        super().__init__()
        self.kernel_size = kernel_size
        # 初始化固定权重(单位矩阵)用于展开操作
        weights = torch.eye(kernel_size ** 2).reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
        self.register_buffer('weights', weights)  # 非训练参数

    def forward(self, x):
        b, c, h, w = x.shape
        # 使用卷积实现滑动窗口展开
        x = F.conv2d(x.reshape(b * c, 1, h, w), self.weights, stride=1, padding=self.kernel_size // 2)
        return x.reshape(b, c * self.kernel_size ** 2, h * w)


class Fold(nn.Module):
    """
    特征折叠模块
    功能:将展开的特征图恢复为原始形状
    """
    def __init__(self, kernel_size=3):
        super().__init__()
        self.kernel_size = kernel_size
        # 初始化固定权重(单位矩阵)用于折叠操作
        weights = torch.eye(kernel_size ** 2).reshape(kernel_size ** 2, 1, kernel_size, kernel_size)
        self.register_buffer('weights', weights)  # 非训练参数

    def forward(self, x):
        b, c, h, w = x.shape
        # 使用转置卷积实现特征折叠
        x = F.conv_transpose2d(x, self.weights, stride=1, padding=self.kernel_size // 2)
        return x


class Attention(nn.Module):
    """
    标准自注意力模块
    功能:通过查询(Q)、键(K)、值(V)计算特征间的注意力关系
    """
    def __init__(self, dim, window_size=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.window_size = window_size
        
        # 线性投影生成QKV
        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Conv2d(dim, dim, kernel_size=1)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, C, H, W = x.shape
        N = H * W
        
        # 生成QKV并重塑为多头形式
        q, k, v = self.qkv(x).reshape(B, self.num_heads, C // self.num_heads * 3, N).chunk(3, dim=2)
        
        # 计算注意力分数并应用softmax
        attn = (k.transpose(-1, -2) @ q) * self.scale  # [B, num_heads, N, N]
        attn = attn.softmax(dim=-2)
        attn = self.attn_drop(attn)
        
        # 应用注意力权重并重塑回原始维度
        x = (v @ attn).reshape(B, C, H, W)
        x = self.proj(x)
        x = self.proj_drop(x)
        
        return x


class StokenAttention(nn.Module):
    """
    空间Token注意力模块(StokenAttention)
    论文:《Stoken: Learning Spatial Tokens for Vision Transformers》
    核心思想:通过迭代优化空间Token,实现高效的视觉特征表示
    """
    def __init__(self, 
                 dim, 
                 stoken_size, 
                 n_iter=1, 
                 num_heads=8, 
                 qkv_bias=False, 
                 qk_scale=None, 
                 attn_drop=0., 
                 proj_drop=0.):
        super().__init__()
        self.n_iter = n_iter  # 迭代优化次数
        self.stoken_size = stoken_size  # 空间Token尺寸
        self.scale = dim ** -0.5
        
        # 模块组件
        self.unfold = Unfold(3)
        self.fold = Fold(3)
        self.stoken_refine = Attention(
            dim, 
            num_heads=num_heads, 
            qkv_bias=qkv_bias, 
            qk_scale=qk_scale, 
            attn_drop=attn_drop, 
            proj_drop=proj_drop
        )

    def stoken_forward(self, x):
        """
        空间Token迭代优化流程
        """
        B, C, H0, W0 = x.shape
        h, w = self.stoken_size
        
        # 自动填充以匹配stoken_size
        pad_l = pad_t = 0
        pad_r = (w - W0 % w) % w
        pad_b = (h - H0 % h) % h
        if pad_r > 0 or pad_b > 0:
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
        _, _, H, W = x.shape
        hh, ww = H // h, W // w
        
        # 1. 生成初始空间Token(通过自适应池化)
        stoken_features = F.adaptive_avg_pool2d(x, (hh, ww))  # [B, C, hh, ww]
        
        # 2. 像素特征重组(将特征图划分为局部块)
        pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
        
        # 3. 迭代优化空间Token(无梯度更新,仅前向传播)
        with torch.no_grad():
            for idx in range(self.n_iter):
                # 展开Token并计算相似度矩阵
                stoken_features_unfolded = self.unfold(stoken_features)  # [B, C*9, hh*ww]
                stoken_features_unfolded = stoken_features_unfolded.transpose(1, 2).reshape(B, hh * ww, C, 9)
                
                # 计算像素与Token间的亲和度矩阵
                affinity_matrix = pixel_features @ stoken_features_unfolded * self.scale  # [B, hh*ww, h*w, 9]
                affinity_matrix = affinity_matrix.softmax(-1)  # 归一化
                
                # 计算亲和度矩阵的折叠版本(用于后续归一化)
                affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
                affinity_matrix_sum = self.fold(affinity_matrix_sum)
                
                # 更新Token特征(通过像素与亲和度矩阵的加权组合)
                if idx < self.n_iter - 1:
                    stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix  # [B, hh*ww, C, 9]
                    stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
                    stoken_features = stoken_features / (affinity_matrix_sum + 1e-12)  # 归一化
        
        # 4. 应用自注意力进一步优化Token特征
        stoken_features = self.stoken_refine(stoken_features)
        
        # 5. 展开优化后的Token并映射回像素空间
        stoken_features = self.unfold(stoken_features)
        stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
        
        # 通过亲和度矩阵将Token特征映射回像素
        pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2)  # [B, hh*ww, C, h*w]
        pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
        
        # 移除填充部分
        if pad_r > 0 or pad_b > 0:
            pixel_features = pixel_features[:, :, :H0, :W0]
        
        return pixel_features

    def direct_forward(self, x):
        """
        直接应用注意力(当stoken_size=1x1时)
        """
        return self.stoken_refine(x)

    def forward(self, x):
        """
        前向传播入口
        """
        if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
            return self.stoken_forward(x)
        else:
            return self.direct_forward(x)


# 测试代码
if __name__ == '__main__':
    # 实例化StokenAttention模块(输入通道64,空间Token尺寸8×8)
    stoken_attn = StokenAttention(
        dim=64,
        stoken_size=[8, 8],
        n_iter=1,
        num_heads=8
    ).cuda()
    
    # 输入特征图:[batch=3, channels=64, height=64, width=64]
    input_tensor = torch.randn(3, 64, 64, 64).cuda()
    
    # 前向传播
    output_tensor = stoken_attn(input_tensor)
    
    # 验证输出形状
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output_tensor.shape}")  # 输出: torch.Size([3, 64, 64, 64])

34、IRMB模块

论文《Rethinking Mobile Block for Efficient Attention-based Models》

1、作用

本文提出了一种有效的轻量级模型设计方法,旨在开发现代高效的轻量级模型,用于密集预测任务,同时平衡参数、FLOPs和性能。作者通过重新思考高效的Inverted Residual Block(IRB)和Transformer的有效组件,从统一的视角出发,扩展了基于CNN的IRB到基于Meta attention的模型,并抽象出了一种一次残差的Meta Mobile Block(MMB),用于轻量级模型设计。

2、机制

本研究通过简单但有效的设计准则,提出了一种现代的Inverted Residual Mobile Block(iRMB),并使用iRMB构建了一个类似于ResNet的高效模型(EMO),仅用于下游任务。EMO通过将CNN的效率和Transformer的动态建模能力结合在iRMB中,有效地提高了模型性能。同时,EMO在不引入复杂结构的情况下,实现了与当前最先进的轻量级注意力模型的竞争性能。

3、独特优势

EMO在ImageNet-1K、COCO2017和ADE20K基准上的广泛实验展示了其优越性,例如,EMO-1M/2M/5M分别达到了71.5%、75.1%和78.4%的Top-1准确率,超过了同等级别的CNN-/Attention-based模型。同时,在参数效率和准确性之间取得了良好的平衡:在iPhone14上运行速度比EdgeNeXt快2.8-4.0倍。此外,EMO不使用复杂的操作,但仍然在多个视觉任务中获得了非常竞争性的结果,这证明了其作为轻量级注意力模型的有效性和实用性。

4、代码

import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.efficientnet_blocks import SqueezeExcite as SE
from einops import rearrange, reduce
from timm.models.layers.activations import *
from timm.models.layers import DropPath

# 确保inplace变量定义
inplace = True

# 二维层归一化,适应于卷积层输出
class LayerNorm2d(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
        super().__init__()
        self.norm = nn.LayerNorm(normalized_shape, eps, elementwise_affine)
        
    def forward(self, x):
        x = rearrange(x, 'b c h w -> b h w c').contiguous()
        x = self.norm(x)
        x = rearrange(x, 'b h w c -> b c h w').contiguous()
        return x

# 获取规范化层的辅助函数
def get_norm(norm_layer='in_1d'):
    eps = 1e-6
    norm_dict = {
        'none': nn.Identity,
        'in_1d': partial(nn.InstanceNorm1d, eps=eps),
        'in_2d': partial(nn.InstanceNorm2d, eps=eps),
        'in_3d': partial(nn.InstanceNorm3d, eps=eps),
        'bn_1d': partial(nn.BatchNorm1d, eps=eps),
        'bn_2d': partial(nn.BatchNorm2d, eps=eps),
        'bn_3d': partial(nn.BatchNorm3d, eps=eps),
        'gn': partial(nn.GroupNorm, eps=eps),
        'ln_1d': partial(nn.LayerNorm, eps=eps),
        'ln_2d': partial(LayerNorm2d, eps=eps),
    }
    return norm_dict[norm_layer]

# 获取激活函数的辅助函数
def get_act(act_layer='relu'):
    act_dict = {
        'none': nn.Identity,
        'sigmoid': Sigmoid,
        'swish': Swish,
        'mish': Mish,
        'hsigmoid': HardSigmoid,
        'hswish': HardSwish,
        'hmish': HardMish,
        'tanh': Tanh,
        'relu': nn.ReLU,
        'relu6': nn.ReLU6,
        'prelu': PReLU,
        'gelu': GELU,
        'silu': nn.SiLU
    }
    return act_dict[act_layer]

# 特征缩放模块
class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=True):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(1, 1, dim))
        
    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

class LayerScale2D(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=True):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(1, dim, 1, 1))
        
    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma

# 集成卷积、规范化和激活的层
class ConvNormAct(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, stride=1, dilation=1, groups=1, bias=False, 
                 skip=False, norm_layer='bn_2d', act_layer='relu', inplace=True, drop_path_rate=0.):
        super().__init__()
        self.has_skip = skip and dim_in == dim_out
        padding = math.ceil((kernel_size - stride) / 2)
        
        self.conv = nn.Conv2d(dim_in, dim_out, kernel_size, stride, padding, dilation, groups, bias)
        self.norm = get_norm(norm_layer)(dim_out)
        self.act = get_act(act_layer)(inplace=inplace)
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
        
    def forward(self, x):
        shortcut = x
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        
        if self.has_skip:
            x = self.drop_path(x) + shortcut
            
        return x

# 多尺度Patch嵌入模块
class MSPatchEmb(nn.Module):
    def __init__(self, dim_in, emb_dim, kernel_size=2, c_group=-1, stride=1, dilations=[1, 2, 3], 
                 norm_layer='bn_2d', act_layer='silu'):
        """
        多尺度Patch嵌入模块,使用不同扩张率的卷积捕获多尺度特征
        
        Args:
            dim_in: 输入通道数
            emb_dim: 输出通道数
            kernel_size: 卷积核大小
            c_group: 卷积分组数
            stride: 步长
            dilations: 扩张率列表
            norm_layer: 规范化层类型
            act_layer: 激活函数类型
        """
        super().__init__()
        self.dilation_num = len(dilations)
        assert dim_in % c_group == 0
        
        # 自动计算分组数
        c_group = math.gcd(dim_in, emb_dim) if c_group == -1 else c_group
        
        # 创建多尺度卷积
        self.convs = nn.ModuleList()
        for i in range(len(dilations)):
            padding = math.ceil(((kernel_size - 1) * dilations[i] + 1 - stride) / 2)
            self.convs.append(nn.Sequential(
                nn.Conv2d(dim_in, emb_dim, kernel_size, stride, padding, dilations[i], groups=c_group),
                get_norm(norm_layer)(emb_dim),
                get_act(act_layer)(emb_dim)
            ))
            
    def forward(self, x):
        if self.dilation_num == 1:
            x = self.convs[0](x)
        else:
            # 多尺度特征融合
            x = torch.cat([self.convs[i](x).unsqueeze(dim=-1) for i in range(self.dilation_num)], dim=-1)
            x = reduce(x, 'b c h w n -> b c h w', 'mean').contiguous()
            
        return x

# 改进的残差模块(iRMB)
class iRMB(nn.Module):
    """
    改进的残差模块(iRMB),融合了空间注意力机制和深度可分离卷积
    
    核心特点:
    - 可配置的空间注意力机制
    - 深度可分离卷积
    - Squeeze-and-Excitation模块
    - 灵活的规范化和激活函数选择
    """
    def __init__(self, dim_in, dim_out, norm_in=True, has_skip=True, exp_ratio=1.0, 
                 norm_layer='bn_2d', act_layer='relu', v_proj=True, dw_ks=3, stride=1, 
                 dilation=1, se_ratio=0.0, dim_head=64, window_size=7, attn_s=True, 
                 qkv_bias=False, attn_drop=0., drop=0., drop_path=0., v_group=False, attn_pre=False):
        super().__init__()
        
        # 规范化层
        self.norm = get_norm(norm_layer)(dim_in) if norm_in else nn.Identity()
        
        # 中间维度
        dim_mid = int(dim_in * exp_ratio)
        
        # 是否使用跳跃连接
        self.has_skip = (dim_in == dim_out and stride == 1) and has_skip
        
        # 空间注意力配置
        self.attn_s = attn_s
        if self.attn_s:
            assert dim_in % dim_head == 0, 'dim should be divisible by num_heads'
            self.dim_head = dim_head
            self.window_size = window_size
            self.num_head = dim_in // dim_head
            self.scale = self.dim_head ** -0.5
            self.attn_pre = attn_pre
            
            # 注意力机制的QK投影
            self.qk = ConvNormAct(dim_in, int(dim_in * 2), kernel_size=1, bias=qkv_bias, 
                                  norm_layer='none', act_layer='none')
            
            # V投影
            self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, 
                                groups=self.num_head if v_group else 1, 
                                bias=qkv_bias, norm_layer='none', 
                                act_layer=act_layer, inplace=inplace)
            
            self.attn_drop = nn.Dropout(attn_drop)
        else:
            # 如果不使用注意力,仅使用V投影
            self.v = ConvNormAct(dim_in, dim_mid, kernel_size=1, bias=qkv_bias, 
                                norm_layer='none', act_layer=act_layer, inplace=inplace)
        
        # 局部卷积(深度可分离卷积)
        self.conv_local = ConvNormAct(dim_mid, dim_mid, kernel_size=dw_ks, stride=stride, 
                                     dilation=dilation, groups=dim_mid, norm_layer='bn_2d', 
                                     act_layer='silu', inplace=inplace)
        
        # Squeeze-and-Excitation模块
        self.se = SE(dim_mid, rd_ratio=se_ratio, act_layer=get_act(act_layer)) if se_ratio > 0.0 else nn.Identity()
        
        # 输出投影和正则化
        self.proj_drop = nn.Dropout(drop)
        self.proj = ConvNormAct(dim_mid, dim_out, kernel_size=1, norm_layer='none', act_layer='none', inplace=inplace)
        self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()

    def forward(self, x):
        """
        前向传播过程
        """
        shortcut = x
        x = self.norm(x)
        B, C, H, W = x.shape
        
        if self.attn_s:
            # 窗口注意力实现
            # 计算窗口大小并处理填充
            if self.window_size <= 0:
                window_size_W, window_size_H = W, H
            else:
                window_size_W, window_size_H = self.window_size, self.window_size
                
            pad_l = pad_t = 0
            pad_r = (window_size_W - W % window_size_W) % window_size_W
            pad_b = (window_size_H - H % window_size_H) % window_size_H
            
            x = F.pad(x, (pad_l, pad_r, pad_t, pad_b, 0, 0,))
            n1, n2 = (H + pad_b) // window_size_H, (W + pad_r) // window_size_W
            
            # 重塑为窗口形式
            x = rearrange(x, 'b c (h1 n1) (w1 n2) -> (b n1 n2) c h1 w1', n1=n1, n2=n2).contiguous()
            
            # 计算注意力
            b, c, h, w = x.shape
            qk = self.qk(x)
            qk = rearrange(qk, 'b (qk heads dim_head) h w -> qk b heads (h w) dim_head', 
                         qk=2, heads=self.num_head, dim_head=self.dim_head).contiguous()
            q, k = qk[0], qk[1]
            
            # 计算注意力得分并应用softmax
            attn_spa = (q @ k.transpose(-2, -1)) * self.scale
            attn_spa = attn_spa.softmax(dim=-1)
            attn_spa = self.attn_drop(attn_spa)
            
            # 应用注意力(两种模式)
            if self.attn_pre:
                x = rearrange(x, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
                x_spa = attn_spa @ x
                x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', 
                                 heads=self.num_head, h=h, w=w).contiguous()
                x_spa = self.v(x_spa)
            else:
                v = self.v(x)
                v = rearrange(v, 'b (heads dim_head) h w -> b heads (h w) dim_head', heads=self.num_head).contiguous()
                x_spa = attn_spa @ v
                x_spa = rearrange(x_spa, 'b heads (h w) dim_head -> b (heads dim_head) h w', 
                                 heads=self.num_head, h=h, w=w).contiguous()
            
            # 恢复原始形状并移除填充
            x = rearrange(x_spa, '(b n1 n2) c h1 w1 -> b c (h1 n1) (w1 n2)', n1=n1, n2=n2).contiguous()
            if pad_r > 0 or pad_b > 0:
                x = x[:, :, :H, :W].contiguous()
        else:
            # 不使用注意力时,直接应用V投影
            x = self.v(x)
        
        # 应用局部卷积和SE模块
        x = x + self.se(self.conv_local(x)) if self.has_skip else self.se(self.conv_local(x))
        
        # 输出投影和跳跃连接
        x = self.proj_drop(x)
        x = self.proj(x)
        x = (shortcut + self.drop_path(x)) if self.has_skip else x
        
        return x

# 测试代码
if __name__ == '__main__':
    input = torch.randn(3, 64, 64, 64).cuda()  # 创建一个随机输入
    model = iRMB(64, 64).cuda()  # 实例化注意力模块
    output = model(input)
    print(f"输入形状: {input.shape}")
    print(f"输出形状: {output.shape}")  # 打印输出形状

35、AFT模块

论文《An Attention Free Transformer》

1、作用

**注意力自由变换器(AFT)**旨在通过去除传统Transformer中的点积自注意力机制,提供一种更高效的变换器模型。它特别适用于需要高计算效率和较低内存消耗的应用场景,如移动设备和边缘计算。

2、机制

AFT通过直接对输入特征进行变换来实现序列间的关联,不再需要复杂的自注意力计算。它使用一种简单的基于位置的加权策略,通过这种方式,每个输出元素是输入元素的加权和,权重由元素的相对位置决定。这种方法极大地降低了模型的复杂性和运行时内存需求。

3、独特优势

1、高效性:AFT由于避免了昂贵的自注意力计算,因此在执行速度和计算效率上有明显优势。

2、简化模型结构:通过消除自注意力机制,AFT简化了模型结构,使得模型更加轻量化,易于实现和部署。

3、适应性强:AFT的结构使其更容易适应于不同的任务和数据集,具有良好的泛化能力。

4、资源占用低:对于资源受限的环境,如移动设备和边缘计算设备,AFT提供了一种实用的解决方案,能够在保持较高性能的同时,降低资源消耗。

4、代码

import torch
import torch.nn as nn
import torch.nn.init as init


class AFT_FULL(nn.Module):
    """
    AFT-FULL模块(Attention Free Transformer - Full)
    论文:《An Attention Free Transformer》
    核心思想:用位置偏置的元素级操作替代传统自注意力的矩阵乘法,降低计算复杂度
    """
    def __init__(self, d_model, n=49, simple=False):
        """
        Args:
            d_model: 特征维度(输入/输出通道数)
            n: 序列长度(如图像分块后的patch数量)
            simple: 是否使用简单模式(位置偏置为固定零矩阵,否则为可学习参数)
        """
        super().__init__()
        # QKV线性投影层(无偏置,简化设计)
        self.fc_q = nn.Linear(d_model, d_model, bias=False)
        self.fc_k = nn.Linear(d_model, d_model, bias=False)
        self.fc_v = nn.Linear(d_model, d_model, bias=False)
        
        # 位置偏置(建模序列中元素的相对位置关系)
        if simple:
            # 简单模式:固定零矩阵,无位置信息学习
            self.register_buffer('position_biases', torch.zeros((n, n)))
        else:
            # 标准模式:可学习的位置偏置矩阵
            self.position_biases = nn.Parameter(torch.ones((n, n)))
        
        self.d_model = d_model
        self.n = n  # 序列长度(如H*W的patch数)
        self.sigmoid = nn.Sigmoid()  # 门控机制激活函数
        
        # 初始化权重
        self.init_weights()

    def init_weights(self):
        """参数初始化:确保线性层权重初始化合理"""
        for m in [self.fc_q, self.fc_k, self.fc_v]:
            if isinstance(m, nn.Linear):
                # 正态分布初始化,标准差与特征维度相关
                init.normal_(m.weight, std=self.d_model ** -0.5)
        # 位置偏置初始化(若为可学习参数)
        if hasattr(self.position_biases, 'requires_grad') and self.position_biases.requires_grad:
            init.constant_(self.position_biases, 0.1)  # 小初始值避免softmax饱和

    def forward(self, input):
        """
        前向传播:用元素级操作替代自注意力的矩阵乘法
        
        Args:
            input: 输入特征,形状为 [batch_size, seq_len, d_model]
        Returns:
            输出特征,形状为 [batch_size, seq_len, d_model]
        """
        bs, n, dim = input.shape  # bs: 批大小;n: 序列长度;dim: 特征维度
        
        # 1. QKV投影
        q = self.fc_q(input)  # [bs, n, dim]:查询向量
        k = self.fc_k(input).view(1, bs, n, dim)  # [1, bs, n, dim]:扩展维度用于广播
        v = self.fc_v(input).view(1, bs, n, dim)  # [1, bs, n, dim]:扩展维度用于广播
        
        # 2. 位置偏置融合(核心创新点)
        # 位置偏置矩阵重塑为 [n, 1, n, 1],与k的 [1, bs, n, dim] 广播匹配
        pos_bias = self.position_biases.view(n, 1, -1, 1)  # [n, 1, n, 1]
        
        # 3. 加权求和(替代注意力分数的softmax加权)
        # 分子:exp(k + 位置偏置) * v 的累加(按序列维度)
        numerator = torch.sum(torch.exp(k + pos_bias) * v, dim=2)  # [n, bs, dim]
        # 分母:exp(k + 位置偏置) 的累加(归一化因子)
        denominator = torch.sum(torch.exp(k + pos_bias), dim=2)  # [n, bs, dim]
        
        # 4. 门控机制(查询向量调制)
        # 计算加权平均并转置为 [bs, n, dim]
        weighted_v = (numerator / denominator).permute(1, 0, 2)  # [bs, n, dim]
        # 用sigmoid(q)作为门控,过滤无关特征
        out = self.sigmoid(q) * weighted_v  # [bs, n, dim]
        
        return out


# 测试代码
if __name__ == '__main__':
    # 实例化AFT-FULL模块(特征维度512,序列长度64)
    model = AFT_FULL(d_model=512, n=64, simple=False).cuda()
    
    # 随机输入:[batch_size=64, seq_len=64, d_model=512]
    input_tensor = torch.randn(64, 64, 512).cuda()
    
    # 前向传播
    output_tensor = model(input_tensor)
    
    # 验证输出形状
    print(f"输入形状: {input_tensor.shape}")
    print(f"输出形状: {output_tensor.shape}")  # 预期: torch.Size([64, 64, 512])

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐