当需要拿出model的一些中间输出并做bp等操作,可以利用hook机制
对某个module调用register_forward_hook(get)时,会在forward时,自动调用定义的函数get(也可以是其他名字)但参数是model input output。使用input和model,在hook标注的这个module得到的结果会返回到output,然后可以对它进行一些操作
我这里是把一些feature放到了model的一个self.features参数里
def get_output(self):
# x = imgs
self.features = []
def get(model, input, output):
# function will be automatically called each time, since the hook is injected
self.features.append(output)
for name, module in self._modules['frontend']._modules.items():
# x = module(x)
if name in ['6', '8', '11', '15', '18', '22']:
self._modules['frontend']._modules[name].register_forward_hook(get)
注意hook注入的函数在forward是自动调用,第一步要清空self.feature释放显存
def forward(self, x):
self.features = []
# front relates to VGG
x = self.frontend(x)
# backend relates to dilated convolution
x = self.backend(x)
x = self.output_layer(x)
return x
这样在外面也能拿到:
model.get_output()
for i, (img, target)in enumerate(train_loader):
data_time.update(time.time() - end)
img = img.cuda()
img = Variable(img)
output = model(img)
for feature in model.features:
print feature