目录

一、KANs的内存困境与突破思路

1.1 KANs的核心创新与瓶颈

1.2 内存瓶颈根源分析

1.3 创新解决方案:元学习权重生成

二、MetaKANs核心技术解析

2.1 元学习器设计

2.2 提示符机制

2.3 深层网络优化

2.4 参数量分析

三、实验成果与性能突破

3.1 函数拟合任务(Feynman数据集)

3.2 图像分类任务(CIFAR系列)

3.3 高维函数与PDE求解

四、关键技术创新点

4.1 函数类压缩效应

4.2 动态权重生成

4.3 通用框架支持

五、代码实现与使用

5.1 安装与基础使用

5.2 高级配置

5.3 开源资源

六、应用场景与展望

6.1 当前应用优势

6.2 未来方向

附录:性能对比总览


一、KANs的内存困境与突破思路

1.1 KANs的核心创新与瓶颈

Kolmogorov-Arnold Networks(KANs)作为新一代神经网络:

  • 理论基础​:基于Kolmogorov-Arnold表示定理
  • 核心创新​:用可学习的单变量函数替代固定激活函数
    ϕ(t;w)=wb​SiLU(t)+∑i=1G+k​ci​Bi​(t)
  • 性能优势​:在函数拟合、符号回归等任务超越MLP

但面临严重内存瓶颈​:

# KANs参数量计算
def kan_params(layers, G, k):
    total = 0
    for i in range(len(layers)-1):
        total += layers[i] * layers[i+1] * (G + k + 1)
    return total  # 例如[4,5,5,1]结构,G=5时高达646参数

# MLP参数量计算
def mlp_params(layers):
    return sum(layers[i]*layers[i+1] for i in range(len(layers)-1))  # 同结构仅46参数
1.2 内存瓶颈根源分析
网络类型 参数量公式 示例([4,5,5,1])
MLP ∑(nl​×nl+1​) 4×5 + 5×5 + 5×1 = 50
KANs ∑(nl​×nl+1​)×(G+k+1) 50 × (5+3+1) = 450
放大因子 (G+k+1)× 通常9-20倍

当G=20时,KANs参数量可达MLP的24倍,导致:

  1. 训练内存溢出
  2. 计算成本飙升
  3. 无法扩展至大模型
1.3 创新解决方案:元学习权重生成

核心洞察​:KANs中所有激活函数共享相同的函数族F,其参数生成规则可被学习

MetaKANs架构​:

mermaid

graph LR
    A[输入x] --> B(可学习提示符z)
    B --> C[元学习器M_θ]
    C --> D[生成权重w]
    D --> E[KANs激活函数]
    E --> F[输出y]

二、MetaKANs核心技术解析

2.1 元学习器设计
class MetaLearner(nn.Module):
    def __init__(self, hidden_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1, hidden_dim),  # 输入为标量提示符z
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)  # 输出权重向量w
        )
    
    def forward(self, z):
        return self.net(z.unsqueeze(-1)).squeeze()

数学表示​:
w=Mθ​(z),ϕ(t;z,θ)=Mθ​(z)⊤B(t)

2.2 提示符机制
  • 每个激活函数分配唯一可学习提示符zα(l)​∈R
  • 提示符集合​:
    Z=⋃l=0L−1​{zα(l)​∣α∈[nl​]×[nl+1​]}
  • 物理意义​:作为函数标识符
2.3 深层网络优化

对深层KANs采用分层元学习器​:

mermaid

graph TD
    A[网络层] --> B{按通道数聚类}
    B --> C[组1:小通道层]
    B --> D[组2:中通道层]
    B --> E[组3:大通道层]
    C --> F[元学习器θ₁]
    D --> G[元学习器θ₂]
    E --> H[元学习器θ₃]

聚类算法​:

def cluster_layers(channels, n_clusters):
    kmeans = KMeans(n_clusters)
    labels = kmeans.fit_predict(channels.reshape(-1,1))
    clusters = []
    for i in range(n_clusters):
        layer_indices = np.where(labels == i)[0]
        start, end = min(layer_indices), max(layer_indices)
        clusters.append((start, end))
    return clusters
2.4 参数量分析

Params=提示符l=0∑L−1​(nl​×nl+1​)​​+元学习器C×(dhidden​+1)×(G+k+1)​​

关键优势​:当dhidden​≪∑(nl​×nl+1​)时,总参数量≈MLP!

模型 参数量公式 [4,5,5,1]示例
KANs ∑(nl​nl+1​)(G+k+1) 4×5×9 + 5×5×9 + 5×1×9 = 450
MetaKANs ∑(nl​nl+1​)+C(dh​+1)(G+k+1) 20 + 1×(64+1)×9 ≈ 605
实际优化 1/3 - 1/9 实验平均减少89%

三、实验成果与性能突破

3.1 函数拟合任务(Feynman数据集)

表2对比​:

函数 结构 KANs-MSE(G=5) MetaKANs-MSE(G=5) 参数量减少
I.12.5 [2,2,1] 1.32e-3 1.16e-4 57 → 30 (↓47%)
I.6.20 [2,2,1,1] 6.44e-3 2.67e-3 76 → 210 (↑但精度更高)
I.9.18 [6,5,5,5,1] 2.39e-3 1.70e-3 781 → 993

关键发现​:

  1. 多数任务精度优于原KANs
  2. 高维任务优势更显著
  3. 学习函数更简洁

3.2 图像分类任务(CIFAR系列)

对比​(4层ConvKAN):

模型 CIFAR-10准确率 CIFAR-100准确率 参数量
KANConv 41.92% 7.69% 3.49M
MetaKANConv 45.97%​ 9.71%​ 0.39M (↓89%)
FastKANConv 68.12% 34.64% 3.49M
MetaFastKANConv 66.69%​ 32.11%​ 0.39M (↓89%)

关键突破​:

  • 参数量减少89%​但精度持平或提升
  • 训练内存峰值降低3.2倍
  • 决策边界更清晰

3.3 高维函数与PDE求解

表4高维函数拟合​(n=1000):

函数 KANs-MSE MetaKANs-MSE 参数量减少
f1​(x)=exp(n1​∑sin2(2πx​)) 0.414 0.148 9,011 → 713 (↓92%)
f2​(x)=∑x2+x3 168.0 0.143 18,011 → 1,329 (↓93%)

表8 PDE求解​(100D Allen-Cahn方程):

方法 相对ℓ2​误差 参数量
KANs 1.91e-2 47,520
MetaKANs 2.60e-2 6,697 (↓86%)

内存消耗降低86%的同时保持求解精度


四、关键技术创新点

4.1 函数类压缩效应

对比揭示​:

  1. KANs学习冗余函数类​(左图分散)
  2. MetaKANs学习紧凑函数类​(右图聚类)
  3. 权值相似度更高(热力图深色集中)
4.2 动态权重生成
# MetaKANs前向传播
def forward(x, prompts, meta_learner):
    for l in range(len(layers)-1):
        z = prompts[l]  # 当前层提示符
        W = meta_learner(z)  # 动态生成权重
        x = kan_activate(x, W)  # KAN特色激活
    return x
  • 实时生成​:避免存储巨量参数
  • 跨层共享​:元学习器捕捉通用规则
4.3 通用框架支持

可扩展至各类KANs变体:

# 扩展至WavKAN
class MetaWavKAN(MetaKAN):
    def __init__(self, ...):
        super().__init__(out_dim=3)  # 输出(w, μ, σ)
    
    def activate(self, t, w):
        return w[0] * wavelet((t - w[1])/w[2])

# 扩展至FastKAN
class MetaFastKAN(MetaKAN):
    def __init__(self, ...):
        super().__init__(out_dim=c)  # c为RBF中心数

五、代码实现与使用

5.1 安装与基础使用
from metakan import MetaKAN

model = MetaKAN(
    layers=[4, 5, 5, 1],  # 网络结构
    grid=5,                # 网格点数
    k=3,                   # 样条阶数
    hidden_dim=64          # 元学习器隐藏层
)

# 训练循环
for x, y in dataloader:
    y_pred = model(x)
    loss = F.mse_loss(y_pred, y)
    loss.backward()
    optimizer.step()
5.2 高级配置
# 深层网络分组优化
model = DeepMetaKAN(
    layers=[32, 64, 128, 512],
    clusters=3,            # 元学习器分组数
    prompt_dim=2           # 提示符维度
)

# 卷积版本
from metakan.conv import MetaConvKAN
model = MetaConvKAN(in_ch=3, out_ch=10, hidden_dims=[32,64])
5.3 开源资源
  • 代码库​:https://github.com/Murphyzc/MetaKAN
  • 预训练模型:支持Feynman/PDE/ImageNet任务
  • 教程Notebook:包含可视化训练全流程

六、应用场景与展望

6.1 当前应用优势
  1. 科学计算​:高维PDE求解(内存降低86%)
  2. 符号回归​:复杂公式发现(参数量减少47%)
  3. 边缘计算​:移动端部署(峰值内存降3.2倍)
6.2 未来方向
  1. 硬件加速​:定制元学习器指令集
  2. 多模态扩展​:跨任务权重生成
  3. 理论深化​:函数类压缩的泛化界分析
  4. 3D视觉​:点云处理中的高效特征提取

"MetaKANs不仅解决了KANs的内存瓶颈,更开辟了'生成式参数'的新范式" —— 论文作者


附录:性能对比总览

任务类型 数据集 KANs精度 MetaKANs精度 参数量减少
函数拟合 Feynman I.12.5 1.32e-3 1.16e-4 ↓47%
图像分类 CIFAR-100 34.64% 32.11%​ ↓89%
高维逼近 f2​ (n=1000) 168.0 0.143 ↓93%
PDE求解 100D Allen-Cahn 1.91e-2 2.60e-2 ↓86%
训练效率 峰值内存 100% 31.25%​ ↓68.75%

通过元学习生成权重,MetaKANs成功将KANs参数量降至MLP水平,为新一代可解释AI模型的实际部署铺平道路。

论文地址:https://arxiv.org/pdf/2506.07549

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐