《MInst数据集实战》
1、自己从头开始写
1.1原理解读:
内容:输入28*28=784的数字图片向量,第一个隐藏层将784用矩阵运算压缩到200,第二层不压缩,仍保持200个参数。最后一层提取出10个参数,作为10分类,到底是哪个数字的问题。
第一步:创建网络,创建三个线性层,每个线性层都有参数w和b,注意w和b的矩阵维度,和申明需要梯度计算。因为最后是10分类,所以最后w和b都有10的维度。
第二部:创建预测值x新=relu(x旧*w的转置+b)。
第三部:定义优化器(迭代计算)。
- 优化目标是三组全连接层的变量w1、w2、w3、b1、b2、b3。
- 并设置学习率,可设为0.001。
- 定义loss函数为交叉熵函数。Nn.crossentropyloss与F. cross_entropy功能一样,都已经包含softmax函数了。
- 定义迭代次数epochs。
- Forward函数为之前定义好的网络输出结果,即输出(原始数据data),输出(网络计算过的logits)
- 这里是调用Nn.crossentropyloss函数,并计算loss(这里把这个函数赋值给了criteon)。输入给loss的是(logit计算的10分类)和(目标的10分类one-hot编码)。