关于MNIST线性模型矩阵顺序问题

大家在学习mnist项目时,一般先从了解线性模型开始,很多教程给出的线性模型如下图:

然而,代码中的线性模型却是这样:

y = tf.matmul(x,W) + b

学过线性代数的都知道,矩阵是不存在交换律的,为什么模型中公式和代码中的公式权值矩阵W和特征矩阵却是相反呢?

要弄明白这个问题,首先要搞清楚线性模型本质:设X=(x1,x2,..xn)为特征向量,W=(w1,w2...wn)为权值向量,b为偏移量,线性模型的任务是找出这样一个W,b,使得对于所有X,经过运算w1*x1+w2*x2...+wn*xn+b后,得出其正确的预测值。比如,在MNIST中,假如X是数字1的手写字体特征数据,采用one-hot编码,经过该运算后应得出[0,1,0,0,0,0,0,0,0,0](当然由于经过softmax后得出的是一个概率值,此处只是举个例子)。所以线性模型关键点在于w1*x1+w2*x2...+wn*xn+b这个运算,而XW^{T}或者WX^{T}都能达到此目的。

每次读100条MNIST数据集中的图形数据和labels,将其shape打印出来,如下图:

代码中之所以用那种形式是因为MNIST数据集中的X的shape是1*784的(而我们习惯了特征用列向量表示,所以才会疑惑),因此相应的用XW^{T}这种乘法。由于是one-hot编码,所以W^{T}的shape为784*10,所以线性运算的结果是1*10,是一个行向量,而MNIST数据集中的labels刚好也是1*10的。所以,为了简便,就使用代码中那种乘法了,否则还要经过两次转置,很麻烦。
学习机器学习算法时经常遇到矩阵乘法,而且特征有时是行向量有时又是列向量,经常被搞迷糊,所以记录一下提醒自己,遇到问题要冷静分析。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值