torch eval梯度回传问题

直接上结论,即使model.eval()了,梯度传还是有效的。也就是说eval也就固定batchnorm的参数用的。

from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from collections import OrderedDict

class g(nn.Module):
    def __init__(self):
        super(g, self).__init__()
        self.k1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)
        self.k = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1, bias=False)

    def forward(self, z):
        # a, b = torch.topk(z, 2, dim=-1, largest=True, sorted=True)
        # return a
        # print(weights)
        # print(weights["k1.weight"],weights["k1.weight"].shape)
        # z = F.conv2d(z,weights["k1.weight"],stride=1, padding=1)
        # z = F.relu(z)
        # z = F.conv2d(z,weights["k.weight"],stride=1, padding=1)
        z=self.k(F.relu(self.k1(z)))
        return z

class g2(nn.Module):
    def __init__(self):
        super(g2, self).__init__()
        self.k1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)
        self.k = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)

    def forward(self, z, weights):
        # a, b = torch.topk(z, 2, dim=-1, largest=True, sorted=True)
        # return a
        # print(weights)
        # print(weights["k1.weight"],weights["k1.weight"].shape)
        z = F.conv2d(z,weights["k1.weight"],stride=1, padding=1)
        z = F.relu(z)
        z = F.conv2d(z,weights["k.weight"],stride=1, padding=1)
        return z


c = 2
h = 5
w = 5
num=255.
gpu_id=1
z = torch.rand(1, c , h , w).float().view(1, c, h, w)*num
z = Variable(z).cuda(gpu_id)
z1 = torch.rand(1, c , h , w).float().view(1, c, h, w)*num
z1 = Variable(z1).cuda(gpu_id)
z2 = torch.ones( 1,1 , h , w).float().view(1, 1, h, w)*num
z2 = Variable(z2).cuda(gpu_id)
net = g().cuda(gpu_id).eval()
net2 = g2().cuda(gpu_id)

ls =nn.L1Loss()

#
meta_lr = 0.01
task_num = 1
update_lr = 0.01
update_step = 5
meta_optim = optim.Adam(net2.parameters(), lr=meta_lr)

# print("lossb",lossb)
weights = OrderedDict(
        (name, param ) for (name, param) in net2.named_parameters())

meta_grads = [{name: 0 for (name, _) in net2.named_parameters()}]*(update_step-1)
# print(weights)
# print("ASdasd",weights.values())
for i in range(task_num):

    q = net2(z,weights)
    r = net(q)

    loss = ls(r,z2)

    grads = torch.autograd.grad(loss, weights.values())
    # print("greads1",grads)
    #fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, net.parameters())))
    fast_weights = OrderedDict(
        (name, param - update_lr * grad) for ((name, param), grad) in zip(weights.items(), grads))
    print("*******************weights************************")
    print(weights)
    print("*******************fast_weights************************")
    print(fast_weights)
    print("**************************************************")
    for k in range(update_step-1):
        q = net2(z, fast_weights)
        r = net(q)
        loss = ls(r,z2)
        # print("loss",loss)
        grads = torch.autograd.grad(loss, fast_weights.values())
        # print("geads",grads)
        # 3. theta_pi = theta_pi - train_lr * grad
        fast_weights = OrderedDict(
            (name, param - update_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
        q = net2(z1, fast_weights)
        r = net(q)
        loss = ls(r,z2)
        # print("loss",loss,lossb[k])
        grads = torch.autograd.grad(loss, weights.values())
        # print(grads,"*************\t",meta_grads[k])
        for ((name, _), g) in zip(meta_grads[k].items(), grads):
            meta_grads[k][name] = meta_grads[k][name]+g

hooks = []
for (k,v) in net2.named_parameters():
    def get_closure():
        key = k
        def replace_grad(grad):
            return meta_grads[-1][key]
        return replace_grad
    hooks.append(v.register_hook(get_closure()))

# for k in net.parameters():
#     print(k.grad)
print("************net*************")
q = net2(z1, fast_weights)
r = net(q)
loss = ls(r,z2)
meta_optim.zero_grad()
loss.backward()
for k,v in net.named_parameters():
    print(k,v.grad)
meta_optim.step()
# Remove the hooks before next training phase
for h in hooks:
    h.remove()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值