目录
参考网站
requires_grad
# set requires_grad = True will allow future calculation of the gradient dy/dx
x = torch.ones(5, requires_grad = True)
gradient
# when we do any opertion, it will track the movement, calculate and store grad_fn for future backpropagation
y = x + 2
# case 1: z is a scalar
z = y*y*2
z = z.mean()
z.backward() # this step will calculate gradient dz/dx, now x will have an attribute: grad
# case 2: z is not a scalar
z = y*y*2
v = torch.tensor([0.1, 1, 0.001], dtype = torch.float32)
z.backward(v)
print(x.grad)
Prevent tracking gradients
# Three Options:
## First One
x.requires_grad_(False)
## Second one
y = x.detach()
## Third One
with torch.no_grad():
y = x +2
Attention!
weights = torch.ones(4, requires_grad = True)
for epoch in range(3):
#first iteration
model_ouput = (weights *3).sum()
model_output.backward()
# now we can get weights.grad
print(weights.grad) # give us [3, 3, 3, 3]
# In the second iteration, weigts.grad will accumulate
print(weights.grad) # give us [6, 6, 6, 6]
# Third iteration
print(weights.grad) # give us [9, 9, 9, 9]
# What the correct things should be like?
# In each epoch, we calculate grad, update weights and calculate new grad without accumulating previous gradient values. Therefore before running into the next iteration, we should set weights.grad to zero
weight.grad.zero_()