KAN 2.0: 科尔莫哥洛夫-阿诺德网络与科学的结合

Github:https://github.com/KindXiaoming/pykan.git

论文:https://arxiv.org/abs/2408.10205

摘要

AI + 科学的一个主要挑战在于它们固有的不兼容性:当今的人工智能主要基于连接主义,而科学依赖于符号主义。为了弥合这两个世界,我们提出了一个框架,以无缝协同科尔莫哥洛夫-阿诺德网络 (KANs) 和科学。该框架强调 KAN 在科学发现的三个方面的应用:识别相关特征、揭示模块化结构和发现符号公式。协同是双向的:科学到 KAN(将科学知识纳入 KAN),以及 KAN 到科学(从 KAN 中提取科学见解)。我们强调 pykan 中的主要新功能:(1)MultKAN:具有乘法节点的 KAN。(2)kanpiler:一个将符号公式编译为 KAN 的 KAN 编译器。(3)树转换器:将 KAN(或任何神经网络)转换为树图。基于这些工具,我们展示了 KAN 发现各种物理定律的能力,包括守恒量、拉格朗日量、对称性和本构定律。

图 1:科学与科尔莫哥洛夫-阿诺德网络 (KAN) 的协同作用。

1 引言

最近几年,AI + 科学作为一个有前景的新领域出现,导致了显著的科学进展,包括蛋白质折叠预测 [37]、自动定理证明 [95, 83]、天气预报 [41] 等等。这些任务的共同点在于,它们都可以很好地被表述为具有明确目标的问题,可以通过黑箱 AI 系统进行优化。虽然这一范式在应用驱动的科学中表现得非常出色,但还有另一种科学存在:好奇驱动的科学。在好奇驱动的研究中,过程更具探索性,往往缺乏超越“获得更多理解”的明确目标。需要澄清的是,好奇驱动的科学远非无用;恰恰相反。通过好奇获得的科学知识和理解常常为未来的技术奠定坚实的基础,并促进广泛的应用。

尽管应用驱动科学和好奇驱动科学都是无价且不可替代的,但它们提出了不同的问题。当天文学家观察天体的运动时,应用驱动的研究者专注于预测它们的未来状态,而好奇驱动的研究者则探索运动背后的物理学。另一个例子是 AlphaFold,尽管它在预测蛋白质结构方面取得了巨大的成功,但仍然属于应用驱动科学的范畴,因为它没有在更基础的层面提供新的知识(例如,原子力)。假设 AlphaFold 必须揭示重要的未知物理学才能实现其高度准确的预测。然而,这些信息对我们来说仍然是隐藏的,使得 AlphaFold 在很大程度上仍然是一个黑箱。因此,我们倡导新的 AI 范式来支持好奇驱动的科学。这种 AI + 科学的新范式要求 AI 工具具有更高的可解释性和互动性,以便能够无缝地融入科学研究中。

最近,一种新的神经网络类型被称为科尔莫戈洛夫-阿诺德网络(KAN)[57],在科学相关任务中显示出潜力。与具有固定激活函数的多层感知器(MLPs)不同,KAN在边缘上具有可学习的激活函数。由于KAN能够将高维函数分解为一维函数,因此可以通过符号回归这些一维函数来获得可解释性。然而,他们对可解释性的定义有些狭隘,几乎将其等同于提取符号公式的能力。这一定义的局限性限制了他们的适用范围,因为在科学中,符号公式并不总是必要或可行的。例如,尽管符号方程在物理学中强大且普遍存在,但在化学和生物学中,系统往往过于复杂,无法用这样的方程表示。在这些领域,模块化结构和关键特征可能足以表征这些系统的有趣方面。另一个被忽视的方面是将知识嵌入KAN的逆任务:我们如何在物理信息学习的精神下,将先验知识融入KAN?

我们增强和扩展KAN,使其更易于用于好奇驱动的科学。本文的目标可以总结如下:

目标:协同科尔莫戈洛夫-阿诺德网络 ⇔ \Leftrightarrow 科学。 ⇐ \Leftarrow :将科学知识构建到KAN中(第3节)。 ⇒ \Rightarrow :从KAN中提取科学知识(第4节)。

更具体地说,科学解释可能具有不同的层次,从最粗糙/最简单/相关的到最细致/最困难/因果的:

  • 重要特征:例如,“ y y y 完全由 x 1 {x}_{1} x1 x 2 {x}_{2} x2 决定,而其他因素并不重要。”换句话说,存在一个函数 f f f 使得 y = f ( x 1 , x 2 ) y = f\left( {{x}_{1},{x}_{2}}\right) y=f(x1,x2)

  • 模块结构:例如,“ x 1 {x}_{1} x1 x 2 {x}_{2} x2 独立地以加法方式贡献于 y y y 。”这意味着存在函数 g g g h h h ,使得 y = g ( x 1 ) + h ( x 2 ) y = g\left( {x}_{1}\right) + h\left( {x}_{2}\right) y=g(x1)+h(x2)

  • 符号公式:例如,“ y y y 依赖于 x 1 {x}_{1} x1 作为正弦函数,并依赖于 x 2 {x}_{2} x2 作为指数函数。”换句话说, y = sin ⁡ ( x 1 ) + exp ⁡ ( x 2 ) y = \sin \left( {x}_{1}\right) + \exp \left( {x}_{2}\right) y=sin(x1)+exp(x2)

本文报告了如何从 KAN 中整合和提取这些属性。论文的结构如下(如图 1 所示):在第 2 节中,我们通过引入乘法节点来增强原始 KAN,提出了一种新的模型称为 MultKAN。在第 3 节中,我们探讨了将科学归纳偏见嵌入 KAN 的方法,重点关注重要特征(第 3.1 节)、模块结构(第 3.2 节)和符号公式(第 3.3 节)。在第 4 节中,我们提出了从 KAN 中提取科学知识的方法,再次涵盖重要特征(第 4.1 节)、模块结构(第 4.2 节)和符号公式(第 4.3 节)。在第 5 节中,我们使用前面章节中开发的工具将 KAN 应用于各种科学发现任务。这些任务包括发现守恒量、对称性、拉格朗日量和本构定律。代码可在 https://github.com/KindXiaoming/pykan 获取,也可以通过 pip install pykan 安装。尽管论文的标题为“KAN 2.0”,但 pykan 的发布版本为 0.2.x。

图 2:顶部:比较 KAN 和 MultKAN 图示。MultKAN 具有额外的乘法层 M。底部:在 f ( x , y ) = x y f\left( {x, y}\right) = {xy} f(x,y)=xy 上训练后,KAN 学习到一个需要两个加法节点的算法,而 MultKAN 只需要一个乘法节点。

2 MultKAN: 用乘法增强 KAN

科尔莫戈洛夫-阿诺德表示定理 (KART) 说明任何连续的高维函数都可以分解为有限个单变量连续函数和加法的组合: f ( x ) = f ( x 1 , ⋯   , x n ) = ∑ q = 1 2 n + 1 Φ q ( ∑ p = 1 n ϕ q , p ( x p ) ) . (1) f\left( \mathbf{x}\right) = f\left( {{x}_{1},\cdots ,{x}_{n}}\right) = \mathop{\sum }\limits_{{q = 1}}^{{{2n} + 1}}{\Phi }_{q}\left( {\mathop{\sum }\limits_{{p = 1}}^{n}{\phi }_{q, p}\left( {x}_{p}\right) }\right) . \tag{1} f(x)=f(x1,,xn)=q=12n+1Φq(p=1nϕq,p(xp)).(1) 这意味着加法是唯一真正的多变量运算,而其他多变量运算(包括乘法)可以表示为与单变量函数结合的加法。例如,要乘以两个正数 x x x y y y ,我们可以将其表示为 x y = exp ⁡ ( log ⁡ x + log ⁡ y ) 2 {xy} = \exp {\left( \log x + \log y\right) }^{2} xy=exp(logx+logy)2 ,其右侧仅由加法和单变量函数(对数和指数)组成。然而,考虑到乘法在科学和日常生活中的普遍性,显式地在 KAN 中包含乘法是可取的,这可能增强可解释性和能力。

2 {}^{2} 2 如果 x x x y y y 可以为负数,可以选择一个大的 c > 0 c > 0 c>0 并表示 x y = exp ⁡ ( log ⁡ ( x + c ) + log ⁡ ( y + {xy} = \exp (\log \left( {x + c}\right) + \log (y + xy=exp(log(x+c)+log(y+ c ) ) − c ( x + y ) − c 2 c)) - c\left( {x + y}\right) - {c}^{2} c))c(x+y)c2 。其他构造包括二次函数,如 x y = ( ( x + y ) 2 − ( x − y ) 2 ) / 4 {xy} = \left( {{\left( x + y\right) }^{2} - {\left( x - y\right) }^{2}}\right) /4 xy=((x+y)2(xy)2)/4 x y = ( ( x + y ) 2 − x 2 − y 2 ) / 2 {xy} = \left( {{\left( x + y\right) }^{2} - {x}^{2} - {y}^{2}}\right) /2 xy=((x+y)2x2y2)/2


科尔莫戈洛夫-阿诺德网络 (KAN) 虽然 KART 方程 (1) 对应于一个两层网络,刘等人 [57] 设法通过认识到看似不同的外部函数 Φ q {\Phi }_{q} Φq 和内部函数 ϕ q , p {\phi }_{q, p} ϕq,p 可以通过他们提出的 KAN 层统一,扩展到任意深度。深度 L L L KAN 可以通过堆叠 L L L KAN 层简单构建。深度 L L L KAN 的形状由整数数组 [ n 0 , n 1 , ⋯   , n L ] \left\lbrack {{n}_{0},{n}_{1},\cdots ,{n}_{L}}\right\rbrack [n0,n1,,nL] 表示,其中 n l {n}_{l} nl 表示 l th  {l}^{\text{th }} lth  神经层中的神经元数量。具有 n l {n}_{l} nl 输入维度和 n l + 1 {n}_{l + 1} nl+1 输出维度的 l th  {l}^{\text{th }} lth  KAN 层,将输入向量 x l ∈ R n l {\mathbf{x}}_{l} \in {\mathbb{R}}^{{n}_{l}} xlRnl 转换为 x l + 1 ∈ R n l + 1 {\mathbf{x}}_{l + 1} \in {\mathbb{R}}^{{n}_{l + 1}} xl+1Rnl+1 x l + 1 = ( ϕ l , 1 , 1 ( ⋅ ) ϕ l , 2 , 1 ( ⋅ ) ⋯ ϕ l , n l , 1 ( ⋅ ) ϕ l , 1 , 2 ( ⋅ ) ϕ l , 2 , 2 ( ⋅ ) ⋯ ϕ l , n l , 2 ( ⋅ ) ⋮ ⋮ ⋮ ϕ l , 1 , n l + 1 ( ⋅ ) ϕ l , 2 , n l + 1 ( ⋅ ) ⋯ ϕ l , n l , n l + 1 ( ⋅ ) ) ⏟ Φ l x l , (2) {\mathbf{x}}_{l + 1} = \underset{{\mathbf{\Phi }}_{l}}{\underbrace{\left( \begin{matrix} {\phi }_{l,1,1}\left( \cdot \right) & {\phi }_{l,2,1}\left( \cdot \right) & \cdots & {\phi }_{l,{n}_{l},1}\left( \cdot \right) \\ {\phi }_{l,1,2}\left( \cdot \right) & {\phi }_{l,2,2}\left( \cdot \right) & \cdots & {\phi }_{l,{n}_{l},2}\left( \cdot \right) \\ \vdots & \vdots & & \vdots \\ {\phi }_{l,1,{n}_{l + 1}}\left( \cdot \right) & {\phi }_{l,2,{n}_{l + 1}}\left( \cdot \right) & \cdots & {\phi }_{l,{n}_{l},{n}_{l + 1}}\left( \cdot \right) \end{matrix}\right) }}{\mathbf{x}}_{l}, \tag{2} xl+1=Φl ϕl,1,1()ϕl,1,2()ϕl,1,nl+1()ϕl,2,1()ϕl,2,2()ϕl,2,nl+1()ϕl,nl,1()ϕl,nl,2()ϕl,nl,nl+1() xl,(2) 整个网络是 L L L KAN 层的组合,即, KAN ⁡ ( x ) = ( Φ L − 1 ∘ ⋯ ∘ Φ 1 ∘ Φ 0 ) x . (3) \operatorname{KAN}\left( \mathbf{x}\right) = \left( {{\mathbf{\Phi }}_{L - 1} \circ \cdots \circ {\mathbf{\Phi }}_{1} \circ {\mathbf{\Phi }}_{0}}\right) \mathbf{x}. \tag{3} KAN(x)=(ΦL1Φ1Φ0)x.(3) 在图示中,KAN 可以直观地可视化为由节点(求和)和边(可学习的激活)组成的网络,如图 2 左上角所示。当在从 f ( x , y ) = x y f\left( {x, y}\right) = {xy} f(x,y)=xy 生成的数据集上进行训练时,KAN(图 2 左下角)使用了两个加法节点,这使得网络的功能不够明确。然而,经过一些考虑,我们意识到它利用了等式 x y = ( ( x + y ) 2 − ( x − y ) 2 ) / 4 {xy} = \left( {{\left( x + y\right) }^{2} - {\left( x - y\right) }^{2}}\right) /4 xy=((x+y)2(xy)2)/4 ,但这并不明显。

乘法科尔莫哥洛夫-阿诺德网络(MultKAN) 为了明确引入乘法操作,我们提出了 MultKAN,它可以更清晰地揭示数据中的乘法结构。MultKAN(如图 2 右上角所示)类似于 KAN,两者都具有标准的 KAN 层。我们将 KAN 层的输入节点称为节点,将 KAN 层的输出节点称为子节点。KAN 和 MultKAN 之间的区别在于从当前层的子节点到下一层的节点的转换。在 KAN 中,节点直接从前一层的子节点复制。在 MultKAN 中,一些节点(加法节点)从相应的子节点复制,而其他节点(乘法节点)对来自前一层的 k k k 子节点进行乘法运算。为了简化,我们将 k = 2 k = 2 k=2 设置为小于 3。

基于 MultKAN 图(图 2 右上角),可以直观地理解 MultKAN 是一个正常的 KAN,其中插入了可选的乘法。为了数学上的精确性,我们定义以下符号:层 l l l 中的加法(乘法)操作数量分别表示为 n l a ( n l m ) {n}_{l}^{a}\left( {n}_{l}^{m}\right) nla(nlm) 。这些被收集成数组:加法宽度 n a ≡ [ n 0 a , n 1 a , ⋯   , n L a ] {\mathbf{n}}^{a} \equiv \left\lbrack {{n}_{0}^{a},{n}_{1}^{a},\cdots ,{n}_{L}^{a}}\right\rbrack na[n0a,n1a,,nLa] 和乘法宽度 n m ≡ [ n 0 m , n 1 m , ⋯   , n L m ] {\mathbf{n}}^{m} \equiv \left\lbrack {{n}_{0}^{m},{n}_{1}^{m},\cdots ,{n}_{L}^{m}}\right\rbrack nm[n0m,n1m,,nLm] 。当 n 0 m = n 1 m = {n}_{0}^{m} = {n}_{1}^{m} = n0m=n1m= ⋯ = n L m = 0 \cdots = {n}_{L}^{m} = 0 =nLm=0 时,MultKAN 简化为 KAN。例如,图 2(右上角)显示了一个具有 n a = [ 2 , 2 , 1 ] {\mathbf{n}}^{a} = \left\lbrack {2,2,1}\right\rbrack na=[2,2,1] n m = [ 0 , 2 , 0 ] {\mathbf{n}}^{m} = \left\lbrack {0,2,0}\right\rbrack nm=[0,2,0] 的 MultKAN。

MultKAN 层由标准的 KANLayer Φ l {\mathbf{\Phi }}_{l} Φl 和一个乘法层组成,该乘法层接收输入向量 x l ∈ R n l a + n l m {\mathbf{x}}_{l} \in {\mathbb{R}}^{{n}_{l}^{a} + {n}_{l}^{m}} xlRnla+nlm 并输出 z l = Φ l ( x ) ∈ R n l + 1 a + 2 n l + 1 m {\mathbf{z}}_{l} = {\mathbf{\Phi }}_{l}\left( \mathbf{x}\right) \in {\mathbb{R}}^{{n}_{l + 1}^{a} + 2{n}_{l + 1}^{m}} zl=Φl(x)Rnl+1a+2nl+1m 。乘法层由两部分组成:乘法部分对子节点对进行乘法运算,而另一部分执行恒等变换。用 Python 编写, M l {\mathbf{M}}_{l} Ml z l {\mathbf{z}}_{l} zl 转换如下: M l ( z l ) = concatenate ⁡ ( z l [ : n l + 1 a ] , z l [ n l + 1 a : : 2 ] ⊙ z l [ n l + 1 a + 1 : : 2 ] ) ∈ R n l + 1 a + n l + 1 m , (4) {\mathbf{M}}_{l}\left( {\mathbf{z}}_{l}\right) = \operatorname{concatenate}\left( {{\mathbf{z}}_{l}\left\lbrack { : {n}_{l + 1}^{a}}\right\rbrack ,{\mathbf{z}}_{l}\left\lbrack {{n}_{l + 1}^{a} : : 2}\right\rbrack \odot {\mathbf{z}}_{l}\left\lbrack {{n}_{l + 1}^{a} + 1 : : 2}\right\rbrack }\right) \in {\mathbb{R}}^{{n}_{l + 1}^{a} + {n}_{l + 1}^{m}}, \tag{4} Ml(zl)=concatenate(zl[:nl+1a],zl[nl+1a::2]zl[nl+1a+1::2])Rnl+1a+nl+1m,(4) 其中 ⊙ \odot 是逐元素乘法。MultKANLayer 可以简洁地表示为 Ψ l ≡ {\mathbf{\Psi }}_{l} \equiv Ψl M l ∘ Φ l {\mathbf{M}}_{l} \circ {\mathbf{\Phi }}_{l} MlΦl 。整个 MultKAN 因此为: MultKAN ⁡ ( x ) = ( Ψ L ∘ Ψ L − 1 ∘ ⋯ ∘ Ψ 1 ∘ Ψ 0 ) x . (5) \operatorname{MultKAN}\left( \mathbf{x}\right) = \left( {{\mathbf{\Psi }}_{L} \circ {\mathbf{\Psi }}_{L - 1} \circ \cdots \circ {\mathbf{\Psi }}_{1} \circ {\mathbf{\Psi }}_{0}}\right) \mathbf{x}. \tag{5} MultKAN(x)=(ΨLΨL1Ψ1Ψ0)x.(5) 由于乘法层中没有可训练的参数,因此所有稀疏正则化技术(例如, ℓ 1 {\ell }_{1} 1 和熵正则化)可以直接应用于 MultKANs。

3 {}^{3} 3 为了简化,我们设置 k = 2 k = 2 k=2 ,但 pykan 包允许 k k k 为任何整数 k ≥ 2 k \geq 2 k2 。用户甚至可以为不同的乘法节点设置不同的 k k k 值。然而,如果在同一层内使用不同的 k    s k\mathrm{\;s} ks 值,则可能会很难并行化这些乘法。


图 3:向输入中添加辅助变量增强了解释性。对于相对论质量方程 m = m 0 / 1 − v 2 / c 2 m = {m}_{0}/\sqrt{1 - {v}^{2}/{c}^{2}} m=m0/1v2/c2 ,如果仅使用 ( m 0 , v , c ) \left( {{m}_{0}, v, c}\right) (m0,v,c) 作为输入,则需要一个两层 KAN。(b) 如果我们将 β ≡ v / c \beta \equiv v/c βv/c γ ≡ 1 / 1 − β 2 \gamma \equiv 1/\sqrt{1 - {\beta }^{2}} γ1/1β2 作为辅助变量添加到 KAN 中,则一个单层 KAN 足够(种子 0)。© 种子 1 找到了一个不同的解决方案,该方案是次优的,可以通过假设检验避免(第 4.3 节)。

在乘法任务 f ( x , y ) = x y f\left( {x, y}\right) = {xy} f(x,y)=xy 中,MultKAN 确实学会使用一个乘法节点,使其能够执行简单的乘法,因为所有学习到的激活函数都是线性的(图 2 右下角)。

尽管 KAN 之前被视为 MultKAN 的特例,但我们扩展了定义,将“KAN”和“MultKAN”视为同义词。默认情况下,当我们提到 KAN 时,允许进行乘法。如果我们特别提到没有乘法的 KAN,我们将明确说明。

3 KAN 的科学

在科学中,领域知识至关重要,使我们即使在数据很少或没有的情况下也能有效工作。因此,采用物理启发的方法对 KAN 是有益的:我们应该在 KAN 中融入可用的归纳偏见,同时保持其从数据中发现新物理的灵活性。

我们探讨可以整合到 KAN 中的三种类型的归纳偏见。从最粗糙/最简单/相关的到最精细/最困难/因果的,它们是重要特征(第 3.1 节)、模块化结构(第 3.2 节)和符号公式(第 3.3 节)。

3.1 向 KAN 添加重要特征

在回归问题中,目标是找到一个函数 f f f 使得 y = f ( x 1 , x 2 , ⋯   , x n ) y = f\left( {{x}_{1},{x}_{2},\cdots ,{x}_{n}}\right) y=f(x1,x2,,xn) 。假设我们想引入一个辅助输入变量 a = a ( x 1 , x 2 , … , x n ) a = a\left( {{x}_{1},{x}_{2},\ldots ,{x}_{n}}\right) a=a(x1,x2,,xn) ,将函数转换为 y = f ( x 1 , ⋯   , x n , a ) y = f\left( {{x}_{1},\cdots ,{x}_{n}, a}\right) y=f(x1,,xn,a) 。尽管辅助变量 a a a 并没有增加新的信息,但它可以增强神经网络的表达能力。这是因为网络不需要消耗资源来计算辅助变量。此外,计算可能变得更简单,从而提高可解释性。用户可以使用 augment_input 方法将辅助特征添加到输入中:

model.augment_input(original_variables, auxiliary_variables, dataset)

作为一个例子,考虑相对论质量的公式 m ( m 0 , v , c ) = m 0 / 1 − ( v / c ) 2 m\left( {{m}_{0}, v, c}\right) = {m}_{0}/\sqrt{1 - {\left( v/c\right) }^{2}} m(m0,v,c)=m0/1(v/c)2 ,其中 m 0 {m}_{0} m0 是静止质量, v v v 是点质量的速度,而 c c c 是光速。由于物理学家经常处理无量纲数 β ≡ v / c \beta \equiv v/c βv/c γ ≡ 1 / 1 − β 2 ≡ 1 / 1 − ( v / c ) 2 \gamma \equiv 1/\sqrt{1 - {\beta }^{2}} \equiv 1/\sqrt{1 - {\left( v/c\right) }^{2}} γ1/1β2 1/1(v/c)2 ,他们可能会引入 β \beta β γ \gamma γ 作为输入,与 v v v c c c 一起使用。图3 显示了带有和不带有这些辅助变量的 KAN:(a)展示了从符号公式编译的 KAN(有关 KAN 编译器,请参见第 3.3 节),需要 5 条边;(b)©显示了带有辅助变量的 KAN,仅需 2 或 3 条边,并分别实现了 10 − 6 {10}^{-6} 106 10 − 4 {10}^{-4} 104 的损失。注意(b)和©仅在随机种子上有所不同。种子 1 代表一个次优解,因为它也将 β = v / c \beta = v/c β=v/c 识别为一个关键特征。这并不令人惊讶,因为在经典极限下 v ≪ c , γ ≡ 1 / 1 − ( v / c ) 2 ≈ v \ll c,\gamma \equiv 1/\sqrt{1 - {\left( v/c\right) }^{2}} \approx vc,γ1/1(v/c)2 1 + ( v / c ) 2 / 2 = 1 + β 2 / 2 1 + {\left( v/c\right) }^{2}/2 = 1 + {\beta }^{2}/2 1+(v/c)2/2=1+β2/2 。由于不同种子引起的变化可以被视为特征或缺陷:作为特征,这种多样性可以帮助找到次优解,这些解可能仍然提供有趣的见解;作为缺陷,可以使用第 4.3 节中提出的假设检验方法加以消除。

图 4:构建 KAN 的模块化结构:(a)乘法可分离性;(b)对称性。

3.2 构建 KAN 的模块化结构

模块化在自然界中普遍存在:例如,人类大脑皮层被划分为几个功能上不同的模块,每个模块负责特定的任务,如感知或决策。这种模块化简化了对神经网络的理解,因为它使我们能够集体解释神经元的簇,而不是逐个分析每个神经元。结构模块化的特点是连接簇,其中簇内连接远强于簇间连接。为了强化模块化,我们引入了模块方法,该方法保留了簇内连接,同时移除簇间连接。模块由用户指定。其语法为

model.module(start_layer_id, '[nodes_id] ->[subnodes_id] ->[nodes_id] … ')(7)

例如,如果用户想将特定的节点/子节点分配给一个模块——比如,第1层的 0 th  {0}^{\text{th }} 0th  节点,第1层的 1 st  {1}^{\text{st }} 1st  3 rd  {3}^{\text{rd }} 3rd  子节点,第2层的 1 st  {1}^{\text{st }} 1st  3 rd  {3}^{\text{rd }} 3rd  节点——他们可能会使用模块 ( 1 , 4 [ 0 ] − > [ 1 , 3 ] − > [ 1 , 3 ] 7 ) \left( {1,{}^{4}\left\lbrack 0\right\rbrack - > \left\lbrack {1,3}\right\rbrack - > \left\lbrack {1,3}\right\rbrack {}^{7}}\right) (1,4[0]>[1,3]>[1,3]7) 。具体来说,有两种类型的模块化:可分离性和对称性。

可分离性 我们说一个函数被认为是可分离的,如果它可以表示为非重叠变量组的函数之和或乘积。例如,一个四变量函数 f ( x 1 , x 2 , x 3 , x 4 ) f\left( {{x}_{1},{x}_{2},{x}_{3},{x}_{4}}\right) f(x1,x2,x3,x4) 是最大乘法可分离的,如果它具有形式 f 1 ( x 1 ) f 2 ( x 2 ) f 3 ( x 3 ) f 4 ( x 4 ) {f}_{1}\left( {x}_{1}\right) {f}_{2}\left( {x}_{2}\right) {f}_{3}\left( {x}_{3}\right) {f}_{4}\left( {x}_{4}\right) f1(x1)f2(x2)f3(x3)f4(x4) ,形成四个不同的组 ( 1 ) , ( 2 ) , ( 3 ) , ( 4 ) \left( 1\right) ,\left( 2\right) ,\left( 3\right) ,\left( 4\right) (1),(2),(3),(4) 。用户可以通过调用模块方法四次来创建这些模块:模块 ( 0 , 4 [ i ] − > [ i ] 3 ) , i = 0 , 1 , 2 , 3 \left( {0,{}^{4}\left\lbrack i\right\rbrack - > \left\lbrack i\right\rbrack {}^{3}}\right) , i = 0,1,2,3 (0,4[i]>[i]3),i=0,1,2,3 ,如图4 (a)所示。最后一次调用可以省略,因为前三次足以定义组。乘法可分离性的较弱形式可能是 f 1 ( x 1 , x 2 ) f 2 ( x 3 , x 4 ) {f}_{1}\left( {{x}_{1},{x}_{2}}\right) {f}_{2}\left( {{x}_{3},{x}_{4}}\right) f1(x1,x2)f2(x3,x4) (调用模块 ( 0 , t [ 0 , 1 ] − > [ 0 , 1 ] r ) \left( {0,{}^{t}\left\lbrack {0,1}\right\rbrack - > \left\lbrack {0,1}\right\rbrack {}^{r}}\right) (0,t[0,1]>[0,1]r) )或 f 1 ( x 1 ) f 2 ( x 2 , x 3 , x 4 ) {f}_{1}\left( {x}_{1}\right) {f}_{2}\left( {{x}_{2},{x}_{3},{x}_{4}}\right) f1(x1)f2(x2,x3,x4) (调用模块 ( 0 , t [ 0 ] − > [ 0 ] r ) \left( {0,{}^{t}\left\lbrack 0\right\rbrack - > \left\lbrack 0\right\rbrack {}^{r}}\right) (0,t[0]>[0]r) )。

广义对称性 我们说一个函数在变量 ( x 1 , x 2 ) \left( {{x}_{1},{x}_{2}}\right) (x1,x2) 中是对称的,如果 f ( x 1 , x 2 , x 3 , ⋯ ) = f\left( {{x}_{1},{x}_{2},{x}_{3},\cdots }\right) = f(x1,x2,x3,)= g ( h ( x 1 , x 2 ) , x 3 , ⋯ ) g\left( {h\left( {{x}_{1},{x}_{2}}\right) ,{x}_{3},\cdots }\right) g(h(x1,x2),x3,) 。这个属性被称为对称性,因为只要 h ( x 1 , x 2 ) h\left( {{x}_{1},{x}_{2}}\right) h(x1,x2) 保持不变, f f f 的值就不会改变,即使 x 1 {x}_{1} x1 x 2 {x}_{2} x2 发生变化。例如,一个函数 f f f 2 D 2\mathrm{D} 2D 中是旋转不变的,如果 f ( x 1 , x 2 ) = g ( r ) f\left( {{x}_{1},{x}_{2}}\right) = g\left( r\right) f(x1,x2)=g(r) ,其中 r ≡ x 1 2 + x 2 2 r \equiv \sqrt{{x}_{1}^{2} + {x}_{2}^{2}} rx12+x22 。当对称性仅涉及一部分变量时,它可以被视为分层的,因为 x 1 {x}_{1} x1 x 2 {x}_{2} x2 首先通过 h h h 进行交互(2层 KAN),然后 h h h 通过 g g g 与其他变量交互(2层 KAN)。假设一个四变量函数具有分层形式 f ( x 1 , x 2 , x 3 , x 4 ) = h ( f ( x 1 , x 2 ) , g ( x 3 , x 4 ) ) f\left( {{x}_{1},{x}_{2},{x}_{3},{x}_{4}}\right) = h\left( {f\left( {{x}_{1},{x}_{2}}\right) , g\left( {{x}_{3},{x}_{4}}\right) }\right) f(x1,x2,x3,x4)=h(f(x1,x2),g(x3,x4)) ,如图 4 (b) 所示。我们可以使用模块方法通过调用模块 ( 0 , 1 [ 0 , 1 ] − > [ 0 , 1 ] − > [ 0 , 1 ] − > [ 0 ] ′ ) \left( {0,{}^{1}\left\lbrack {0,1}\right\rbrack - > \left\lbrack {0,1}\right\rbrack - > \left\lbrack {0,1}\right\rbrack - > \left\lbrack 0\right\rbrack {}^{\prime }}\right) (0,1[0,1]>[0,1]>[0,1]>[0]) 来创建这个结构,确保变量组 ( x 1 , x 2 ) \left( {{x}_{1},{x}_{2}}\right) (x1,x2) ( x 3 , x 4 ) \left( {{x}_{3},{x}_{4}}\right) (x3,x4) 在前两层中不发生交互。

图 5:KAN 编译器(kanpiler)将符号表达式转换为 KAN。 (a) kanpiler 的工作原理:符号公式首先被解析为表达式树,然后转换为 KAN。 (b) 将 KAN 应用于 10 个方程(从费曼数据集中选取)。 © 扩展已编译的 KAN 以增强其表达能力。

3.3 将符号公式编译为 KAN

科学家们常常在通过符号方程表示复杂现象中找到满足感。然而,尽管这些方程简洁,但由于其特定的功能形式,它们可能缺乏捕捉所有细微差别所需的表达能力。相比之下,神经网络具有高度的表达能力,但可能低效地花费训练时间和数据来学习科学家们已经知道的领域知识。为了利用这两种方法的优势,我们提出了一种两步程序:(1)将符号方程编译为 KAN,并(2)使用数据对这些 KAN 进行微调。第一步旨在将已知的领域知识嵌入到 KAN 中,而第二步则侧重于从数据中学习新的“物理”知识。

kanpiler(KAN 编译器)kanpiler 的目标是将符号公式转换为 KAN。该过程如图 5 (a) 所示,涉及三个主要步骤:(1)将符号公式解析为树结构,其中节点表示表达式,边表示操作/函数。(2)然后修改该树以与 KAN 图的结构对齐。修改包括通过虚拟边将所有叶节点移动到输入层,并添加虚拟子节点/节点以匹配 KAN 架构。这些虚拟边/节点/子节点仅执行身份变换。(3)在第一层中组合变量,有效地将树转换为图。为了视觉清晰,在边上放置一维曲线以表示函数。我们在费曼数据集上对 kanpiler 进行了基准测试,成功处理了所有 120 个方程。示例见图 5 (b)。kanpiler 接收输入变量(作为 sympy 符号)和输出表达式(作为 sympy 表达式),并返回一个 KAN 模型。 m o d e l = k a n p i l e r ( i n p u t v a r i a b l e s , o u t p u t e x p r e s s i o n ) (8) {model = kanpiler(input_variables, output_expression)} \tag{8} model=kanpiler(inputvariables,outputexpression)(8) 请注意,返回的 KAN 模型处于符号模式,即符号函数被精确编码。如果我们改为使用三次样条来近似这些符号函数,我们将得到均方误差损失 ℓ ∝ N − 8 \ell \propto {N}^{-8} N8 [57],其中 N N N 是网格区间的数量(与模型参数的数量成正比)。

宽度/深度扩展以增加表达能力 由 kan-piler 生成的 KAN 网络是紧凑的,没有冗余边缘,这可能限制其表达能力并妨碍进一步的微调。为了解决这个问题,我们提出了 expand_width 和 expand_depth 方法,以扩展网络,使其变得更宽和更深,如图 5 © 所示。扩展方法最初添加零激活函数,这在训练过程中会遭受零梯度。因此,应使用扰动方法将这些零函数扰动为非零值,使其能够以非零梯度进行训练。

4 KANs 与科学

当今的黑箱深度神经网络功能强大,但解释这些模型仍然具有挑战性。科学家不仅寻求高性能模型,还希望能够从模型中提取有意义的知识。在本节中,我们专注于增强 KANs 的科学目的的可解释性。我们将探讨从 KANs 中提取知识的三个层次,从最基本到最复杂:重要特征(第 4.1 节)、模块化结构(第 4.2 节)和符号公式(第 4.3 节)。

4.1 从 KANs 中识别重要特征

识别重要变量对许多任务至关重要。给定一个回归模型 f f f 其中 y ≈ y \approx y f ( x 1 , x 2 , … , x n ) f\left( {{x}_{1},{x}_{2},\ldots ,{x}_{n}}\right) f(x1,x2,,xn) ,我们旨在为输入变量分配分数以评估其重要性。Liu 等人 [57] 使用 L1 范数函数来指示边缘的重要性,但该指标可能存在问题,因为它仅考虑局部信息。

为了解决这个问题,我们引入了一种更有效的归因分数,它比 L1 范数更好地反映变量的重要性。为简单起见,假设存在乘法节点,因此我们不需要区分节点和子节点 4 {}^{4} 4 | 假设我们有一个宽度为 [ n 0 , n 1 , ⋯   , n L ] \left\lbrack {{n}_{0},{n}_{1},\cdots ,{n}_{L}}\right\rbrack [n0,n1,,nL] L L L 层 KAN。我们将 E l , i , j {E}_{l, i, j} El,i,j 定义为 (l, i, j) 边上激活值的标准差,将 N l , i {N}_{l, i} Nl,i 定义为 (l, i) 节点上激活值的标准差。然后我们定义节点(归因)分数 A l , i {A}_{l, i} Al,i 和边(归因)分数 B l , i , j {B}_{l, i, j} Bl,i,j 。在 [57] 中,我们简单地定义了 B l , i , j = E l , i , j {B}_{l, i, j} = {E}_{l, i, j} Bl,i,j=El,i,j A l , i = N l , i {A}_{l, i} = {N}_{l, i} Al,i=Nl,i 。然而,这一定义未能考虑网络的后续部分;即使一个节点或边本身具有较大的范数,如果网络的其余部分有效地是零函数,它可能不会对输出产生贡献。因此,我们现在计算节点和边的分数。

(a) 比较 L1 函数范数和归因分数

图 6:在 KAN 中识别重要特征。 (a) 将归因分数与 Liu 等人 [57] 使用的 L1 范数进行比较。在两个合成任务中,归因分数提供的见解超过了 L1 范数。 (b) 可以为输入计算归因分数,并用于输入剪枝。

从输出层迭代到输入层。我们将所有输出维度的分数设为单位分数,即 A L , i = 1 , i = 0 , 1 , ⋯   , n L − 1 ∣ 5 {A}_{L, i} = 1, i = 0,1,\cdots ,{n}_{L} - 1 \mid 5 AL,i=1,i=0,1,,nL15 ,并按如下方式计算分数: B l − 1 , i , j = A l , j E l , j N l + 1 , j ,    A l − 1 , i = ∑ j = 0 n l B l − 1 , i , j ,    l = L , L − 1 , ⋯   , 1. (9) {B}_{l - 1, i, j} = {A}_{l, j}\frac{{E}_{l, j}}{{N}_{l + 1, j}},\;{A}_{l - 1, i} = \mathop{\sum }\limits_{{j = 0}}^{{n}_{l}}{B}_{l - 1, i, j},\;l = L, L - 1,\cdots ,1. \tag{9} Bl1,i,j=Al,jNl+1,jEl,j,Al1,i=j=0nlBl1,i,j,l=L,L1,,1.(9) 比较 E l , i , j {E}_{l, i, j} El,i,j B l , i , j {B}_{l, i, j} Bl,i,j 我们发现 B l , i , j {B}_{l, i, j} Bl,i,j 更准确地反映了边的重要性。在图6中,我们比较了基于两个方程 y = exp ⁡ ( sin ⁡ ( π x 1 ) + x 2 2 ) y = \exp \left( {\sin \left( {\pi {x}_{1}}\right) + {x}_{2}^{2}}\right) y=exp(sin(πx1)+x22) y = ( x 1 2 + x 2 2 ) 2 + ( x 3 2 + x 4 2 ) 2 y = {\left( {x}_{1}^{2} + {x}_{2}^{2}\right) }^{2} + {\left( {x}_{3}^{2} + {x}_{4}^{2}\right) }^{2} y=(x12+x22)2+(x32+x42)2 训练的 KAN,并可视化了重要性分数为 E E E (L1 范数) 或 B B B (归因分数) 的 KAN。对于第一个方程,归因分数揭示了比 L1 范数更清晰的图形,因为第一层中的许多活跃边由于后续边不活跃而未能对最终输出做出贡献。归因分数考虑到了这一点,从而产生了更有意义的图形。对于第二个方程 y = ( x 1 2 + x 2 2 ) 2 + ( x 3 2 + x 4 2 ) 2 y = {\left( {x}_{1}^{2} + {x}_{2}^{2}\right) }^{2} + {\left( {x}_{3}^{2} + {x}_{4}^{2}\right) }^{2} y=(x12+x22)2+(x32+x42)2 ,我们可以从符号方程中看出所有四个变量的重要性是相等的。归因分数正确反映了所有四个变量的相等重要性,而 L1 范数则错误地暗示 x 3 {x}_{3} x3 x 4 {x}_{4} x4 x 1 {x}_{1} x1 x 2 {x}_{2} x2 更重要。

基于归因分数修剪输入 在真实数据集中,输入维度可能很大,但只有少数变量可能是相关的。为了解决这个问题,我们提出根据归因分数修剪掉无关特征,以便我们能够专注于最相关的特征。用户可以应用 prune_input 仅保留最相关的变量。例如,如果在函数 y = ∑ i = 0 99 x i 2 / 2 i , x i ∈ [ − 1 , 1 ] y = \mathop{\sum }\limits_{{i = 0}}^{{99}}{x}_{i}^{2}/{2}^{i},{x}_{i} \in \left\lbrack {-1,1}\right\rbrack y=i=099xi2/2i,xi[1,1] 中有 100 个输入特征按相关性递减排序,并且在训练后,只有前五个特征显示出显著更高的归因分数,那么 prune_input 方法将仅保留这五个特征。修剪后的网络变得紧凑且易于解释,而原始的 KAN 由于有 100 个输入而过于密集,难以进行直接解释。

4.2 从 KAN 中识别模块化结构

尽管归因分数提供了有关哪些边或节点重要的有价值见解,但它并未揭示模块结构,即重要的边和节点是如何连接的。在这一部分,我们旨在通过检查两种类型的模块性:解剖模块性和功能模块性,来揭示经过训练的 KANs 和 MLPs 的模块结构。

图 7:通过神经元交换在神经网络中诱导解剖模块性。该方法涉及为神经元分配空间坐标并对其进行置换,以最小化整体连接成本。对于两个任务(左:多任务奇偶性,右:层次多数投票),神经元交换在 KANs(上)和 MLPs(下)中均适用于多任务奇偶性。

4.2.1 解剖模块性

解剖模块化是指空间上相互靠近的神经元之间的连接比远离的神经元之间的连接更强的倾向。尽管人工神经网络缺乏物理空间坐标,但引入物理空间的概念已被证明可以增强可解释性 [51, 52]。我们采用了 [51, 52] 中的神经元交换方法,该方法在保持网络功能的同时缩短连接。我们称该方法为 auto_swap。通过神经元交换揭示的解剖模块结构使得在图 7 中展示的两个任务的模块易于识别,甚至可以通过视觉识别。 (1) 多任务稀疏奇偶性; (2) 层次多数投票。对于多任务稀疏奇偶性,我们有 10 个输入位 x i ∈ { 0 , 1 } , i = 1 , 2 , ⋯   , 10 {x}_{i} \in \{ 0,1\} , i = 1,2,\cdots ,{10} xi{0,1},i=1,2,,10 和输出 y j = x 2 j − 1 ⊕ x 2 j , j = 1 , ⋯   , 5 {y}_{j} = {x}_{{2j} - 1} \oplus {x}_{2j}, j = 1,\cdots ,5 yj=x2j1x2j,j=1,,5 ,其中 ⊕ \oplus 表示模 2 加法。该任务表现出模块化,因为每个输出仅依赖于一部分输入。auto_swap 成功识别了 KANs 和 MLPs 的模块,其中 KAN 发现了更简单的模块。对于层次多数投票,输入位为 9 x i ∈ { 0 , 1 } , i = 1 , ⋯   , 9 {x}_{i} \in \{ 0,1\} , i = 1,\cdots ,9 xi{0,1},i=1,,9 ,输出为 y = maj ⁡ ( maj ⁡ ( x 1 , x 2 , x 3 ) , maj ⁡ ( x 4 , x 5 , x 6 ) , maj ⁡ ( x 7 , x 8 , x 9 ) ) y = \operatorname{maj}\left( {\operatorname{maj}\left( {{x}_{1},{x}_{2},{x}_{3}}\right) ,\operatorname{maj}\left( {{x}_{4},{x}_{5},{x}_{6}}\right) ,\operatorname{maj}\left( {{x}_{7},{x}_{8},{x}_{9}}\right) }\right) y=maj(maj(x1,x2,x3),maj(x4,x5,x6),maj(x7,x8,x9)) ,其中 maj 代表多数投票(如果两个或三个输入为 1,则输出 1,否则输出 0)。KAN 在 auto_swap 之前就揭示了模块结构,而在 auto_swap 之后图示变得更加有序。MLP 从第一层权重的模式中显示出一些模块结构,表明变量之间存在相互作用,但无论是否使用 auto_swap,整体模块结构仍然不清晰。

4.2.2 功能模块化

功能模块化涉及神经网络所表示的整体功能。考虑一个 Oracle 网络,其中内部细节如权重和隐藏层激活不可访问(过于复杂以至于无法分析),我们仍然可以通过输入和输出的前向和后向传递收集有关功能模块化的信息。我们定义了三种类型的功能模块化(见图 8 (a)),主要基于 [84]。

可分离性:一个函数 f f f 是可加分离的,如果 f ( x 1 , x 2 , ⋯ x n ) = g ( x 1 , … , x k ) + h ( x k + 1 , … , x n ) . (10) f\left( {{x}_{1},{x}_{2},\cdots {x}_{n}}\right) = g\left( {{x}_{1},\ldots ,{x}_{k}}\right) + h\left( {{x}_{k + 1},\ldots ,{x}_{n}}\right) . \tag{10} f(x1,x2,xn)=g(x1,,xk)+h(xk+1,,xn).(10) 注意 ∂ 2 f ∂ x i ∂ x j = 0 \frac{{\partial }^{2}f}{\partial {x}_{i}\partial {x}_{j}} = 0 xixj2f=0 1 ≤ i ≤ k , k + 1 ≤ j ≤ n 1 \leq i \leq k, k + 1 \leq j \leq n 1ik,k+1jn 时。为了检测可分离性,我们可以计算 Hessian 矩阵 H ≡ ∇ T ∇ f ( H i j = ∂ 2 f ∂ x i ∂ x j ) \mathbf{H} \equiv {\nabla }^{T}\nabla f\left( {{\mathbf{H}}_{ij} = \frac{{\partial }^{2}f}{\partial {x}_{i}\partial {x}_{j}}}\right) HTf(Hij=xixj2f) 并检查块结构。如果 H i j = 0 {\mathbf{H}}_{ij} = 0 Hij=0 对于

图 8:在 KANs 中检测功能模块化。(a) 我们研究三种类型的功能模块化:可分离性(加性或乘性)、一般可分离性和对称性。(b) 递归应用这些测试将一个函数转换为树。在这里,函数可以是符号函数(顶部)、KANs(中间)或 MLPs(底部)。KANs 和 MLPs 在训练结束时都产生正确的树图,但显示出不同的训练动态。

所有 1 ≤ i ≤ k 1 \leq i \leq k 1ik k + 1 ≤ j ≤ n k + 1 \leq j \leq n k+1jn ,那么我们知道 f f f 是可加分离的。对于乘性可分离性,我们可以通过取对数将其转换为可加分离性: f ( x 1 , x 2 , ⋯ x n ) = g ( x 1 , … , x k ) × h ( x k + 1 , … , x n ) (11) f\left( {{x}_{1},{x}_{2},\cdots {x}_{n}}\right) = g\left( {{x}_{1},\ldots ,{x}_{k}}\right) \times h\left( {{x}_{k + 1},\ldots ,{x}_{n}}\right) \tag{11} f(x1,x2,xn)=g(x1,,xk)×h(xk+1,,xn)(11) log ⁡ ∣ f ( x 1 , x 2 , ⋯   , x n ) ∣ = log ⁡ ∣ g ( x 1 , … , x k ) ∣ + log ⁡ ∣ h ( x k + 1 , … , x n ) ∣ \log \left| {f\left( {{x}_{1},{x}_{2},\cdots ,{x}_{n}}\right) }\right| = \log \left| {g\left( {{x}_{1},\ldots ,{x}_{k}}\right) }\right| + \log \left| {h\left( {{x}_{k + 1},\ldots ,{x}_{n}}\right) }\right| logf(x1,x2,,xn)=logg(x1,,xk)+logh(xk+1,,xn) 为了检测乘性可分离性,我们定义 H i j ≡ ∂ 2 log ⁡ ∣ f ∣ ∂ x i ∂ x j {\mathbf{H}}_{ij} \equiv \frac{{\partial }^{2}\log \left| f\right| }{\partial {x}_{i}\partial {x}_{j}} Hijxixj2logf ,并检查块结构。用户可以调用 test_separability 来测试一般可分离性。

广义可分离性:一个函数 f f f 具有广义可分离性,如果 f ( x 1 , x 2 , ⋯ x n ) = F ( g ( x 1 , … , x k ) + h ( x k + 1 , … , x n ) ) . (12) f\left( {{x}_{1},{x}_{2},\cdots {x}_{n}}\right) = F\left( {g\left( {{x}_{1},\ldots ,{x}_{k}}\right) + h\left( {{x}_{k + 1},\ldots ,{x}_{n}}\right) }\right) . \tag{12} f(x1,x2,xn)=F(g(x1,,xk)+h(xk+1,,xn)).(12) 为了检测广义可分离性,我们计算 ∂ f ∂ x i = ∂ F ∂ g ∂ g ∂ x i ( 1 ≤ i ≤ k ) , ∂ f ∂ x j = ∂ F ∂ h ∂ h ∂ x i ( k + 1 ≤ j ≤ n ) (13) \frac{\partial f}{\partial {x}_{i}} = \frac{\partial F}{\partial g}\frac{\partial g}{\partial {x}_{i}}\left( {1 \leq i \leq k}\right) ,\frac{\partial f}{\partial {x}_{j}} = \frac{\partial F}{\partial h}\frac{\partial h}{\partial {x}_{i}}\left( {k + 1 \leq j \leq n}\right) \tag{13} xif=gFxig(1ik),xjf=hFxih(k+1jn)(13) ∂ f / ∂ x i ∂ f / ∂ x j = ∂ F / ∂ g ∂ F / ∂ h ∂ g / ∂ x i ∂ h / ∂ x j = ∂ g / ∂ x i ∂ h / ∂ x j = g x i ( x 1 , x 2 , ⋯ x k ) × 1 h x j ( x k + 1 , ⋯   , x n ) . \frac{\partial f/\partial {x}_{i}}{\partial f/\partial {x}_{j}} = \frac{\partial F/\partial g}{\partial F/\partial h}\frac{\partial g/\partial {x}_{i}}{\partial h/\partial {x}_{j}} = \frac{\partial g/\partial {x}_{i}}{\partial h/\partial {x}_{j}} = {g}_{{x}_{i}}\left( {{x}_{1},{x}_{2},\cdots {x}_{k}}\right) \times \frac{1}{{h}_{{x}_{j}}\left( {{x}_{k + 1},\cdots ,{x}_{n}}\right) }. f/xjf/xi=F/hF/gh/xjg/xi=h/xjg/xi=gxi(x1,x2,xk)×hxj(xk+1,,xn)1. 我们使用了 ∂ F ∂ g = ∂ F ∂ h \frac{\partial F}{\partial g} = \frac{\partial F}{\partial h} gF=hF 。注意 ∂ f / ∂ x i ∂ f / ∂ x j \frac{\partial f/\partial {x}_{i}}{\partial f/\partial {x}_{j}} f/xjf/xi 是乘性可分离的,可以通过上述提出的可分离性测试来检测。用户可以调用 test_general_separability 来检查加性或乘性可分离性。

广义对称性:一个函数在第一个 k k k 变量中具有广义对称性,如果 f ( x 1 , x 2 , ⋯   , x n ) = g ( h ( x 1 , ⋯   , x k ) , x k + 1 , ⋯   , x n ) . (14) f\left( {{x}_{1},{x}_{2},\cdots ,{x}_{n}}\right) = g\left( {h\left( {{x}_{1},\cdots ,{x}_{k}}\right) ,{x}_{k + 1},\cdots ,{x}_{n}}\right) . \tag{14} f(x1,x2,,xn)=g(h(x1,,xk),xk+1,,xn).(14) 我们表示 y = ( x 1 , ⋯   , x k ) \mathbf{y} = \left( {{x}_{1},\cdots ,{x}_{k}}\right) y=(x1,,xk) z = ( x k + 1 , ⋯   , x n ) \mathbf{z} = \left( {{x}_{k + 1},\cdots ,{x}_{n}}\right) z=(xk+1,,xn) 。这个属性被称为广义对称性,因为 f f f 保持相同的值,只要 h h h 保持不变,而与 x 1 , ⋯   , x k {x}_{1},\cdots ,{x}_{k} x1,,xk 的个体值无关。我们计算 f f f 关于 y : ∇ y f = ∂ g ∂ h ∇ y h \mathbf{y} : {\nabla }_{\mathbf{y}}f = \frac{\partial g}{\partial h}{\nabla }_{\mathbf{y}}h y:yf=hgyh 的梯度。由于 ∂ g ∂ h \frac{\partial g}{\partial h} hg 是一个标量函数,它不会改变 ∇ y h {\nabla }_{\mathbf{y}}h yh 的方向。因此, ∇ y f ^ ≡ ∇ y f ∣ ∇ y f ∣ \widehat{{\nabla }_{\mathbf{y}}f} \equiv \frac{{\nabla }_{\mathbf{y}}f}{\left| {\nabla }_{\mathbf{y}}f\right| } yf yfyf 的方向独立于 z \mathbf{z} z ,即, ∇ z ( ∇ y f ^ ) = 0 (15) {\nabla }_{\mathbf{z}}\left( \widehat{{\nabla }_{\mathbf{y}}f}\right) = 0 \tag{15} z(yf )=0(15) 这是对称性的条件。用户可以调用 test_symmetry 方法来检查对称性。

树转换器 三种功能模块化形式构成一个层次结构:对称性是最一般的,广义可分离性是中间的,而可分离性是最具体的。从数学上讲, Separability ⊂ Generalized Separability ⊂ Generalized Symmetry (16) \text{Separability} \subset \text{Generalized Separability} \subset \text{Generalized Symmetry} \tag{16} SeparabilityGeneralized SeparabilityGeneralized Symmetry(16) 为了获得模块结构的最大层次,我们递归地应用广义对称性检测,形成尽可能小的 k = 2 k = 2 k=2 个变量的组,并扩展到所有 k = n k = n k=n 个变量。例如,让我们考虑一个8变量函数 f ( x 1 , ⋯   , x 8 ) = ( ( x 1 2 + x 2 2 ) 2 + ( x 3 2 + x 4 2 ) 2 ) 2 + ( ( x 5 2 + x 6 2 ) 2 + ( x 7 2 + x 8 2 ) 2 ) 2 , (17) f\left( {{x}_{1},\cdots ,{x}_{8}}\right) = {\left( {\left( {x}_{1}^{2} + {x}_{2}^{2}\right) }^{2} + {\left( {x}_{3}^{2} + {x}_{4}^{2}\right) }^{2}\right) }^{2} + {\left( {\left( {x}_{5}^{2} + {x}_{6}^{2}\right) }^{2} + {\left( {x}_{7}^{2} + {x}_{8}^{2}\right) }^{2}\right) }^{2}, \tag{17} f(x1,,x8)=((x12+x22)2+(x32+x42)2)2+((x52+x62)2+(x72+x82)2)2,(17) 它具有四个 k = 2 k = 2 k=2 广义对称性,涉及组 ( x 1 , x 2 ) , ( x 3 , x 4 ) , ( x 5 , x 6 ) \left( {{x}_{1},{x}_{2}}\right) ,\left( {{x}_{3},{x}_{4}}\right) ,\left( {{x}_{5},{x}_{6}}\right) (x1,x2),(x3,x4),(x5,x6) ( x 7 , x 8 ) \left( {{x}_{7},{x}_{8}}\right) (x7,x8) ;两个 k = 2 k = 2 k=2 广义对称性,涉及组 ( x 1 , x 2 , x 3 , x 4 ) \left( {{x}_{1},{x}_{2},{x}_{3},{x}_{4}}\right) (x1,x2,x3,x4) ( x 5 , x 6 , x 7 , x 8 ) \left( {{x}_{5},{x}_{6},{x}_{7},{x}_{8}}\right) (x5,x6,x7,x8) 。因此,每个 k = 4 k = 4 k=4 组包含两个 k = 2 k = 2 k=2 组,展示了一个层次结构。对于每个广义对称性,我们还可以测试该广义对称性是否进一步广义可分离或可分离。用户可以使用方法 plot_tree 来获取一个函数的树图(该函数可以是任何 Python 表达式、神经网络等)。对于神经网络模型,用户只需调用 model.tree()。树图可以有 ‘tree’(默认)或 ‘box’ 的样式。

示例图8 (b) 提供了两个例子。当将精确的符号函数输入到 plot_tree 时,获得了真实的树图。我们特别关注树转换器是否适用于神经网络。在这些简单的情况下,如果经过充分训练,KAN和MLP都可以找到正确的图。图8 (b) (底部) 显示了在 KAN 和 MLP 训练过程中树图的演变。特别有趣的是观察神经网络如何逐渐学习到正确的模块化结构。在第一种情况下 f ( x 1 , x 2 , x 3 , x 4 ) = ( x 1 2 + x 2 2 ) 2 + ( x 3 2 + x 4 2 ) 2 f\left( {{x}_{1},{x}_{2},{x}_{3},{x}_{4}}\right) = {\left( {x}_{1}^{2} + {x}_{2}^{2}\right) }^{2} + {\left( {x}_{3}^{2} + {x}_{4}^{2}\right) }^{2} f(x1,x2,x3,x4)=(x12+x22)2+(x32+x42)2 ,KAN 和 MLP 都逐渐获得了更多的归纳偏见(它们的中间状态是不同的),直到它们达到正确的结构。在第二种情况下 f ( x 1 , x 2 , x 3 ) = sin ⁡ ( x 1 ) / x 2 2 + x 3 2 f\left( {{x}_{1},{x}_{2},{x}_{3}}\right) = \sin \left( {x}_{1}\right) /\sqrt{{x}_{2}^{2} + {x}_{3}^{2}} f(x1,x2,x3)=sin(x1)/x22+x32 ,两个模型最初检测到所有三个变量的乘法可分性,显示出比正确结构更高的对称性。随着训练的进行,两个模型“意识到”:为了更好地拟合数据(损失变得更低),这样的高对称结构不再能够满足,应该放宽到一个不那么严格的结构。另一个观察是 KAN 有一个在 MLP 中找不到的中间结构。我们想提到两个警告:(1)结果可能依赖于种子和/或阈值。(2)所有测试依赖于二阶导数,这可能不够稳健,因为模型仅在零阶信息上进行训练。对抗性构造如 f ϵ ( x ) = f ( x ) + ϵ sin ⁡ ( x ϵ ) {f}_{\epsilon }\left( x\right) = f\left( x\right) + \epsilon \sin \left( \frac{x}{\epsilon }\right) fϵ(x)=f(x)+ϵsin(ϵx) 可能会导致问题,因为尽管 ∣ f ϵ ( x ) − f ( x ) ∣ → 0 \left| {{f}_{\epsilon }\left( x\right) - f\left( x\right) }\right| \rightarrow 0 fϵ(x)f(x)0 如同 ϵ → 0 \epsilon \rightarrow 0 ϵ0 ∣ f ϵ ′ ′ ( x ) − f ′ ′ ( x ) ∣ → ∞ \left| {{f}_{\epsilon }^{\prime \prime }\left( x\right) - {f}^{\prime \prime }\left( x\right) }\right| \rightarrow \infty fϵ′′(x)f′′(x) 如同 ϵ → 0 \epsilon \rightarrow 0 ϵ0 。尽管在实践中这种极端情况不太可能发生,但平滑性对于确保我们方法的成功是必要的。

图9:促进符号回归的三种技巧。技巧A(顶部):检测和利用模块化结构。技巧B(中间):稀疏连接初始化。技巧C(底部):假设检验。

4.3 从 KAN 中识别符号公式

符号公式是最具信息量的,因为一旦已知,它们清晰地揭示了重要特征和模块结构。在 Liu 等人 [57] 的研究中,作者展示了一系列示例,从中可以提取符号公式,并在需要时使用一些先验知识。借助上述新工具(特征重要性、模块结构和符号公式),用户可以利用这些新工具轻松与 KAN 进行交互和协作,使符号回归变得更加简单。我们在下面介绍三个技巧,如图 9 所示。

技巧 A:发现并利用模块结构 我们可以首先训练一个通用网络并探测其模块性。一旦识别出模块结构,我们就用这个模块结构作为归纳偏置初始化一个新模型。例如,考虑函数 f ( q , v , B , m ) = q v B / m f\left( {q, v, B, m}\right) = {qvB}/m f(q,v,B,m)=qvB/m

我们首先初始化一个大型 KAN(假设其表达能力足够强),以合理的准确度拟合数据集。训练后,从训练好的 KAN 中提取树图(参见第 4.2 节),该图显示了乘法可分性。然后,我们可以将模块结构构建到第二个 KAN 中(参见第 3.2 节),对其进行训练,然后将所有 1D 函数符号化以推导出公式。

技巧 B:稀疏初始化 符号公式通常对应于具有稀疏连接的 KAN(见图 5 (b)),因此稀疏初始化 KAN 更好地与符号公式的归纳偏置对齐。否则,密集初始化的 KAN 需要仔细的正则化以促进稀疏性。稀疏初始化可以通过将参数 “sparse_init=True” 传递给 KAN 初始化器来实现。例如,对于函数 f ( q , E , v , B , θ ) = q ( E + v B sin ⁡ θ ) f\left( {q, E, v, B,\theta }\right) = q\left( {E + {vB}\sin \theta }\right) f(q,E,v,B,θ)=q(E+vBsinθ) ,稀疏初始化的 KAN 与最终训练的 KAN 非常相似,仅需在训练中进行小幅调整。相比之下,密集初始化则需要大量训练以去除不必要的边。

Trick C: 假设检验 当面对多个合理的假设时,我们可以尝试所有假设(分支到“平行宇宙”)以测试哪个假设是最准确和/或最简单的。为了促进假设检验,我们建立了一个检查点系统,该系统在进行任何更改(例如训练、剪枝)时自动保存模型版本。例如,考虑函数 f ( m 0 , v , c ) = m 0 / 1 − ( v / c ) 2 f\left( {{m}_{0}, v, c}\right) = {m}_{0}/\sqrt{1 - {\left( v/c\right) }^{2}} f(m0,v,c)=m0/1(v/c)2 。我们从一个随机初始化的 KAN 开始,版本为 0.0。在训练后,它演变为版本 0.1,此时它在 β = v / c \beta = v/c β=v/c γ = 1 / 1 − ( v / c ) 2 \gamma = 1/\sqrt{1 - {\left( v/c\right) }^{2}} γ=1/1(v/c)2 上均被激活。假设仅需要 β \beta β γ \gamma γ 。我们首先将 γ \gamma γ 的边设置为零,并训练模型,获得 6.5 × 10 − 4 {6.5} \times {10}^{-4} 6.5×104 测试 RMSE(版本 0.2)。为了测试替代假设,我们希望恢复到分支点(版本 0.1) - 我们调用 model.rewind(‘0.1’),这将模型回滚到版本 0.1。为了指示调用了回滚,版本 0.1 被重命名为版本 1.1。现在我们将 β \beta β 的边设置为零,训练模型,获得 2.0 × 10 − 6 {2.0} \times {10}^{-6} 2.0×106 测试 RMSE(版本变为 1.2)。比较版本 0.2 和 1.2 表明第二个假设更好,因为在相同复杂度下损失更低(两个假设都有两个非零边)。

5 应用

前面的部分主要集中在回归问题上,以便于教学。在本节中,我们将 KAN 应用于发现物理概念,例如守恒量、拉格朗日量、隐藏对称性和本构定律。这些例子说明了本文提出的工具如何有效地融入现实科学研究,以应对这些复杂任务。

5.1 发现守恒量

图10:使用KAN发现二维谐振子的守恒量。

守恒量是随时间保持不变的物理量。例如,自由下落的球体将其重力势能转化为动能,而总能量(两种能量的总和)保持不变(假设空气阻力可以忽略不计)。守恒量至关重要,因为它们通常与物理系统中的对称性相对应,并且可以通过降低系统的维度来简化计算。传统上,使用纸和铅笔推导守恒量可能非常耗时,并且需要广泛的领域知识。最近,已经探索了机器学习技术来发现守恒量 [55, 53, 54, 58, 32, 89]。

我们遵循Liu等人 [53] 的方法,该方法推导出守恒量必须满足的微分方程,从而将寻找守恒量的问题转化为微分方程求解。他们使用多层感知器(MLP)对守恒量进行参数化。我们基本上遵循他们的程序,但用KAN替代了MLP。具体而言,他们考虑一个由方程 d z d t = f ( z ) \frac{d\mathbf{z}}{dt} = \mathbf{f}\left( \mathbf{z}\right) dtdz=f(z) 支配的状态变量 z ∈ R d \mathbf{z} \in {\mathbb{R}}^{d} zRd 的动态系统。函数 H ( z ) H\left( \mathbf{z}\right) H(z) 成为守恒量的必要和充分条件是对于所有 z \mathbf{z} z ,有 f ( z ) ⋅ ∇ H ( z ) = 0 \mathbf{f}\left( \mathbf{z}\right) \cdot \nabla H\left( \mathbf{z}\right) = 0 f(z)H(z)=0 。例如,在一维谐振子中,相空间由位置和动量 z = ( x , p ) \mathbf{z} = \left( {x, p}\right) z=(x,p) 特征化,演化方程为 d ( x , p ) / d t = ( p , − x ) d\left( {x, p}\right) /{dt} = \left( {p, - x}\right) d(x,p)/dt=(p,x) 。能量 H = 1 2 ( x 2 + p 2 ) H = \frac{1}{2}\left( {{x}^{2} + {p}^{2}}\right) H=21(x2+p2) 是一个守恒量,因为 f ( z ) ⋅ ∇ H ( z ) = ( p , − x ) ⋅ ( x , p ) = 0 \mathbf{f}\left( \mathbf{z}\right) \cdot \nabla H\left( \mathbf{z}\right) = \left( {p, - x}\right) \cdot \left( {x, p}\right) = 0 f(z)H(z)=(p,x)(x,p)=0 。我们使用KAN对 H H H 进行参数化,并用损失函数 ℓ = ∑ i = 1 N ∣ f ( z ( i ) ) ⋅ ∇ ^ H ( z ( i ) ) ∣ 2 \ell = \mathop{\sum }\limits_{{i = 1}}^{N}{\left| \mathbf{f}\left( {\mathbf{z}}^{\left( i\right) }\right) \cdot \widehat{\nabla }H\left( {\mathbf{z}}^{\left( i\right) }\right) \right| }^{2} =i=1N f(z(i)) H(z(i)) 2 进行训练,其中 ∇ ^ \widehat{\nabla } 是归一化梯度, z ( i ) {\mathbf{z}}^{\left( i\right) } z(i) 是从超立方体 [ − 1 , 1 ] d {\left\lbrack -1,1\right\rbrack }^{d} [1,1]d 中均匀抽取的 i th  {i}^{\text{th }} ith  数据点。

我们选择二维谐振子来测试 KANs,其特征为 ( x , y , p x , p y ) \left( {x, y,{p}_{x},{p}_{y}}\right) (x,y,px,py) 。它有三个守恒量:(1) 沿 x x x 方向的能量: H 1 = 1 2 ( x 2 + p x 2 ) {H}_{1} = \frac{1}{2}\left( {{x}^{2} + {p}_{x}^{2}}\right) H1=21(x2+px2) ;(2) 沿 y y y 方向的能量: H 2 = 1 2 ( y 2 + p y 2 ) {H}_{2} = \frac{1}{2}\left( {{y}^{2} + {p}_{y}^{2}}\right) H2=21(y2+py2) ;(3) 角动量 H 3 = x p y − y p x {H}_{3} = x{p}_{y} - y{p}_{x} H3=xpyypx 。我们使用三个不同的随机种子训练 [ 4 , [ 0 , 2 ] , 1 ] \left\lbrack {4,\left\lbrack {0,2}\right\rbrack ,1}\right\rbrack [4,[0,2],1] KANs,如图 10 所示,分别对应于 H 1 , H 2 {H}_{1},{H}_{2} H1,H2 H 3 {H}_{3} H3

5.2 发现拉格朗日量

图 11:使用 KANs 学习单摆(上)和均匀场中的相对论质量(下)的拉格朗日量。

在物理学中,拉格朗日力学是基于静态作用原理的经典力学的一种表述。它使用相空间和一个光滑函数 L \mathcal{L} L 来描述机械系统,该函数称为拉格朗日量。对于许多系统, L = T − V \mathcal{L} = T - V L=TV ,其中 T T T V V V 分别表示系统的动能和势能。相空间通常由 ( q , q ˙ ) \left( {\mathbf{q},\dot{\mathbf{q}}}\right) (q,q˙) 描述,其中 q \mathbf{q} q q ˙ \dot{\mathbf{q}} q˙ 分别表示坐标和速度。运动方程可以通过欧拉-拉格朗日方程从拉格朗日量推导出来: d d t ( ∂ L ∂ q ˙ ) = ∂ L ∂ q \frac{d}{dt}\left( \frac{\partial \mathcal{L}}{\partial \dot{\mathbf{q}}}\right) = \frac{\partial \mathcal{L}}{\partial \mathbf{q}} dtd(q˙L)=qL ,或等效地 q ¨ = ( ∇ q ˙ ∇ q ˙ T L ) − 1 [ ∇ q L − ( ∇ q ∇ q ˙ T q ˙ ) ] (18) \ddot{\mathbf{q}} = {\left( {\nabla }_{\dot{\mathbf{q}}}{\nabla }_{\dot{\mathbf{q}}}^{T}\mathcal{L}\right) }^{-1}\left\lbrack {{\nabla }_{\mathbf{q}}\mathcal{L} - \left( {{\nabla }_{\mathbf{q}}{\nabla }_{\dot{\mathbf{q}}}^{T}\dot{\mathbf{q}}}\right) }\right\rbrack \tag{18} q¨=(q˙q˙TL)1[qL(qq˙Tq˙)](18) 鉴于拉格朗日量的基本作用,一个有趣的问题是我们是否可以从数据中推断出拉格朗日量。根据 [19],我们训练一个拉格朗日神经网络以从 ( q , q ˙ \mathbf{q},\dot{\mathbf{q}} q,q˙ ) 预测 q ¨ \ddot{\mathbf{q}} q¨ 。拉格朗日神经网络使用多层感知器(MLP)来参数化 L ( q , q ˙ ) \mathcal{L}\left( {\mathbf{q},\dot{\mathbf{q}}}\right) L(q,q˙) ,并计算公式 (18) 以预测瞬时加速度 "。然而,拉格朗日神经网络面临两个主要挑战:(1)由于公式 (18) 中的二阶导数和矩阵求逆,拉格朗日神经网络的训练可能不稳定。(2)拉格朗日神经网络缺乏可解释性,因为多层感知器本身并不容易解释。我们使用 KANs 来解决这些问题。

为了应对第一个挑战,我们注意到当 Hessian ( ∇ q ˙ ∇ q ˙ T L ) − 1 {\left( {\nabla }_{\dot{\mathbf{q}}}{\nabla }_{\dot{\mathbf{q}}}^{T}\mathcal{L}\right) }^{-1} (q˙q˙TL)1 的特征值接近零时,矩阵求逆会变得问题重重。为了解决这个问题,我们将 ( ∇ q ˙ ∇ q ˙ T L ) \left( {{\nabla }_{\dot{\mathbf{q}}}{\nabla }_{\dot{\mathbf{q}}}^{T}\mathcal{L}}\right) (q˙q˙TL) 初始化为一个正定矩阵(或在一维中为一个正数)。由于 ( ∇ q ˙ ∇ q ˙ T L ) \left( {{\nabla }_{\dot{\mathbf{q}}}{\nabla }_{\dot{\mathbf{q}}}^{T}\mathcal{L}}\right) (q˙q˙TL) 是经典力学中的质量 m m m ,而动能通常是 T = 1 2 m q ˙ 2 T = \frac{1}{2}m{\dot{\mathbf{q}}}^{2} T=21mq˙2 ,将这一先验知识编码到 KANs 中比编码到 MLPs 中更为直接(使用第 3.3 节中介绍的 kanpiler)。kanpiler 可以将符号公式 T T T 转换为 KAN(如图 11 所示)。我们使用这个转换后的 KAN 进行初始化并继续训练,与随机初始化相比,结果显示出更大的稳定性。训练后,可以对每个边应用符号回归以提取符号公式,从而解决第二个挑战。

我们在图 11 中展示了两个一维示例,一个是单摆,另一个是均匀场中的相对论质量。编译后的 KANs 显示在左侧,边 q ˙ \dot{q} q˙ 显示为二次函数,边 q q q 显示为零函数。

单摆 q ˙ \dot{q} q˙ 部分仍然是一个二次函数 T ( q ˙ ) = 1 2 q ˙ 2 T\left( \dot{q}\right) = \frac{1}{2}{\dot{q}}^{2} T(q˙)=21q˙2 ,而 q q q 部分学习成为余弦函数,如 V ( q ) = 1 − cos ⁡ ( q ) V\left( q\right) = 1 - \cos \left( q\right) V(q)=1cos(q) 所示。在图11的顶部,来自 suggest_symbolic 的结果显示了与样条曲线最佳匹配的前五个函数,考虑了适应度和简洁性。正如预期的那样,余弦函数和二次函数出现在列表的顶部。

均匀场中的相对论质量 经过训练后,动能部分偏离了 T = T = T= 1 2 q ˙ 2 \frac{1}{2}{\dot{q}}^{2} 21q˙2 ,因为对于一个相对论粒子, T r = ( 1 − q ˙ 2 ) − 1 / 2 − 1 {T}_{r} = {\left( 1 - {\dot{q}}^{2}\right) }^{-1/2} - 1 Tr=(1q˙2)1/21 。在图11(底部),符号回归成功找到了 V ( q ) = q V\left( q\right) = q V(q)=q ,但由于其组合性质,未能识别 T r {T}_{r} Tr ,因为我们的符号回归仅搜索简单函数。通过假设第一个函数组合是二次的,我们创建了另一个 [ 1 , 1 , 1 ] \left\lbrack {1,1,1}\right\rbrack [1,1,1] KAN 来拟合 T r {T}_{r} Tr ,并将第一个函数设定为使用 fix_symbolic 的二次函数,仅训练第二个可学习函数。训练后,我们看到真实值 x − 1 / 2 {x}^{-1/2} x1/2 出现在前五个候选者中。然而, x 1 / 2 {x}^{1/2} x1/2 的拟合效果稍好一些,正如更高的 R 平方值所示。这表明符号回归对噪声敏感(由于学习不完善),而先验知识对于正确判断至关重要。例如,知道动能应该在速度接近光速时发散,有助于确认 x − 1 / 2 {x}^{-1/2} x1/2 是正确的项,因为 x 1 / 2 {x}^{1/2} x1/2 并未表现出预期的发散。

5.3 发现隐藏的对称性

图12:使用MLPs和KANs重新发现施瓦茨希尔德黑洞的隐藏对称性。(a) MLP学习到的 Δ t ( r ) {\Delta t}\left( r\right) Δt(r) 是一个全局光滑解;(b) KAN学习到的 Δ t ( r ) {\Delta t}\left( r\right) Δt(r) 是一个域壁解;© KAN在域壁处显示出损失峰值;(d) KAN可以用于将MLP解微调到接近机器精度。

菲利普·安德森著名地指出:“说物理学是对称性的研究,这一说法只是稍微夸大了事实”,强调了对称性发现对加深我们理解和更有效地解决问题的重要性。

然而,对称性有时并不明显,而是隐藏的,只有通过应用某些坐标变换才能揭示。例如,在施瓦茨希尔德发现他命名的黑洞度量后,佩恩莱夫、古尔斯特朗和勒梅特花了17年才揭示其隐藏的平移对称性。他们证明,通过巧妙的坐标变换,空间部分可以变得平移不变,从而加深了我们对黑洞的理解[65]。刘和泰格马克[56]展示了古尔斯特朗-佩恩莱夫变换可以通过在几分钟内训练一个MLP来发现。然而,他们并没有获得极高的精度(即机器精度)来解决该问题。我们尝试使用KAN重新审视这个问题。

假设在时空(t, x, y, z)中存在一个质量为 2 M = 1 {2M} = 1 2M=1 的施瓦茨希尔德黑洞,位于 x = y = z = 0 x = y = z = 0 x=y=z=0 ,半径为 r s = 2 M = 1 {r}_{s} = {2M} = 1 rs=2M=1 。施瓦茨希尔德度量描述了空间和时间如何在其周围扭曲: g μ ν = ( 1 − 2 M r 0 0 0 0 − 1 − 2 M x 2 ( r − 2 M ) r 2 − 2 M x y ( r − 2 M ) r 2 − 2 M x z ( r − 2 M ) r 2 0 − 2 M x y ( r − 2 M ) r 2 − 1 − 2 M y 2 ( r − 2 M ) r 2 − 2 M y z ( r − 2 M ) r 2 0 − 2 M x z ( r − 2 M ) r 2 − 2 M y z ( r − 2 M ) r 2 − 1 − 2 M z 2 ( r − 2 M ) r 2 ) (19) {\mathbf{g}}_{\mu \nu } = \left( \begin{matrix} 1 - \frac{2M}{r} & 0 & 0 & 0 \\ 0 & - 1 - \frac{{2M}{x}^{2}}{\left( {r - {2M}}\right) {r}^{2}} & - \frac{2Mxy}{\left( {r - {2M}}\right) {r}^{2}} & - \frac{2Mxz}{\left( {r - {2M}}\right) {r}^{2}} \\ 0 & - \frac{2Mxy}{\left( {r - {2M}}\right) {r}^{2}} & - 1 - \frac{{2M}{y}^{2}}{\left( {r - {2M}}\right) {r}^{2}} & - \frac{2Myz}{\left( {r - {2M}}\right) {r}^{2}} \\ 0 & - \frac{2Mxz}{\left( {r - {2M}}\right) {r}^{2}} & - \frac{2Myz}{\left( {r - {2M}}\right) {r}^{2}} & - 1 - \frac{{2M}{z}^{2}}{\left( {r - {2M}}\right) {r}^{2}} \end{matrix}\right) \tag{19} gμν= 1r2M00001(r2M)r22Mx2(r2M)r22Mxy(r2M)r22Mxz0(r2M)r22Mxy1(r2M)r22My2(r2M)r22Myz0(r2M)r22Mxz(r2M)r22Myz1(r2M)r22Mz2 (19) 应用古尔斯特朗-佩恩莱夫变换 t ′ = t + 2 M ( 2 u + ln ⁡ ( u − 1 u + 1 ) ) , u ≡ r 2 M , x ′ = x {t}^{\prime } = t + {2M}\left( {{2u} + \ln \left( \frac{u - 1}{u + 1}\right) }\right) , u \equiv \sqrt{\frac{r}{2M}},{x}^{\prime } = x t=t+2M(2u+ln(u+1u1)),u2Mr ,x=x y ′ = y , z ′ = z {y}^{\prime } = y,{z}^{\prime } = z y=y,z=z ,新坐标下的度量变为: g μ ν ′ = ( 1 − 2 M r − 2 M r x r − 2 M r y r − 2 M r z r − 2 M r x r − 1 0 0 − 2 M r y r 0 − 1 0 − 2 M r z r 0 0 − 1 ) (20) {\mathbf{g}}_{\mu \nu }^{\prime } = \left( \begin{matrix} 1 - \frac{2M}{r} & - \sqrt{\frac{2M}{r}}\frac{x}{r} & - \sqrt{\frac{2M}{r}}\frac{y}{r} & - \sqrt{\frac{2M}{r}}\frac{z}{r} \\ - \sqrt{\frac{2M}{r}}\frac{x}{r} & - 1 & 0 & 0 \\ - \sqrt{\frac{2M}{r}}\frac{y}{r} & 0 & - 1 & 0 \\ - \sqrt{\frac{2M}{r}}\frac{z}{r} & 0 & 0 & - 1 \end{matrix}\right) \tag{20} gμν= 1r2Mr2M rxr2M ryr2M rzr2M rx100r2M ry010r2M rz001 (20) 该模型在空间部分表现出平移不变性(右下角 3 × 3 3 \times 3 3×3 块是欧几里得度量)。Liu & Tegmark [56] 使用多层感知器(MLP)学习从 (t, x, y, z) 到 ( t ′ , x ′ , y ′ , z ′ ) \left( {{t}^{\prime },{x}^{\prime },{y}^{\prime },{z}^{\prime }}\right) (t,x,y,z) 的映射。定义雅可比矩阵 J ≡ ∂ ( t ′ , x ′ , y ′ , z ′ ) ∂ ( t , x , y , z ) , g \mathbf{J} \equiv \frac{\partial \left( {{t}^{\prime },{x}^{\prime },{y}^{\prime },{z}^{\prime }}\right) }{\partial \left( {t, x, y, z}\right) },\mathbf{g} J(t,x,y,z)(t,x,y,z),g 转换为 g ′ = J − T g J − 1 {\mathbf{g}}^{\prime } = {\mathbf{J}}^{-T}\mathbf{g}{\mathbf{J}}^{-1} g=JTgJ1 。我们取 g ′ {\mathbf{g}}^{\prime } g 的右下角 3 × 3 3 \times 3 3×3 块,并将其与欧几里得度量的差异计算,以获得均方误差(MSE)损失。通过对 MLP 进行梯度下降来最小化损失。为了简化,他们假设已知 x ′ = x , y ′ = y , z ′ = z {x}^{\prime } = x,{y}^{\prime } = y,{z}^{\prime } = z x=x,y=y,z=z ,并仅使用一个 MLP(1 个输入和 1 个输出)来预测半径 r r r 的时间差 Δ t ( r ) = t ′ − t = 2 M ( 2 u + ln ⁡ ( u − 1 u + 1 ) ) , u ≡ r 2 M {\Delta t}\left( r\right) = {t}^{\prime } - t = {2M}\left( {{2u} + \ln \left( \frac{u - 1}{u + 1}\right) }\right) , u \equiv \sqrt{\frac{r}{2M}} Δt(r)=tt=2M(2u+ln(u+1u1)),u2Mr

MLP 和 KAN 找到了不同的解决方案。我们训练了一个 MLP 和一个 KAN 来最小化这个损失函数,结果如图 12 所示。由于任务具有 1 个输入维度和 1 个输出维度,KAN 实际上简化为一个样条。我们最初预计 KAN 会优于 MLP,因为样条在低维设置中被认为是更优的 [63]。然而,尽管 MLP 可以达到 10 − 8 {10}^{-8} 108 的损失,KAN 却在 10 − 3 {10}^{-3} 103 的损失上停滞不前,尽管进行了网格细化。事实证明,KAN 和 MLP 学习了两种不同的解决方案:虽然 MLP 找到了一个全局平滑的解决方案(图 12 (a)),但 KAN 学习了一个域壁解决方案(图 12 (b))。域壁解决方案有一个奇点,将整个曲线分成两个部分。左侧部分正确学习了 Δ t ( r ) {\Delta t}\left( r\right) Δt(r) ,而右侧部分学习了 − Δ t ( r ) - {\Delta t}\left( r\right) Δt(r) ,这也是一个有效的解决方案,但与左侧部分的差异在于符号相反。在奇点处出现了损失尖峰(图 12 ©)。有人可能会将此视为 KAN 的一个特征,因为域壁解决方案在自然界中普遍存在。然而,如果将其视为缺陷,KAN 仍然可以通过添加正则化(以减少样条振荡)或尝试不同的随机种子(大约 1/3 的随机种子能找到全局平滑解决方案)来获得全局平滑的解决方案。

KANs 可以实现极高的精度。尽管 MLP 找到了全局平滑解并达到了 10 − 8 {10}^{-8} 108 损失,但损失仍远未达到机器精度。我们发现,无论是延长训练时间还是增加 MLP 的规模,都没有显著降低损失。因此,我们转向 KANs,作为一维样条,它们可以通过细化网格(在给定无限数据的情况下)实现任意精度。我们首先使用 MLP 作为教师,生成监督对 (x, y) 来训练 KAN 以拟合监督数据。通过这种方式,KAN 被初始化为全局平滑解。然后,我们通过将网格区间的数量增加到 1000 来迭代地细化 KAN。最终,经过微调的 KANs 达到了 10 − 15 {10}^{-{15}} 1015 的损失,接近机器精度(图 12 (d))。

5.4 学习本构定律

本构定律通过建模材料如何响应外部力或变形来定义材料的行为和性质。本构定律最简单的形式之一是胡克定律 [34],它线性地关联弹性材料的应变和应力。本构定律涵盖了广泛的材料,包括弹性材料 [80, 68]、塑性材料 [64] 和流体 [8]。传统上,这些定律是基于理论和实验研究从第一原理推导出来的 [79, 81, 6, 29]。然而,最近的进展引入了数据驱动的方法,利用机器学习从专门的数据集中发现和完善这些定律 [73, 91, 59, 60]。

图 13:通过与 KANs 互动发现本构定律(压力张量 P P P 和应变张量 F F F 之间的关系)。顶部:预测对角元素 P 11 {P}_{11} P11 ;底部:预测非对角元素 P 12 {P}_{12} P12

我们遵循 NCLaw [59] 中弹性部分的标准符号和实验设置,并将本构定律定义为一个参数化函数 E θ ( F ) → P {\mathcal{E}}_{\theta }\left( \mathbf{F}\right) \rightarrow \mathbf{P} Eθ(F)P ,其中 F \mathbf{F} F 表示变形张量, P \mathbf{P} P 表示第一 Piola-Kirchhoff 应力张量,以及 θ \theta θ 表示本构定律中的参数。

许多各向同性材料在变形较小时具有线性本构定律: P l = μ ( F + F T − 2 I ) + λ ( Tr ⁡ ( F ) − 3 ) I . (21) {\mathbf{P}}_{l} = \mu \left( {\mathbf{F} + {\mathbf{F}}^{T} - 2\mathbf{I}}\right) + \lambda \left( {\operatorname{Tr}\left( \mathbf{F}\right) - 3}\right) \mathbf{I}. \tag{21} Pl=μ(F+FT2I)+λ(Tr(F)3)I.(21) 然而,当变形增大时,非线性效应开始显现。例如,Neo-Hookean 材料具有以下本构定律: P = μ ( F F T − I ) + λ log ⁡ ( det ⁡ ( F ) ) I , (22) \mathbf{P} = \mu \left( {{\mathbf{{FF}}}^{T} - \mathbf{I}}\right) + \lambda \log \left( {\det \left( \mathbf{F}\right) }\right) \mathbf{I}, \tag{22} P=μ(FFTI)+λlog(det(F))I,(22) 其中 μ \mu μ λ \lambda λ 是所谓的 Lamé 参数,由所谓的杨氏模量 Y Y Y 和泊松比 ν \nu ν 确定,如 μ = Y 2 ( 1 + ν ) , λ = Y ν ( 1 + ν ) ( 1 − 2 ν ) \mu = \frac{Y}{2\left( {1 + \nu }\right) },\lambda = \frac{Y\nu }{\left( {1 + \nu }\right) \left( {1 - {2\nu }}\right) } μ=2(1+ν)Y,λ=(1+ν)(12ν)Yν 所示。为简便起见,我们选择 Y = 1 Y = 1 Y=1 ν = 0.2 \nu = {0.2} ν=0.2 ,因此 μ = 5 12 ≈ 0.42 \mu = \frac{5}{12} \approx {0.42} μ=1250.42 λ = 5 18 ≈ 0.28 \lambda = \frac{5}{18} \approx {0.28} λ=1850.28

假设我们正在处理 Neo-Hookean 材料,我们的目标是使用 KANs 从 F \mathbf{F} F 张量预测 P \mathbf{P} P 张量。假设我们不知道它们是 Neo-Hookean 材料,但我们有先验知识,即线性本构定律在小变形下大致有效。由于对称性,证明我们可以从 F \mathbf{F} F 的 9 个矩阵元素准确预测 P 11 {P}_{11} P11 P 12 {P}_{12} P12 就足够了。我们希望将线性本构定律编译成 KANs,这些 KANs 是 P 11 = 2 μ ( F 11 − 1 ) + λ ( F 11 + F 22 + F 33 − 3 ) {P}_{11} = {2\mu }\left( {{F}_{11} - 1}\right) + \lambda \left( {{F}_{11} + {F}_{22} + {F}_{33} - 3}\right) P11=2μ(F111)+λ(F11+F22+F333) P 12 = μ ( F 12 + F 21 ) {P}_{12} = \mu \left( {{F}_{12} + {F}_{21}}\right) P12=μ(F12+F21) 。我们希望从训练好的 KANs 中提取 Neo-Hookean 定律,这些定律是 P 11 = μ ( F 11 2 + F 12 2 + F 13 2 − 1 ) + λ log ⁡ ( det ⁡ ( F ) ) {P}_{11} = \mu \left( {{F}_{11}^{2} + {F}_{12}^{2} + {F}_{13}^{2} - 1}\right) + \lambda \log \left( {\det \left( \mathbf{F}\right) }\right) P11=μ(F112+F122+F1321)+λlog(det(F)) P 12 = μ ( F 11 F 21 + F 12 F 22 + F 13 F 23 ) {P}_{12} = \mu \left( {{F}_{11}{F}_{21} + {F}_{12}{F}_{22} + {F}_{13}{F}_{23}}\right) P12=μ(F11F21+F12F22+F13F23) 。我们通过独立于 U [ δ i j − w , δ i j + w ] ( w = 0.2 ) U\left\lbrack {{\delta }_{ij} - w,{\delta }_{ij} + w}\right\rbrack \left( {w = {0.2}}\right) U[δijw,δij+w](w=0.2) F i j {F}_{ij} Fij 采样生成一个合成数据集,并使用 Neo-Hookean 本构定律计算 P \mathbf{P} P 。我们与 KANs 的交互如图 13 所示。在这两种情况下,我们最终成功找到了真实的符号公式,得益于一些归纳偏见。然而,关键的收获并不是我们能够重新发现确切的符号公式——鉴于先验知识扭曲了这一过程——而是在现实世界的场景中,答案未知,用户可以根据先验知识进行猜测,pykan 包使得测试或结合先验知识变得容易。

预测 P 11 {P}_{11} P11 在第一步中,我们将线性本构定律 P 11 = 2 μ ( F 11 − 1 ) + λ ( F 11 + {P}_{11} = {2\mu }\left( {{F}_{11} - 1}\right) + \lambda \left( {{F}_{11} + }\right. P11=2μ(F111)+λ(F11+ F 22 + F 33 − 3 ) \left. {{F}_{22} + {F}_{33} - 3}\right) F22+F333) 编译为 KAN,使用 kanpiler,导致 10 − 2 {10}^{-2} 102 损失。在第二步中,我们对 KAN 进行扰动,使其变得可训练(通过颜色变化从红色变为紫色来指示;红色表示纯符号部分,而紫色表示符号和样条部分均处于活动状态)。在第三步中,我们训练扰动后的模型直到收敛,得到 6 × 10 − 3 6 \times {10}^{-3} 6×103 损失。在第四步中,假设行列式是一个关键辅助变量,我们使用 expand_width(对于 KAN)和 augment_input(对于数据集)来包含行列式 ∣ F ∣ \left| F\right| F 。在第五步中,我们训练 KAN 直到收敛,得到 2 × 10 − 4 2 \times {10}^{-4} 2×104 损失。在第六步中,我们将 KAN 转换为符号形式,以获得一个符号公式 P 11 = {P}_{11} = P11= 0.42 ( F 11 2 + F 12 2 + F 13 2 − 1 ) + 0.28 log ⁡ ( ∣ F ∣ ) {0.42}\left( {{F}_{11}^{2} + {F}_{12}^{2} + {F}_{13}^{2} - 1}\right) + {0.28}\log \left( \left| F\right| \right) 0.42(F112+F122+F1321)+0.28log(F) ,其损失为 3 × 10 − 11 3 \times {10}^{-{11}} 3×1011

预测 P 12 {P}_{12} P12 我们在有和没有将线性本构定律作为先验知识的情况下进行了实验。有先验知识:在第一步中,我们将线性本构定律编译为 KAN,导致损失为 10 − 2 {10}^{-2} 102 。然后我们执行一系列操作,包括扩展(第二步)、扰动(第三步)、训练(第四步)、剪枝(第五步)以及最后的符号化(第六步)。先验知识的影响显而易见,因为最终的 KAN 仅识别出对线性本构定律的微小修正项。最终的 KAN 被符号化为 P 12 = 0.42 ( F 12 + F 21 ) + 0.44 F 13 F 23 − 0.03 F 21 2 + 0.02 F 12 2 {P}_{12} = {0.42}\left( {{F}_{12} + {F}_{21}}\right) + {0.44}{F}_{13}{F}_{23} - {0.03}{F}_{21}^{2} + {0.02}{F}_{12}^{2} P12=0.42(F12+F21)+0.44F13F230.03F212+0.02F122 ,其损失为 7 × 10 − 3 7 \times {10}^{-3} 7×103 ,仅比线性本构定律稍好。在没有先验知识的情况下:在第一步中,我们随机初始化 KAN 模型。在第二步中,我们使用正则化训练 KAN。在第三步中,我们将 KAN 剪枝为一个更紧凑的模型。在第四步中,我们对 KAN 进行符号化,得到 P 12 = 0.42 ( F 11 F 21 + F 12 F 22 + F 13 F 23 ) {P}_{12} = {0.42}\left( {{F}_{11}{F}_{21} + {F}_{12}{F}_{22} + {F}_{13}{F}_{23}}\right) P12=0.42(F11F21+F12F22+F13F23) ,这与精确公式非常接近,实现了 6 × 10 − 9 6 \times {10}^{-9} 6×109 的损失。比较这两种情况——一种有先验知识,一种没有先验知识——揭示了一个令人惊讶的结果:在这个例子中,先验知识似乎是有害的,这可能是因为线性本构定律可能接近一个(不好的)局部最小值,模型很难逃脱。然而,我们可能不应该将这个结论随机推广到更复杂的任务和更大的网络。对于更复杂的任务,通过梯度下降找到局部最小值可能足够具有挑战性,因此一个近似的初始解是可取的。此外,更大的网络可能足够过度参数化,以消除不好的局部最小值,确保所有局部最小值都是全局的并且相互连接。

图14:KAN 在软件 1.0 和 2.0 之间进行插值。(a) KAN 在可解释性(软件 1.0)和可学习性(软件 2.0)之间取得平衡。(b) KAN 在可解释性尺度平面上的帕累托前沿。我们从 KAN 中获得的解释量取决于问题规模和可解释性方法。

6 相关工作

科尔莫哥洛夫-阿诺德网络(KANs)受到科尔莫哥洛夫-阿诺德表示定理(KART)的启发,最近由刘等人提出 [57]。尽管长期以来 KART 与网络之间的联系被认为无关紧要 [30],但刘等人将原始的两层网络推广到任意深度,并展示了其在科学导向任务中的潜力,考虑到其准确性和可解释性。后续研究探索了 KANs 在多个领域的应用,包括图 [12, 22, 38, 99]、偏微分方程 [87, 78] 和算子学习 [1, 78, 67]、表格数据 [70]、时间序列 [85, 28, 93, 27]、人类活动识别 [49, 50]、神经科学 [96, 33]、量子科学 [40, 46, 4]、计算机视觉 [17, 7, 44, 16, 76, 10]、核学习 [101]、核物理 [48]、电气工程 [69]、生物学 [71]。刘等人使用 B-样条对一维函数进行参数化,其他研究探索了各种激活函数,包括小波 [11, 76]、径向基函数 [47]、傅里叶级数 [92]、有限基 [35, 82]、雅可比基函数 [2]、多项式基函数 [75]、有理函数 [3]。还提出了 KANs 的其他技术,包括正则化 [5]、Kansformer(结合变换器和 KAN) [15]、自适应网格更新 [72]、联邦学习 [98]、卷积 KANs [10]。关于 KANs 是否真的在各个领域超越其他神经网络(尤其是 MLPs)仍存在持续的争论 [7, 16, 42, 77, 97],这表明尽管 KANs 在机器学习任务中显示出潜力,但仍需进一步发展以超越最先进的模型。

物理定律的机器学习 KANs 的一个主要目标是帮助从数据中发现新的物理定律。先前的研究表明,机器学习可以用于学习各种类型的物理定律,包括运动方程 [90, 13, 43, 20]、守恒定律 [55, 53, 54, 58, 32, 89]、对称性 [39, 56, 94]、相变 [88, 14]、拉格朗日和哈密顿 [19, 31] 以及符号回归 [18, 61, 23, 74] 等等。然而,使神经网络可解释通常需要领域特定的知识,这限制了它们的普适性。我们希望 KANs 能够演变为物理发现的通用基础模型。

机制可解释性旨在理解神经网络在基本层面上的运作 [21, 62, 86, 25, 66, 100, 51, 24, 45, 26]。该领域的一些研究专注于设计本质上可解释的模型 [24] 或提出明确促进可解释性的训练方法 [51]。KANs 属于这一类别,因为科尔莫哥洛夫-阿诺德定理将高维函数分解为一组一维函数,这些函数比高维函数显著更易于解释。

7 讨论

KAN 在软件 1.0 和 2.0 之间进行插值。Kolmogorov-Arnold 网络(KANs)与其他神经网络(软件 2.0,Andrej Karpathy 创造的术语)之间的关键区别在于它们更高的可解释性,这使得用户可以进行操作,类似于传统软件(软件 1.0)。然而,KANs 并不完全是传统软件,因为它们具有 (1) 学习能力(好),使其能够从数据中学习新知识,以及 (2) 降低的可解释性(坏),因为随着网络规模的增加,它们变得不那么可解释和可控。图 14 (a) 可视化了软件 1.0、软件 2.0 和 KANs 在可解释性-学习能力平面上的位置,说明了 KANs 如何在这两种范式之间平衡权衡。本文的目标是提出各种工具,使 KANs 更像软件 1.0,同时利用软件 2.0 的学习能力。

效率提升。原始的 pykan 包 [57] 效率较低。我们已经整合了一些技术来提高其效率。

  1. 高效的样条评估。受到高效 KAN [9] 的启发,我们通过避免不必要的输入扩展来优化样条评估。对于具有 L L L 层、每层 N N N 个神经元和网格大小 G G G 的 KAN,内存使用量已从 O ( L N 2 G ) O\left( {L{N}^{2}G}\right) O(LN2G) 降低到 O ( L N G ) O\left( {LNG}\right) O(LNG)

  2. 仅在需要时启用符号分支。KAN 层包含一个样条分支和一个符号分支。符号分支的计算时间远远超过样条分支,因为它无法并行化(需要灾难性的双重循环)。然而,在许多应用中,符号分支是多余的,因此我们可以在可能的情况下跳过它,从而显著减少运行时间,尤其是在网络较大时。

  3. 仅在需要时保存中间激活。为了绘制 KAN 图,必须保存中间激活。最初,激活默认被保存,这导致运行时间变慢和内存使用过多。现在,我们仅在需要时保存中间激活(例如,用于绘图或在训练中应用正则化)。用户可以通过一行代码启用这些效率改进:model.speed( )。

  4. GPU 加速。最初,由于问题的规模较小,所有模型都在 CPU 上运行。我们现在已使模型兼容 GPU 6。例如,使用 Adam 在 CPU 上训练一个 [ 4 , 100 , 100 , 100 , 1 ] \left\lbrack {4,{100},{100},{100},1}\right\rbrack [4,100,100,100,1] 需要整整一天(在实施 1,2,3 之前),但现在在 CPU 上只需 20 秒,在 GPU 上不到一秒。然而,KAN 在效率上仍然落后于 MLP,尤其是在大规模时。社区一直在努力对 KAN 的效率进行基准测试和改进,效率差距已显著缩小 [36]。

由于本文的目标是使 KAN 更像软件 1.0,在面对 1.0(具有交互性和多功能性)与 2.0(高效且特定)之间的权衡时,我们优先考虑交互性和多功能性,而非效率。例如,我们在模型内存储缓存数据(这会消耗额外的内存),因此用户可以简单地调用 model.plot( ) 来生成 KAN 图,而无需手动进行前向传播以收集数据。

可解释性 尽管 KANs 中可学习的一元函数比 MLPs 中的权重矩阵更具可解释性,但可扩展性仍然是一个挑战。随着 KAN 模型的扩展,即使所有样条函数在个体上都是可解释的,管理这些一维函数的组合输出也变得越来越困难。因此,KAN 可能仅在网络规模相对较小时保持可解释性(图 14 (b),粗红线)。需要注意的是,可解释性依赖于内在因素(与模型本身相关)和外在因素(与可解释性方法相关)。高级可解释性方法应该能够在不同层次上处理可解释性。例如,通过使用符号回归、模块发现和特征归因来解释 KANs(图 14 (b),细红线),可解释性与规模的帕累托前沿超出了 KAN 单独能够实现的范围。未来研究的一个有希望的方向是开发更高级的可解释性方法,以进一步推动当前的帕累托前沿。

未来工作 本文介绍了一个将 KANs 与科学知识相结合的框架,主要集中在小规模的与物理相关的示例上。未来的两个有希望的方向包括将该框架应用于更大规模的问题,并将其扩展到物理以外的其他科学学科。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐