替换成0,则变量保持不动
0: xtensor([1.0000, 2.0000, 3.0000, 4.5000], device='cuda:0', requires_grad=True)
0: xtensor([1.0000, 2.0000, 3.0000, 4.5000], device='cuda:0', requires_grad=True)
替换成1,变量会变化
0: xtensor([1.0000, 2.0000, 3.0000, 4.5000], device='cuda:0', requires_grad=True)
0: xtensor([1.0000, 2.0000, 3.0000, 4.5000], device='cuda:0', requires_grad=True)
代码如下:
import torch
import numpy as np
import torch.optim as optim
torch.autograd.set_detect_anomaly(True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.tensor([1.0, 2.0, 3.0,4.5], dtype=torch.float32, requires_grad=True, device=device)
y1 = x*2
y2 = x**3
print(f'x{x}')
print(f'y1{y1}')
print(f'y2{y2}')
optimizer = optim.Adam([x], lr=1)
def replace_gradients_zero(grad):
return torch.where(torch.ones_like(grad, dtype=torch.bool), torch.zeros_like(grad), grad)
def replace_gradients_one(grad):
return torch.where(torch.ones_like(grad, dtype=torch.bool), torch.ones_like(grad), grad)
for i in range(1):
print(f'{i}: x{x}')
optimizer.zero_grad()
loss = (y1-y2).sum()
for param in [x]:
param.register_hook(replace_gradients_zero)
loss.backward()
optimizer.step()
print(f'{i}: x{x}')