在pytorch中的前向计算中经常会碰到x.mm()这个函数,先看一段代码
class Linear(nn.Module):
def __init__(self,in_feature,out_feature):
super(Linear,self).__init__()
self.w=nn.Parameter(t.randn(in_feature,out_feature))
self.b=nn.Parameter(t.randn(out_feature))
def forward(self,x):
x=x.mm(self.w)
return x+self.b.expand_as(x)
layer=Linear(4,3)
input=V(t.randn(2,4))
output=layer(input)
output
输出结果:
tensor([[-2.3766, -2.5417, 3.6651],
[-0.4582, -0.3701, 1.0399]], grad_fn=<AddBackward0>)
在前向计算中x.mm的作用是将参数和变量相乘,但是这里有两个要注意的地方:
1是w和x相乘的时候的顺序是x*w
2是w和x的类型,x必须是tensor,如果w需要训练的话则需要将其加入到Module中的parameter(参数迭代器)中,特别注意一点的是:在pytorch中tensor是不能进行训练的。而且input送入网络计算之前,需要将tensor变成变量。