【pytorch】torch.nn.GroupNorm的使用
torch.nn.GroupNorm字面意思是分组做Normalization,官方说明在这里。torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)计算公式E[x]是x的均值;Var[x]是标准差;gama和beta是训练参数,如果不想使用,可以通过参数affine
torch.nn.GroupNorm
字面意思是分组做Normalization,官方说明在这里。
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)
计算公式

E[x]是x的均值;
Var[x]是标准差;
gama和beta是训练参数,如果不想使用,可以通过参数affine=False设置。默认为True;
eposilon是输入参数,防止Var为0,默认值为1e-05,可以通过参数eps修改。
输入张量要求
输入的张量至少是2维的,其中第一维度为Channel,后面的维度为特征数据。
使用示例
GroupNorm 是将第一维度的Channels按group分,然后每个group按照上面的计算公式做计算。
比如,
input shape = (4,5)
gn = GroupNorm (2,4)
output = gn(input)
那么output就是将4个channel的数据分为2组,前1-2channel为一组,并按公式计算;后3-4channel为一组,并按公式计算;但是这里输出的shape还是(4,5)
GroupNorm 不会改变输入张量的shape,它只是按照group做normalization
三维,四维以上都一样,比如这里的input shape =(4,1,2,3,4,5),GroupNorm 的作用仅仅针对第一维度的channel。
报错
如果GroupNorm 输入的channel num与输入不一致,则会报错RuntimeError: Expected number of channels in input to be divisible by num_groups
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)