invariance.py
class ExemplarMemory(Function):
def __init__(self, em, alpha=0.01):
super(ExemplarMemory, self).__init__()
self.em = em
self.alpha = alpha
def forward(self, inputs, targets):
self.save_for_backward(inputs, targets)
outputs = inputs.mm(self.em.t())
return outputs
def backward(self, grad_outputs):
inputs, targets = self.saved_tensors
grad_inputs = None
if self.needs_input_grad[0]:
grad_inputs = grad_outputs.mm(self.em)
for x, y in zip(inputs, targets):
self.em[y] = self.alpha * self.em[y] + (1. - self.alpha) * x
self.em[y] /= self.em[y].norm()
return grad_inputs, None
# Invariance learning loss
class InvNet(nn.Module):
def __init__(self, num_features, num_classes, beta=0.05, knn=6, alpha=0.01):
super(InvNet, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.num_features = num_features
self.num_classes = num_classes
self.alpha = alpha # Memory update rate
self.beta = beta # Temperature fact
self.knn = knn # Knn for neighborhood invariance
# Exemplar memory
self.em = nn.Parameter(torch.zeros(num_classes, num_features))
def forward(self, inputs, targets, epoch=None):
alpha = self.alpha * epoch
inputs = ExemplarMemory(self.em, alpha=alpha)(inputs, targets)
inputs