半监督学习之Mean teachers
半监督学习Mean teachers网络整体的架构包括两个部分student model和teacher model:student model的网络参数通过学习,梯度下降获得。teacher model的网络参数通过student model的网络参数的moving average得到。student model的网络参数更新方法:通过损失函数的梯度下降更新参数得到。其中损失函数包括两个部分:第
半监督学习Mean teachers
网络整体的架构包括两个部分student model和teacher model:
-
student model的网络参数通过学习,梯度下降获得。
-
teacher model的网络参数通过student model的网络参数的moving average得到。
student model的网络参数更新方法:
通过损失函数的梯度下降更新参数得到。
其中损失函数包括两个部分:
第一部分是有监督损失函数,保证有标签训练数据拟合;
第二部分是无监督损失函数,主要是保证student model的预测结果和teacher model的预测结果尽量的相似。因为teacher model的参数是student model的网络参数的moving average,所以,对于任何新来的数据,预测结果都不应该有太大的抖动。
如果如果模型是正确的,那么前后两个模型的预测标签应该是接近的,并且变化较小的,那么使模型向使两个模型预测结果接近的方向移动,就是向groudtruth model移动。
teacher model的网络参数的更新方法:
通过student model网络参数的moving average得到
θt′=αθt−1′+(1−α)θt\theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t}θt′=αθt−1′+(1−α)θt
基本流程
假设有一批训练样本X1,X2,其中X1使有标签数据(对应标签是z1),X2使无标签数据。具体的训练过程如下:
-
把这一批样本作为student网络输入,然后分别得到输出的标签:ys1,ys2;
-
构造对于有标签数据X1的损失函数,有标签分类损失函数L1(z1,ys1);
-
把这批数据作为teacher model的输入,得到输出的标签yt1,yt2;
-
构造无监督损失函数L2,论文中采用MSE损失函数:J(x,θ)=Ex,η′,η[∣∣f(x,θ′,η′)−f(x,θ,η)∣∣2]J(x, \theta)=E_{x, \eta ^{\prime}}, \eta \left[ ||f(x, \theta ^{\prime}, \eta ^{\prime})-f(x, \theta , \eta)||^{2}\right]J(x,θ)=Ex,η′,η[∣∣f(x,θ′,η′)−f(x,θ,η)∣∣2]
-
总损失函数L1+L2梯度下降,更新student model的网络参数,通过moving average更新teacher model的网络参数θt′=αθt−1′+(1−α)θt\theta_{t}^{\prime}= \alpha \theta _{t-1}^{\prime}+(1- \alpha)\theta _{t}θt′=αθt−1′+(1−α)θt

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐
所有评论(0)