1. retain graph
通常计算完backward之后会将计算图free掉以节省空间,但是如果需要做两次backwark,且有共用的graph时,需要先retain graph然后再进行backward。
2. detach
假设一个变量在graph中位置1时参与反向传播,需要更新权重;在位置2时不需要反向传播,只是当做一个常量来参与运算。此时,在位置2计算之前需要进行detach操作。
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
self.target = target.detach() * weight
# 因为这里只是需要target这个数值,这个数值是一种状态,不计入计算树中。
# 这里单纯将其当做常量对待,因此用了detach则在backward中计算梯度时不对target之前所在的计
算图存在任何影响。
self.weight = weight
self.criterion = nn.MSELoss()