简单粗暴PyTorch之模型创建与nn.module(重要!!!)
模型创建与nn.module
一、模型的创建
模型创建的步骤:
模型构建的两个要素
- 构建子模块
- 拼接子模块
1.1 构建子模块
以之前上传的纸币二分类代码为例,对模型的构建进行讲解。
在70行设置断点,查看如何构建LeNet网络。
step into 进入
进入到LeNet的__init__函数构建子模块,创建了两个卷积层与三个全连接层。运行完最后一个子模块self.fc3就会跳出返回,模型的初始化便完成了。
1.2 拼接子模块
模块在__init__构建好之后,神魔时候拼接实现前向传播呢。
下次使用模型在训练时,直接跳转到,训练时模型的使用
step into进入95行代码,进入到了module.py的call函数,因为LeNet继承于nn.module
进入到547行forword,查看在哪里实现了前向传播,进入到了LeNet函数的forword函数,具体实现前向传播,每一层网络的计算
得到分类结果out,返回out,得到了outputs输出
二、nn.Module
nn.Module在torch.nn下
nn.module通过八个有序字典来管理模型
• parameters: 存储管理nn.Parameter类,如权值、偏置
• modules : 存储管理nn.Module类,如卷积层、池化层
• buffers:存储管理缓冲属性,如BN层中的running_mean
• ***_hooks:存储管理钩子函数
debug进入到LeNet的__init__函数
因为继承与nn.Module,进入14行,到nn.module的__init__函数下
进入到construct函数,这里实现了八个有序字典的初始化
可以看到字典_modules下存贮的是卷积层、全连接层
因为卷积层也是继承于nn.module,也有八个字典,可以看到,parameters下面存储的是权重与偏置
每次进行类属性的赋值,会被module中的setattr函数拦截,判断属于parameters还是modules
nn.Module总结
• 一个module可以包含多个子module
• 一个module相当于一个运算,必须实现forward()函数
• 每个module都有8个字典管理它的属性