pytorch中的.mm

本文介绍了PyTorch中线性层Linear的实现,重点讲解了前向传播过程中x.mm()函数的作用。该函数用于矩阵乘法,注意输入参数的顺序和类型。在使用时,权重w需要作为Module的Parameter,输入x需转换为变量。通过示例代码展示了Linear层如何对输入进行处理并得到输出。
摘要由CSDN通过智能技术生成

在pytorch中的前向计算中经常会碰到x.mm()这个函数,先看一段代码

class Linear(nn.Module):
    def __init__(self,in_feature,out_feature):
        super(Linear,self).__init__() 
        self.w=nn.Parameter(t.randn(in_feature,out_feature)) 
        self.b=nn.Parameter(t.randn(out_feature))
    
    def forward(self,x):
        x=x.mm(self.w)
        return x+self.b.expand_as(x)
    
layer=Linear(4,3)
input=V(t.randn(2,4))
output=layer(input)
output

输出结果:

tensor([[-2.3766, -2.5417,  3.6651],
        [-0.4582, -0.3701,  1.0399]], grad_fn=<AddBackward0>)

在前向计算中x.mm的作用是将参数和变量相乘,但是这里有两个要注意的地方:
1是w和x相乘的时候的顺序是x*w
2是w和x的类型,x必须是tensor,如果w需要训练的话则需要将其加入到Module中的parameter(参数迭代器)中,特别注意一点的是:在pytorch中tensor是不能进行训练的。而且input送入网络计算之前,需要将tensor变成变量。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

骑猪的骑士

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值