运行环境
使用环境:python3.8
平台:Windows10
IDE:PyCharm
书中片段
如果不理解上图各个变量的意思,可以参看下面这幅图
理解
accuracy函数中的
return (y_hat.argmax(dim=1) == y).float().mean().item()
功能如下:
- 比较y_hat(预测概率)中最大值索引与真实情况y异同
- 将得到的异同情况进行提取为数字形式,相同则计为1,不同计为0
- 对异同情况取平均值,得到预估准确率,值得注意的是,在这个例子的情况下,错一个,对一个,所以准确率是(0+1)/2 = 0.5,而不是这个0.5
进一步思考
为表示清楚return (y_hat.argmax(dim=1) == y).float().mean().item()
这一步的作用,将其分解为一些步骤进行debug
附言
由于本文首先记载于博主的OneNote中,在撰写CSDN博客时涂鸦功能无法实现,故而在本文最后加上OneNote中的截图,旨在帮助小伙伴们理解。