《Few-Shot Learning with Graph Neural Networks》——少样本学习与图神经网络

一、简介

关于 Few-shot Learning(小样本学习),详细 可参考综述【1】,本文为叙述方便,现简要概括如下:

  • 所谓 Few-shot Learning 就是小样本学习,直观的解释就是样本比较少的机器学习,【1】中指出它要解决的问题是:

机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习。

  • Few-shot Learning 是 Meta-learning(元学习) 在监督学习领域上的一个应用。其训练过程大致是这样的:

Few-shot 的训练集中包含了很多的类别,每个类别中有多个样本。在训练阶段,所使用的训练数据由两部分组成:第一部分为 support set,它是由训练集中随机抽取 C 个类别,每个类别 K 个样本(总共C * K 个数据)构成的;第二部分称为 test set ,它是从刚才抽取的 C 个分类的剩余数据中抽取一批测试样本作为模型的预测对象。这两部分数据合成为一个训练数据(task data),训练的目标就是要求模型能从 C*K 个数据中分辨出这 C 个类别,这样的任务被称为 C-way K-shot 问题。它的“有监督”体现在其 Loss 是构建在 test set 的预测分类与其对应 ground-truth 之间的差别上。

  • Few-shot Learning 模型大致可分为三类:Mode Based,Metric Based 和 Optimization Based,如图1,在此,我就不做具体解释了,可参【1】
    图1
    图1、【1】中所述的三种 Few-shot Learning 模型

本博文是对【2】的解读,【2】采用的模型不同于上述三种,它将 Graph Neural Networks(图神经网络,GNN)应用到 Few-shot Learning 中:它将训练数据中每一幅 Image 映射为 Graph 上的一个 Vertex(顶点),通过训练,得到 Graph 中 Vertex 之间的 Adjacency Matrix,并利用它进行分类推断。
关于 GNN(图神经网络,Graph Neural Networks)可以参考【3】,为了叙述方便,简要介绍如下:

GNN 是对非欧空间(Non-Euclidean Space)中适合用 Graph 表达的数据,进行表达学习(Representation Learning)的神经网络模型。我们一般进行深度学习的数据,比如:Image、Text、Video 等,都是欧氏空间(Euclidean Space)中数据,比如Image,就可以看成是规则网格(regular grid)上的点构成的数据,在其上应用CNN(卷积网络),可获得数据后面隐藏的表达(Latent Representation),而一般的 Graph 结构,无法直接应用CNN,需要特殊的图卷积操作,才能得到其背后隐藏的图结构,如图2:
图2
图2、2D 卷积 与 图卷积

GNN图网络是对图的学习,它不同与数据本身的学习,是对数据集所体现出来的图结构表达的学习,其概念要比普通的机器学习要间接一些,也要复杂和难懂一些。为了搞明白GNN图网络的思想,我特地找来 GNN 的一个应用实例——【2】,作为 GNN 学习的范例。

二、代码实现过程分析

图神经网络的概念比一般网络要间接,仅通读【2】并不能很好地把握文章的思想精华,结合其代码实现会有助于文章概念的理解。我在GitHub上找到一个基于 Pytorch 的实现【4】,以下将结合这份代码,来研究 GNN 是如何进行 Few-shot Learning 的。

2.1 数据的准备

先看代码,完整的代码请参考【4】,现摘抄部分代码如下:

class self_DataLoader(Dataset):
    def __init__(self, root, train=True, dataset='cifar100', seed=1, nway=5):
        super(self_DataLoader, self).__init__()

        self.seed = seed
        self.nway = nway
        self.num_labels = 100
        self.input_channels = 3
        self.size = 32

        self.transform = tv.transforms.Compose([
            tv.transforms.ToTensor(),
            tv.transforms.Normalize([0.5071, 0.4866, 0.4409], 
                [0.2673, 0.2564, 0.2762])
            ])

        self.full_data_dict, self.few_data_dict = self.load_data(root, train, dataset)

        print('full_data_num: %d' % count_data(self.full_data_dict))
        print('few_data_num: %d' % count_data(self.few_data_dict))

    def load_data(self, root, train, dataset):
        if dataset == 'cifar100':
            few_selected_label = random.Random(self.seed).sample(range(self.num_labels), self.nway)
            print('selected labeled', few_selected_label)

            full_data_dict = {
   }
            few_data_dict = {
   }

            d = CIFAR100(root, train=train, download=True)

            for i, (data, label) in enumerate(d):

                data = self.transform(data)

                if label in few_selected_label:
                    data_dict = few_data_dict
                else:
                    data_dict = full_data_dict

                if label not in data_dict:
                    data_dict[label] = [data]
                else:
                    data_dict[label].append(data)
            print(i + 1)
        else:
            raise NotImplementedError

        return full_data_dict, few_data_dict

    def load_batch_data(self, train=True, batch_size=16, nway=5, num_shots=1):
        if train:
            data_dict = self.full_data_dict
        else:
            data_dict = self.few_data_dict

        x = []
        label_y = [] # fake label: from 0 to (nway - 1)
        one_hot_y = [] # one hot for fake label
        class_y = [] # real label

        xi = []
        label_yi = []
        one_hot_yi = []
        

        map_label2class = []

        ### the format of x, label_y, one_hot_y, class_y is 
        ### [tensor, tensor, ..., tensor] len(label_y) = batch size
        ### the first dimension of tensor = num_shots

        for i in range(batch_size):

            # sample the class to train
            sampled_classes = random.sample(data_dict.keys(), nway)

            positive_class = random.randint(0, nway - 1)

            label2class = torch.LongTensor(nway)

            single_xi = []
            single_one_hot_yi = []
            single_label_yi = []
            single_class_yi = []


            for j, _class in enumerate(sampled_classes):
                if j == positive_class:
                    ### without loss of generality, we assume the 0th 
                    ### sampled  class is the target class
                    sampled_data = random.sample(data_dict[_class], num_shots+1)

                    x.append(sampled_data[0])
                    label_y.append(torch.LongTensor([j]))

                    one_hot = torch.zeros(nway)
                    one_hot[j] = 1.0
                    one_hot_y.append(one_hot)

                    class_y.append(torch.LongTensor([_class]))

                    shots_data = sampled_data[1:]
                else:
                    shots_data = random.sample(data_dict[_class], num_shots)

                single_xi += shots_data
                single_label_yi.append(torch.LongTensor([j]).repeat(num_shots))
                one_hot = torch.zeros(nway)
                one_hot[j] = 1.0
                single_one_hot_yi.append(one_hot.repeat(num_shots, 1))

                label2class[j] = _class

            shuffle_index = torch.randperm(num_shots*nway)
            xi.append(torch.stack(single_xi, dim=0)[shuffle_index])
            label_yi.append(torch.cat(single_label_yi, dim=0)[shuffle_index])
            one_hot_yi.append(torch.cat(single_one_hot_yi, dim=0)[shuffle_index])

            map_label2class.append(label2class)

        return [torch.stack(x, 0), torch.cat(label_y, 0), torch.stack(one_hot_y, 0), \
            torch.cat(class_y, 0), torch.stack(xi, 0), torch.stack(label_yi, 0), \
            torch.stack(one_hot_yi, 0), torch.stack(map_label2class, 0)]


    def load_tr_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(True, batch_size, nway, num_shots)

    def load_te_batch(self, batch_size=16, nway=5, num_shots=1):
        return self.load_batch_data(False, batch_size, nway, num_shots)

    def get_data_list(self, data_dict):
        data_list = []
        label_list = []
        for i in data_dict.keys():
            for data in data_dict[i]:
                data_list.append(data)
                label_list.append(i)

        now_time = time.time()

        random.Random(now_time).shuffle(data_list)
        random.Random(now_time).shuffle(label_list)

        return data_list, label_list

    def get_full_data_list(self):
        return self.get_data_list(self.full_data_dict)

    def get_few_data_list(self):
        return self.get_data_list(self.few_data_dict)

这段代码的类图如下:

self_DatasetLoader
Dataset

self_DatasetLoader 继承自 torch.utils.data 的 Dataset,其数据源来自:cifar100,cifar100是 pytorch 集成的数据源之一,可以直接下载下来,它包括 100 个分类,每个分类由 500 幅 3 ∗ 32 ∗ 32 3*32*32

  • 35
    点赞
  • 90
    收藏
    觉得还不错? 一键收藏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值