def query(self, x, y, predict=False):
"""
Compute the nearest neighbor of the input queries.
Arguments:
x: A normalized matrix of queries of size (batch_size x key_dim)
y: A matrix of correct labels (batch_size x 1)
Returns:
y_hat, A (batch-size x 1) matrix
- the nearest neighbor to the query in memory_size
softmax_score, A (batch_size x 1) matrix
- A normalized score measuring the similarity between query and nearest neighbor
loss - average loss for memory module
"""
batch_size, dims = x.size()
query = F.normalize(self.query_proj(x), dim=1)
#query = F.normalize(torch.matmul(x, self.query_proj), dim=1)
# Find the k-nearest neighbors of the query
scores = torch.matmul(query, torch.t(self.keys_var))
cosine_similarity, topk_indices_var = torch.topk(scores, self.top_k, dim=1)
# softmax of cosine similarities - embedding
softmax_score = F.softmax(self.softmax_temperature * cosine_similarity)
# retrive memory values - prediction
topk_indices = topk_indices_var.detach().data
y_hat_indices = topk_indices[:, 0]
y_hat = self.values[y_hat_indices]
loss = None
if not predict:
# Loss Function
# topk_indices = (batch_size x topk)
# topk_values = (batch_size x topk x value_size)
# collect the memory values corresponding to the topk scores
batch_size, topk_size = topk_indices.size()
flat_topk = flatten(topk_indices)
flat_topk_values = self.values[topk_indices]
topk_values = flat_topk_values.resize_(batch_size, topk_size)
correct_mask = torch.eq(topk_values, torch.unsqueeze(y.data, dim=1)).float()
correct_mask_var = ag.Variable(correct_mask, requires_grad=False)
pos_score, pos_idx = torch.topk(torch.mul(cosine_similarity, correct_mask_var), 1, dim=1)
neg_score, neg_idx = torch.topk(torch.mul(cosine_similarity, 1-correct_mask_var), 1, dim=1)
# zero-out correct scores if there are no correct values in topk values
mask = 1.0 - torch.eq(torch.sum(correct_mask_var, dim=1), 0.0).float()
pos_score = torch.mul(pos_score, torch.unsqueeze(mask, dim=1))
#print(pos_score, neg_score)
loss = MemoryLoss(pos_score, neg_score, self.margin)
# Update memory
self.update(query, y, y_hat, y_hat_indices)
return y_hat, softmax_score, loss