大家在学习mnist项目时,一般先从了解线性模型开始,很多教程给出的线性模型如下图:
然而,代码中的线性模型却是这样:
y = tf.matmul(x,W) + b
学过线性代数的都知道,矩阵是不存在交换律的,为什么模型中公式和代码中的公式权值矩阵W和特征矩阵却是相反呢?
要弄明白这个问题,首先要搞清楚线性模型本质:设X=(x1,x2,..xn)为特征向量,W=(w1,w2...wn)为权值向量,b为偏移量,线性模型的任务是找出这样一个W,b,使得对于所有X,经过运算w1*x1+w2*x2...+wn*xn+b后,得出其正确的预测值。比如,在MNIST中,假如X是数字1的手写字体特征数据,采用one-hot编码,经过该运算后应得出[0,1,0,0,0,0,0,0,0,0](当然由于经过softmax后得出的是一个概率值,此处只是举个例子)。所以线性模型关键点在于w1*x1+w2*x2...+wn*xn+b这个运算,而X或者都能达到此目的。
每次读100条MNIST数据集中的图形数据和labels,将其shape打印出来,如下图:
代码中之所以用那种形式是因为MNIST数据集中的X的shape是1*784的(而我们习惯了特征用列向量表示,所以才会疑惑),因此相应的用X这种乘法。由于是one-hot编码,所以的shape为784*10,所以线性运算的结果是1*10,是一个行向量,而MNIST数据集中的labels刚好也是1*10的。所以,为了简便,就使用代码中那种乘法了,否则还要经过两次转置,很麻烦。
学习机器学习算法时经常遇到矩阵乘法,而且特征有时是行向量有时又是列向量,经常被搞迷糊,所以记录一下提醒自己,遇到问题要冷静分析。