对比学习中,simsiam用到了z变量停止梯度回传,z变量生成的变量p存在梯度,验证如下(鬼知道我为什么要验证。。):
import torch
import torch.nn as nn
def f(x):
return x + 1.
def h(z):
return z + 1.
def d(p,z):
z = z.detach()
return -(p*z).sum(dim=1).mean()
def save_grad(name):
def hook(grad):
grads[name] = grad
return hook
if __name__ == '__main__':
grads = {}
x = torch.tensor([[2.,4,6],
[1,3,5]],requires_grad=True)
x1 = x.sigmoid()
x2 = x.relu()
print(x1,x2)
z1 = f(x1)
z2 = f(x2)
p1 = h(z1)
p2 = h(z2)
p1.register_hook(save_grad('p1'))
z1.register_hook(save_grad('z1'))
l = d(p1,z1) * 0.5 + d(p2,z2) * 0.5
z2.register_hook(save_grad('z2'))
# p2 = p2.detach()
# p2.register_hook(save_grad('p2'))
# print(x.grad,x1.grad,z1