经常用到,但是又经常把这个接口忘掉,torch文档里面又不是很直接,所以记录一下。
大概两种用途吧:
- 一些特定的算法,需要半路取出模型某个部分的梯度 (e.g., adversarial training)
- debug用,试试看loss反传是否异常 (某些特定部分的梯度大小是否异常等)
比方说:
torch.autograd.grad(loss,self.encoder.embeddings.speaker_embed.parameters(),retain_graph=True)
查看模型encoder.embeddings.speaker_embed
这个部分的参数梯度,必须要是iterable所以,用了.parameters()
。记得设置retain_graph=True
,这样才能二次反传。