- 场景
网络训练好后,在用网络进行推理的过程,想多输出模型中间层的结果。
- 方法
更改中间变量的名称,直接在return中加上该变量。
- 示例
在domain adversarial neural network代码结构中,如果在模型训练时,我们只返回了D,h的结果用于后续损失计算,在推理时我们突然想看编码器输出特征z。 原先的代码:
def forward(self, x: torch.Tensor, _lambda):
x = self.parts_before(x)
x_ = self.grl(x)
x_ = self.domain_classifier(x_)
x = self.class_predictor(x)
return x_, x
推理时的代码:
def forward(self, x: torch.Tensor, _lambda):
x_latent = self.parts_before(x)
x_ = self.grl(x_latent)
x_dc = self.domain_classifier(x_)
x_cp = self.class_predictor(x)
return x_dc, x_cp,x_latent
这样推理时同样可以用之前训练好的模型。本质上只是对网络中间层的重命名,不影响网络结构就行。至于返回个数,在外部调用函数时多加个返回值就行。