addmm_与addmm的区别

两者唯一区别:addmm_()函数可以在原对象的基础上进行修改,而addmm()函数则没有该功能。
我在看视频学习敲代码的时候,误将

dist.addmm_(1, -2, inputs, inputs.t())

写成了:

dist.addmm(1, -2, inputs, inputs.t())

导致后期程序怎么也跑不出正常一点的结果,为了找到错误所在,真的是煞费苦心了。希望大家不要犯和我一样的错误,敲代码时擦亮眼睛,遇到不懂的函数,绝不模棱两可!

接下里介绍此行代码的含义:

dist.addmm(a1, a2, inputs, inputs.t())

a1 * dist + a2 * inputs * inputs.t()        (后面是矩阵乘法,行乘列)

非常不好意思,误删了某位同学的评论,在此补上解答的内容:

问题:这行代码用于何种场景?

解答:建议你找几个矩阵举个例子,把它们画出来,就大体上知道是什么原理,主要是用于计算两个矩阵所有行元素之间的距离平方。(通常行元素是指一个样本的特征向量,初始A、B矩阵行数也就是样本数可以不同,但列数,也就是特征向量的长度需要一致)

自己学习时画的图比较复杂,这里我用文字讲解:
1、dist是由矩阵A、B计算得来的,假设A的形状为(a,b),B的形状为(m,n)。
2、A矩阵先将各个元素分别求平方(形状不变),后按行求和(形状:a*1),然后列向扩展m列(形状:a*m),这里扩展的意思就是重复,重复m列。
3、同理,B矩阵先将各个元素分别求平方(形状不变),后按行求和(形状:m*1),然后列向扩展a列(形状:m*a),这里扩展的意思就是重复,重复a列,最后再转置(形状:a*m)。
4、将最后得到的A’矩阵和B’矩阵按元素相加,得到dist矩阵(形状:a*m)。这个时候得到的每一个元素值,比如说位置在dist(c,d)上的元素值,就是原A矩阵c行平方和的值,加上原B矩阵d行平方和的值。
5、这里开始讲到addmm的功能,dist.addmm_(1, -2, A, B.T),意思是 dist-2A*B.T 。B.T是指B的转置
6、A*B.T采用的是矩阵乘法得到矩阵C(形状:a*m),这要求A和B它们的列是要相同的。矩阵C(x,y)处的元素指的是A的第x行元素和B的第y行元素,对应元素之间相乘后再相加(结合矩阵乘法可想象该过程)。
7、计算 dist-2A*B.T,即得到最后的新的dist矩阵(形状:a*m),矩阵dist(e,f)处的元素指的是,(A矩阵e行平方和的值 + B矩阵f行平方和的值)- 2*(A矩阵e行元素和B矩阵f行元素 对应元素之间相乘后再全部相加得到的值),其实就是,(A矩阵e行第i个值的平方 + B矩阵f行第i个值的平方 - 2倍的它们的乘积)对于i从0到b的累加求和,最后得到的也就是A的各个行元素(即每个样本特征)与B的各个行元素(即每个样本特征)之间的距离的平方。
8、最后说一点,之所以不求根号,不直接求距离是因为没必要增加计算复杂度,这个距离的平方也可以起到相同的作用,各个值之间的差距也更大,便于比较。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我有明珠一颗

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值