- 加载数据和标签(如手写数字识别中的数据就是一张图片,而一张图片在计算机里面是一个由整数组成的矩阵,它的标签是当前这个图片到底是数字几)
- 设计网络结构(如是共有几层神经网络,每层输入输出的矩阵格式,每层的激活函数是怎样的,当前层到底是卷积层还是循环神经网络层等等)
- 设计损失(误差)函数(以手写数字识别为例,神经网络会输出一个值表明当前这个图片到底是几,但是一开始肯定不准确。所以我们需要计算神经网络输出值与标签之间的误差。而误差计算有很多种我们需要告诉pytorch怎么计算误差。然后pytorch就会自动根据我们设定的误差计算方法去自动调整神经网络的参数)
- 设置用于自动调整神经网络参数的优化器(在上一步我们提到了pytorch会自动帮我们自动调整神经网络的参数,但是具体用哪种优化器去调整参数呢?这需要我们告诉pytorch,因此我们需要用代码设置具体用哪个优化器调整参数)
- 使用数据和标签训练神经网络
这个神经网络已经训练好可以用于解决你的问题了。(你只用输入数据,然后用神经网络输出结果即可)
摘录于 https://zhuanlan.zhihu.com/p/76901725
通用深度学习模型实现流程
最新推荐文章于 2022-04-10 18:24:24 发布