目录
一、简介
关于 Few-shot Learning(小样本学习),详细 可参考综述【1】,本文为叙述方便,现简要概括如下:
- 所谓 Few-shot Learning 就是小样本学习,直观的解释就是样本比较少的机器学习,【1】中指出它要解决的问题是:
机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习。
- Few-shot Learning 是 Meta-learning(元学习) 在监督学习领域上的一个应用。其训练过程大致是这样的:
Few-shot 的训练集中包含了很多的类别,每个类别中有多个样本。在训练阶段,所使用的训练数据由两部分组成:第一部分为 support set,它是由训练集中随机抽取 C 个类别,每个类别 K 个样本(总共C * K 个数据)构成的;第二部分称为 test set ,它是从刚才抽取的 C 个分类的剩余数据中抽取一批测试样本作为模型的预测对象。这两部分数据合成为一个训练数据(task data),训练的目标就是要求模型能从 C*K 个数据中分辨出这 C 个类别,这样的任务被称为 C-way K-shot 问题。它的“有监督”体现在其 Loss 是构建在 test set 的预测分类与其对应 ground-truth 之间的差别上。
- Few-shot Learning 模型大致可分为三类:Mode Based,Metric Based 和 Optimization Based,如图1,在此,我就不做具体解释了,可参【1】
图1、【1】中所述的三种 Few-shot Learning 模型
本博文是对【2】的解读,【2】采用的模型不同于上述三种,它将 Graph Neural Networks(图神经网络,GNN)应用到 Few-shot Learning 中:它将训练数据中每一幅 Image 映射为 Graph 上的一个 Vertex(顶点),通过训练,得到 Graph 中 Vertex 之间的 Adjacency Matrix,并利用它进行分类推断。
关于 GNN(图神经网络,Graph Neural Networks)可以参考【3】,为了叙述方便,简要介绍如下:
GNN 是对非欧空间(Non-Euclidean Space)中适合用 Graph 表达的数据,进行表达学习(Representation Learning)的神经网络模型。我们一般进行深度学习的数据,比如:Image、Text、Video 等,都是欧氏空间(Euclidean Space)中数据,比如Image,就可以看成是规则网格(regular grid)上的点构成的数据,在其上应用CNN(卷积网络),可获得数据后面隐藏的表达(Latent Representation),而一般的 Graph 结构,无法直接应用CNN,需要特殊的图卷积操作,才能得到其背后隐藏的图结构,如图2:
图2、2D 卷积 与 图卷积
GNN图网络是对图的学习,它不同与数据本身的学习,是对数据集所体现出来的图结构表达
的学习,其概念要比普通的机器学习要间接一些,也要复杂和难懂一些。为了搞明白GNN图网络的思想,我特地找来 GNN 的一个应用实例——【2】,作为 GNN 学习的范例。
二、代码实现过程分析
图神经网络的概念比一般网络要间接,仅通读【2】并不能很好地把握文章的思想精华,结合其代码实现会有助于文章概念的理解。我在GitHub上找到一个基于 Pytorch 的实现【4】,以下将结合这份代码,来研究 GNN 是如何进行 Few-shot Learning 的。
2.1 数据的准备
先看代码,完整的代码请参考【4】,现摘抄部分代码如下:
class self_DataLoader(Dataset):
def __init__(self, root, train=True, dataset='cifar100', seed=1, nway=5):
super(self_DataLoader, self).__init__()
self.seed = seed
self.nway = nway
self.num_labels = 100
self.input_channels = 3
self.size = 32
self.transform = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize([0.5071, 0.4866, 0.4409],
[0.2673, 0.2564, 0.2762])
])
self.full_data_dict, self.few_data_dict = self.load_data(root, train, dataset)
print('full_data_num: %d' % count_data(self.full_data_dict))
print('few_data_num: %d' % count_data(self.few_data_dict))
def load_data(self, root, train, dataset):
if dataset == 'cifar100':
few_selected_label = random.Random(self.seed).sample(range(self.num_labels), self.nway)
print('selected labeled', few_selected_label)
full_data_dict = {
}
few_data_dict = {
}
d = CIFAR100(root, train=train, download=True)
for i, (data, label) in enumerate(d):
data = self.transform(data)
if label in few_selected_label:
data_dict = few_data_dict
else:
data_dict = full_data_dict
if label not in data_dict:
data_dict[label] = [data]
else:
data_dict[label].append(data)
print(i + 1)
else:
raise NotImplementedError
return full_data_dict, few_data_dict
def load_batch_data(self, train=True, batch_size=16, nway=5, num_shots=1):
if train:
data_dict = self.full_data_dict
else:
data_dict = self.few_data_dict
x = []
label_y = [] # fake label: from 0 to (nway - 1)
one_hot_y = [] # one hot for fake label
class_y = [] # real label
xi = []
label_yi = []
one_hot_yi = []
map_label2class = []
### the format of x, label_y, one_hot_y, class_y is
### [tensor, tensor, ..., tensor] len(label_y) = batch size
### the first dimension of tensor = num_shots
for i in range(batch_size):
# sample the class to train
sampled_classes = random.sample(data_dict.keys(), nway)
positive_class = random.randint(0, nway - 1)
label2class = torch.LongTensor(nway)
single_xi = []
single_one_hot_yi = []
single_label_yi = []
single_class_yi = []
for j, _class in enumerate(sampled_classes):
if j == positive_class:
### without loss of generality, we assume the 0th
### sampled class is the target class
sampled_data = random.sample(data_dict[_class], num_shots+1)
x.append(sampled_data[0])
label_y.append(torch.LongTensor([j]))
one_hot = torch.zeros(nway)
one_hot[j] = 1.0
one_hot_y.append(one_hot)
class_y.append(torch.LongTensor([_class]))
shots_data = sampled_data[1:]
else:
shots_data = random.sample(data_dict[_class], num_shots)
single_xi += shots_data
single_label_yi.append(torch.LongTensor([j]).repeat(num_shots))
one_hot = torch.zeros(nway)
one_hot[j] = 1.0
single_one_hot_yi.append(one_hot.repeat(num_shots, 1))
label2class[j] = _class
shuffle_index = torch.randperm(num_shots*nway)
xi.append(torch.stack(single_xi, dim=0)[shuffle_index])
label_yi.append(torch.cat(single_label_yi, dim=0)[shuffle_index])
one_hot_yi.append(torch.cat(single_one_hot_yi, dim=0)[shuffle_index])
map_label2class.append(label2class)
return [torch.stack(x, 0), torch.cat(label_y, 0), torch.stack(one_hot_y, 0), \
torch.cat(class_y, 0), torch.stack(xi, 0), torch.stack(label_yi, 0), \
torch.stack(one_hot_yi, 0), torch.stack(map_label2class, 0)]
def load_tr_batch(self, batch_size=16, nway=5, num_shots=1):
return self.load_batch_data(True, batch_size, nway, num_shots)
def load_te_batch(self, batch_size=16, nway=5, num_shots=1):
return self.load_batch_data(False, batch_size, nway, num_shots)
def get_data_list(self, data_dict):
data_list = []
label_list = []
for i in data_dict.keys():
for data in data_dict[i]:
data_list.append(data)
label_list.append(i)
now_time = time.time()
random.Random(now_time).shuffle(data_list)
random.Random(now_time).shuffle(label_list)
return data_list, label_list
def get_full_data_list(self):
return self.get_data_list(self.full_data_dict)
def get_few_data_list(self):
return self.get_data_list(self.few_data_dict)
这段代码的类图如下:
self_DatasetLoader 继承自 torch.utils.data 的 Dataset,其数据源来自:cifar100,cifar100是 pytorch 集成的数据源之一,可以直接下载下来,它包括 100 个分类,每个分类由 500 幅 3 ∗ 32 ∗ 32 3*32*32