pytorch
数据准备和使用
模型的定义
熟练掌握训练过程和结果可视化
训练方法(sgd\adam)和测试方法
- 1、读取数据的指令
torch.utils.data
class torch.utils.data.Dataset
表示dataset的抽象类
所有其他数据集都应该进行子类化,所有子类应该是override __len__和__getitem__,前者提供数据集大小,后者支持整数索引,范围从0到len(self)
- 2、搭建网络模块指令
torch.nn
在pytorch里编写网络结构,几乎所有的层结构都来自于torch.nn函数
- 3、训练指令
(1)优化方法的选择
torch.optim
torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成复杂的方法
如何使用optimizer
为了使用torch.optimizer,你需要构建一个optimizer对象,这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。
(2)学习率调整
torch.optim
torch.optim是一个实现了各种优化算法的库。
如何使用optimizer
为了使用optimizer,你需要构建一个optimizer对象。这个对象能够保持当前的参数状态并且基于计算得到的梯度进行参数的更新。
(3)网络参数初始化的选择
torch.nn.init
torch.nn.init.calculate_gain
- 4、测试指令
设置网络不跟新参数,同时输出相应的测试结果即可
- 5、可视化方法
网络结构的可视化
- 6、模型的保存和加载指令
(1)保存
torch.save[source]
torch.save(obj, f ,pickle_module=<module 'pickle' from '/home/conda/lib/python3.5/pickle.py'> )
保存一个对象到一个硬盘文件上参考Recommended approach for saving a model参数:
obj保存对象
f 类文件对象(返回文件描述符)或一个保存文件名的字符串
pickle_module- 用于pickling元数据和对象的模块
pickle_protocol指定pickle protocal可以覆盖默认参数
保存整个模型的结构和参数
或仅保存模型的参数
(2)加载
torch.load
加载整个模型的结构和参数
或者仅加载模型的参数