class Linear_fw(nn.Linear): #used in MAML to forward input with fast weight
def __init__(self, in_features, out_features):
super(Linear_fw, self).__init__(in_features, out_features)
self.weight.fast = None #Lazy hack to add fast weight link
self.bias.fast = None
def forward(self, x):
if self.weight.fast is not None and self.bias.fast is not None:
out = F.linear(x, self.weight.fast, self.bias.fast)
#weight.fast (fast weight) is the temporaily adapted weight
else:
out = super(Linear_fw, self).forward(x)
return out
out = super(Linear_fw, self).forward(x)这一句让人感到困惑,但如果理解父类,其实它是有意义的。
它的工作原理应该是这样的:super(Linear_fw, self).forward(x)通过super关键字调用Linear_fw,类的forward()方法,这是引用类继承的父类的一种更优雅的方式。