元学习顾名思义是使网络具备自主学习能力,能像人一样具备学习能力,通过看到少量的样本就可以去区分识别更多的样本,对于现在很多较难获取样本的场景具有优越性。
元学习网络的训练与评估使用过程,首先,元学习网络现在较大的数据集如minimagesnet数据集上进行5way5shot和5way1shot训练,获得模型,加载训练好的模型再在新的类别中进行测试评估。个人理解跟迁移学习十分相似。同时在训练过程中学习率随着损失的变化而变化,当学习率不发生变化是则停止训练。
具体的来说元学习就是由feature encoder和其距离计算组成。
其中prototypical network是其中较为简单的元学习网络,主要由encoder(Protonet)如下所示。
Protonet(
(encoder): Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(1): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(2): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(3): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(4): Flatten()
)
)
distance metric(距离计算)组成。
def euclidean(x, y):
'''
Compute euclidean distance between two tensors
'''
# x: N x D
# y: M x D
n = x.size(0)
m = y.size(0)
d = x.size(1)
if d != y.size(1):
raise Exception
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
return torch.pow(x - y, 2).sum(2)
网络的损失计算
class PrototypicalLoss(Module):
'''
Loss class deriving from Module for the prototypical loss function defined below
'''
def __init__(self, n_support, dist_func, reg):
super(PrototypicalLoss, self).__init__()
self.n_support = n_support
if dist_func == "cosine":
self.dist_func = cosine
elif dist_func == "euclidean":
self.dist_func = euclidean
else:
self.dist_func = None
self.reg = reg
def forward(self, input, target, weights):
return prototypical_loss(input, target, self.n_support, weights=weights, dist_func=self.dist_func, lambda_reg=self.reg)
def prototypical_loss(input, target, n_support, weights, dist_func=euclidean, lambda_reg=0.05):
'''
Inspired by https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
Compute the barycentres by averaging the features of n_support
samples for each class in target, computes then the distances from each
samples' features to each one of the barycentres, computes the
log_probability for each n_query samples for each one of the current
classes, of appertaining to a class c, loss and accuracy are then computed
and returned
Args:
- input: the model output for a batch of samples
- target: ground truth for the above batch of samples
- n_support: number of samples to keep in account when computing
barycentres, for each one of the current classes
'''
target_cpu = target.to('cpu')
input_cpu = input.to('cpu')
def supp_idxs(c):
# FIXME when torch will support where as np
return target_cpu.eq(c).nonzero()[:n_support].squeeze(1)
# FIXME when torch.unique will be available on cuda too
classes = torch.unique(target_cpu) # non-repeated classes (i.e. types of ground truth)
n_classes = len(classes)
# FIXME when torch will support where as np
# assuming n_query, n_target constants
n_query = target_cpu.eq(classes[0].item()).sum().item() - n_support
support_idxs = list(map(supp_idxs, classes))
prototypes = torch.stack([input_cpu[idx_list].mean(0) for idx_list in support_idxs]) # 每一个class的类似centroid?
# FIXME when torch will support where as np
query_idxs = torch.stack(list(map(lambda c: target_cpu.eq(c).nonzero()[n_support:], classes))).view(-1)
query_samples = input.to('cpu')[query_idxs]
dists = dist_func(query_samples, prototypes)
log_p_y = F.log_softmax(-dists, dim=1).view(n_classes, n_query, -1)
target_inds = torch.arange(0, n_classes)
target_inds = target_inds.view(n_classes, 1, 1)
target_inds = target_inds.expand(n_classes, n_query, 1).long()
# loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
# --------------------------
reg = 0
for param in weights:
param = param.to('cpu')
reg += torch.sum(0.5*(param**2)) # L2 regularization
# reg += torch.sum(torch.abs(param)) # L1 regularization
loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean() + lambda_reg*reg
# --------------------------
_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
return loss_val, acc_val
后记
元学习中matching network使用的是cosine distance。DeepEMD使用的是EMD距离。有所不同,理论上DeepEMD的效果更好。有时间继续更新............