.sum()
函数主要有两个作用,一个是用来求和,一个是用来降维。而在这里是用到了降维的作用。
Pytorch进行梯度的计算,只能对标量进行梯度计算,例如
y
=
x
2
+
x
+
1
y = x^2 +x +1
y=x2+x+1这是一个标量,是能够进行梯度计算的,而例如
y
=
[
x
1
,
x
2
]
2
+
[
x
1
,
x
2
]
+
[
1
,
1
]
y=[x_1, x_2]^2 +[x_1, x_2] +[1, 1]
y=[x1,x2]2+[x1,x2]+[1,1]这是二维的,pytorch并不能进行梯度反向传播计算梯度,所以我们需要使用sum
进行降维处理,变成
y
=
x
1
2
+
x
2
2
+
x
1
+
x
2
+
1
y=x_1^2 + x_2^2 +x_1+x_2 +1
y=x12+x22+x1+x2+1,对于多元函数便能计算偏微分,求梯度了。
例子如下,y_hat和y是多维的,所以先要sum再backward:
X = X.reshape((1, 1, 6, 8))
Y = Y.reshape((1, 1, 6, 7))
lr = 3e-2 # Learning rate
for i in range(10):
Y_hat = conv2d(X)
l = (Y_hat - Y) ** 2
conv2d.zero_grad()
l.sum().backward()
# Update the kernel
conv2d.weight.data[:] -= lr * conv2d.weight.grad
if (i + 1) % 2 == 0:
print(f'epoch {i + 1}, loss {l.sum():.3f}')
print(conv2d.weight.data.reshape((1, 2)))