直接上结论,即使model.eval()了,梯度传还是有效的。也就是说eval也就固定batchnorm的参数用的。
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from collections import OrderedDict
class g(nn.Module):
def __init__(self):
super(g, self).__init__()
self.k1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)
self.k = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1, bias=False)
def forward(self, z):
# a, b = torch.topk(z, 2, dim=-1, largest=True, sorted=True)
# return a
# print(weights)
# print(weights["k1.weight"],weights["k1.weight"].shape)
# z = F.conv2d(z,weights["k1.weight"],stride=1, padding=1)
# z = F.relu(z)
# z = F.conv2d(z,weights["k.weight"],stride=1, padding=1)
z=self.k(F.relu(self.k1(z)))
return z
class g2(nn.Module):
def __init__(self):
super(g2, self).__init__()
self.k1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)
self.k = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=False)
def forward(self, z, weights):
# a, b = torch.topk(z, 2, dim=-1, largest=True, sorted=True)
# return a
# print(weights)
# print(weights["k1.weight"],weights["k1.weight"].shape)
z = F.conv2d(z,weights["k1.weight"],stride=1, padding=1)
z = F.relu(z)
z = F.conv2d(z,weights["k.weight"],stride=1, padding=1)
return z
c = 2
h = 5
w = 5
num=255.
gpu_id=1
z = torch.rand(1, c , h , w).float().view(1, c, h, w)*num
z = Variable(z).cuda(gpu_id)
z1 = torch.rand(1, c , h , w).float().view(1, c, h, w)*num
z1 = Variable(z1).cuda(gpu_id)
z2 = torch.ones( 1,1 , h , w).float().view(1, 1, h, w)*num
z2 = Variable(z2).cuda(gpu_id)
net = g().cuda(gpu_id).eval()
net2 = g2().cuda(gpu_id)
ls =nn.L1Loss()
#
meta_lr = 0.01
task_num = 1
update_lr = 0.01
update_step = 5
meta_optim = optim.Adam(net2.parameters(), lr=meta_lr)
# print("lossb",lossb)
weights = OrderedDict(
(name, param ) for (name, param) in net2.named_parameters())
meta_grads = [{name: 0 for (name, _) in net2.named_parameters()}]*(update_step-1)
# print(weights)
# print("ASdasd",weights.values())
for i in range(task_num):
q = net2(z,weights)
r = net(q)
loss = ls(r,z2)
grads = torch.autograd.grad(loss, weights.values())
# print("greads1",grads)
#fast_weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, net.parameters())))
fast_weights = OrderedDict(
(name, param - update_lr * grad) for ((name, param), grad) in zip(weights.items(), grads))
print("*******************weights************************")
print(weights)
print("*******************fast_weights************************")
print(fast_weights)
print("**************************************************")
for k in range(update_step-1):
q = net2(z, fast_weights)
r = net(q)
loss = ls(r,z2)
# print("loss",loss)
grads = torch.autograd.grad(loss, fast_weights.values())
# print("geads",grads)
# 3. theta_pi = theta_pi - train_lr * grad
fast_weights = OrderedDict(
(name, param - update_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))
q = net2(z1, fast_weights)
r = net(q)
loss = ls(r,z2)
# print("loss",loss,lossb[k])
grads = torch.autograd.grad(loss, weights.values())
# print(grads,"*************\t",meta_grads[k])
for ((name, _), g) in zip(meta_grads[k].items(), grads):
meta_grads[k][name] = meta_grads[k][name]+g
hooks = []
for (k,v) in net2.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return meta_grads[-1][key]
return replace_grad
hooks.append(v.register_hook(get_closure()))
# for k in net.parameters():
# print(k.grad)
print("************net*************")
q = net2(z1, fast_weights)
r = net(q)
loss = ls(r,z2)
meta_optim.zero_grad()
loss.backward()
for k,v in net.named_parameters():
print(k,v.grad)
meta_optim.step()
# Remove the hooks before next training phase
for h in hooks:
h.remove()