nn.Module 是 PyTorch 中所有神经网络的基类
nn.Module是 PyTorch 框架中构建神经网络的核心基类。它帮助你组织网络结构、管理模型参数、定义前向传播逻辑,并为模型的训练和推理提供便利。
nn.Module 是 PyTorch 中所有神经网络的基类,几乎所有的神经网络模型都需要从 nn.Module 继承。它封装了神经网络的基本结构,包括层的定义、参数管理、前向传播等功能,使得构建、训练和优化深度学习模型变得方便。以下是 nn.Module 的主要功能和作用:
1. 网络层的容器
nn.Module 是一个用于存储和组织网络层的容器。通过继承 nn.Module,你可以将各种网络层(如卷积层、全连接层、批量归一化层等)添加到模型中。它能够自动将这些层的参数注册到模型中,方便后续调用和训练。
例如:
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3) # 定义一个卷积层
self.fc1 = nn.Linear(16*6*6, 10) # 定义一个全连接层
2. 前向传播逻辑(forward 方法)
在继承 nn.Module 后,你需要重写 forward 方法,以定义数据如何在模型中流动。这个方法是必须实现的,表示输入如何通过各个网络层并最终得到输出。
例如:
python
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3)
self.fc1 = nn.Linear(16*6*6, 10)
def forward(self, x):
x = self.conv1(x) # 数据流过卷积层
x = x.view(x.size(0), -1) # 展平
x = self.fc1(x) # 数据流过全连接层
return x # 输出结果
3. 参数管理
nn.Module 能够自动跟踪所有层中的参数,如权重和偏置。这些参数会被存储在 model.parameters() 中,方便后续用于优化器和梯度更新。
例如:
model = MyNet()
for param in model.parameters():
print(param) # 打印模型中的权重和偏置
4. 模型训练与推理模式
nn.Module 提供了 train() 和 eval() 方法,分别用于设置模型为训练模式和推理(评估)模式。这会影响到一些特定层(如 Dropout 和 BatchNorm)的行为。
model.train():启用训练模式,Dropout 等层会保留随机性。model.eval():启用推理模式,Dropout 会关闭,BatchNorm 的均值和方差会固定。
model.train() # 训练模式
model.eval() # 推理模式
5. 模块嵌套
nn.Module 还支持模块的嵌套。你可以在一个 nn.Module 中嵌套其他 nn.Module,从而构建复杂的网络。
例如:
class Block(nn.Module):
def __init__(self):
super(Block, self).__init__()
self.conv = nn.Conv2d(3, 16, 3)
def forward(self, x):
return self.conv(x)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.block1 = Block()
self.block2 = Block()
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
return x
6. 模型参数的保存与加载
通过 nn.Module,你可以方便地保存和加载模型的权重。
- 保存权重:
torch.save(model.state_dict(), "model.pth")- 加载权重:
model.load_state_dict(torch.load("model.pth"))
总结
nn.Module 是 PyTorch 框架中构建神经网络的核心基类。它帮助你组织网络结构、管理模型参数、定义前向传播逻辑,并为模型的训练和推理提供便利。
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)