接下来以线性回归模型对pytorch框架进行一个介绍
在pytorch中对于任何一个模型都要先定义一个模型的内,在这个模型类中super下和前向传播中将会记录数据经过哪些层(卷积层,池化层等),这样就可以建立一个基本的模型
接下来指定参数和损失函数
首先是学习的次数和学习率,其次要指定优化器,pytorch中有很多种优化器可以选择,这里选择最基本的SGD就行,其次要指定损失函数,pytorch中也提供了非常多的损失函数,也可以自行选择。
然后定义一下参数
然后是训练参数,这里要注意每次把优化器清零和进行参数更新
然后预测一下(就是直接前向传播一下)
然后是保存和加载模型
可以发现模型是以字典的形式进行保存的