最近在写一个自己的网络,也是基于Resnet50改造,但是没法直接用预训练过的resent更新参数。直接报错预训练里的参数没有我加的层数,我一开始觉得无所谓,直接拿来训练,发现训练的精度都上不去,毫无起色,对于网络训练参数初始化有了更深的思考。后来经过寻找类似的程序和网上找帖子。终于解决了。

第一种方法

def resnet50(pretrained = False,**kwargs):
	model = ResNet(Bottleneck,[3,4,6,3],**kwargs)
	if pretrained:
		#model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))  #开始的样子,直接加载预训练好的参数
		pretrained_state_dict = model_zoo.load_url(model_urls['resnet50'])#旧参数
		now_state_dict = model.state_dict()#新参数
		now_state_dict.update(pretrained_state_dict)#更新
		model.load_state_dict(now_state_dict)#新网络加载预训练的参数
	return model

if __name__=='__main__':
	net = resnet50(True)
	print(net)

按照上边的操作即可。

第二种方法

#加载model
resnet50 = models.resnet50(pretrained=True)
cnn = CNN(Bottleneck, [3, 4, 6, 3])
#读取参数
pretrained_dict = resnet50.state_dict()
model_dict = cnn.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
cnn.load_state_dict(model_dict)
# print(resnet50)
print(cnn)

这个方法借鉴于:https://blog.csdn.net/whut_ldz/article/details/78882532这篇博客,博主也是很厉害的

我一开始在网上找的方法就是第二种,经过测试,没问题,参数加载进去,程序结果随着训练正常上升或者下降。

第一种方法我只是测试了正常输出网络,(自己一开始正常输出都报错,现在可以正常输出网络了),应该没问题、可以正常训练了。

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐