小样本学习Few-Shot Learning——孪生网络Siamese Networks、匹配网络Matching Networks、原型网络Prototypical Networks 的简单总结

小样本学习旨在解决在有限标注样本下训练模型的问题。本文介绍了几种模型,包括孪生网络、匹配网络和原型网络。孪生网络通过比较同类和不同类样本对的embedding向量来学习;匹配网络引入注意力机制,对支持集中的样本加权求和;原型网络则通过计算支持集中样本的平均embedding作为类别原型,进行预测。这些方法提高了模型在未见过的类别上的泛化能力。
摘要由CSDN通过智能技术生成

1. 小样本学习 Few-Shot Learning

1.1 小样本学习要解决的问题

以图片分类这个任务举例,使用神经网络模型的传统做法是:先使用大量带标签的猫和狗的图片训练模型,然后让训练好的模型给不在训练集中的猫和狗的图片做分类,去预测输入的图片是猫还是狗。
而在很多领域的现实应用中,并没有足够的带标签的图片可供模型训练,可能每个类别只有几十个、甚至几个带标签的样本,此时我们希望模型可以根据这些少量样本就学到该类别的关键知识,以对不在训练集中的图片做分类。

1.2 小样本学习模型的训练方法

我们既然希望模型可以根据少量样本学习分类,那在训练阶段就要锻炼模型的这个能力。此处借用李宏毅老师的课程PPT截图进行说明。

数据集可用来做训练任务(Training Tasks)和测试任务(Testing Tasks),训练任务中有很多个子任务(Task1、Task2……),每个子任务中都有两个小数据集,一个是support set(Train),用来让模型学习知识;另一个是query set(Test),用来检验模型学习知识的能力。

如上图,在训练时的Task1中,只使用一张猫的图片和一张狗的图片训练模型,然后让模型预测其他两张图片,分类出猫和狗;在训练时的Task2中,只使用一张苹果的图片和一张橙子的图片训练模型,然后让模型预测其他两张图片,分类出苹果和橙子……这样的子任务会有很多,如果模型在每次子任务中都表现的很好,就说明模型有了这样一个能力:根据support set中的少量样本学习有用的知识,然后去分类该领域的图片,如果分类效果好,就说明模型使用少量样本去学习知识的能力很强。此处体现了,小样本学习是希望模型可以自己学会如何去学习知识

与传统深度学习不同的地方是,小样本学习模型应用的领域是之前未曾接触过的领域。比如上图中,测试任务中的自行车和汽车,模型在训练过程中是从未看到的,只能根据测试任务的support set中的一张自行车图片和一张汽车图片去学习自行车和汽车的特点。而传统的深度学习,是使用大量的猫和狗的图片训练好模型后,模型只去分类猫和狗的图片。

小样本学习中有三个典型的模型:孪生网络、匹配网络、原型网络。下面做一些简单介绍。

2. 孪生网络 Siamese Networks

论文题目:Siamese Neural Networks for One-shot Image Recognition,2015
论文地址:http://www.cs.cmu.edu/~rsalakhu/papers/oneshot1.pdf

2.1 主要思想

利用相同样本对和不同样本对的之间的区别,训练出一个神经网络模型,使同类样本生成的embedding向量相近,不同样本的embedding向量远离。

2.2 模型结构

输入:两张图片: ( x 1 , 1 … x 1 , N 1 ) 、 ( x 2 , 1 … x 2 , N 1 ) (x_{1,1} \ldots x_{1,N1})、(x_{2,1} \dots x_{2,N1}) (x1,1x1,N1)(x2,1x2,N1)
输出:两张图片是同一类别的预测值。该值越大,表示输入的两张图片越有可能是同一类别。

inference:将query image和support set中的N×K个images逐一配对输入模型,得到N×K个预测值,将query image归为预测值最大的一类。(N 是support set中的类别个数,K 是support set中每一类的样本数)

3. 匹配网络 Matching Networks

论文题目:Matching Networks for One Shot Learning,2016
论文地址:https://proceedings.neurips.cc/paper/2016/file/90e1357833654983612fb05e3ec9148c-Paper.pdf

3.1 主要思想

首先对support set和query set进行embedding,然后用query image对support set中的每个样本计算注意力:

其中 x ^ \hat{x} x^ 是query image, x i x_i xi 是support set 中的样本,c是余弦距离。query image使用编码器 f f f 进行编码得到embedding,support set中的image使用编码器 g g g 编码得到embedding。

最后把每个类别根据注意力得分进行线性加权:

3.2 和孪生网络的区别

和siamese network区别:不是直接取最高的预测值,而是将同类预测值相加,取最高的一类。
当实验场景是one-shot时,除了网络结构的部分都和siamese network一样。

4. 原型网络Prototypical Networks

论文题目:Prototypical Networks for Few-shot Learning,2017
论文地址:https://proceedings.neurips.cc/paper/2017/file/cb8da6767461f2812ae4290eac7cbc42-Paper.pdf

4.1 主要思想

  • 求原型中心:将support set中的样本全部输入编码器,每个样本对应得到一个embedding向量,将同类样本的embedding向量取平均值,得到该类的原型中心c。
  • 预测query image:将query set中的每个样本query image x 输入编码器,每个x对应得到一个embedding向量,该向量离哪个原型中心c 最近,就预测x 应为哪一类。
  • 15
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是基于匹配网络的one-shot样本分类代码示例,使用matlab的深度学习工具箱实现: ```matlab % 数据集准备 % 在这里,我们使用Omniglot数据集,其中包含来自50个不同语言的1623个字符类别(每个类别有20个样本)。 % 这里我们只使用其中的1200个类别进行训练,剩下的423个类别用于测试。 % 数据集已经预处理为一个.mat文件,包含了训练和测试数据,以及对应的标签信息。 % 加载数据集 load('omniglot.mat'); % 训练数据 train_data = train_data'; train_labels = train_labels'; % 测试数据 test_data = test_data'; test_labels = test_labels'; % 网络定义 input_dim = 105*105; % 输入维度 hidden_dim = 64; % 隐藏层维度 output_dim = 1; % 输出维度(二分类) % 定义网络结构 net = siamese_network(input_dim, hidden_dim, output_dim); % 训练网络 num_epochs = 50; % 训练轮数 batch_size = 32; % 批大小 learning_rate = 0.001; % 学习率 % 定义优化器 optimizer = adam_optimizer(learning_rate); % 训练网络 train_losses = zeros(num_epochs, 1); for epoch = 1:num_epochs epoch_loss = 0; for i = 1:batch_size:size(train_data, 1) % 获取当前批次数据及标签 batch_data = train_data(i:min(i+batch_size-1, end), :); batch_labels = train_labels(i:min(i+batch_size-1, end)); % 前向传播计算损失 [loss, grads] = compute_loss_and_grads(net, batch_data, batch_labels); epoch_loss = epoch_loss + loss; % 反向传播更新参数 net = update_parameters(net, grads, optimizer); end train_losses(epoch) = epoch_loss / ceil(size(train_data, 1)/batch_size); fprintf('Epoch %d, Train Loss: %f\n', epoch, train_losses(epoch)); end % 测试网络 num_correct = 0; for i = 1:size(test_data, 1) % 对每个测试样本,找到它的最近邻(即与它距离最近的训练样本) distances = sum((train_data - repmat(test_data(i,:), size(train_data, 1), 1)).^2, 2); [~, nearest_idx] = min(distances); % 使用最近邻与当前测试样本进行匹配,并预测其类别 input1 = test_data(i,:); input2 = train_data(nearest_idx,:); output = forward(net, input1, input2); prediction = output > 0.5; % 计算准确率 if prediction == test_labels(i) num_correct = num_correct + 1; end end accuracy = num_correct / size(test_data, 1); fprintf('Test Accuracy: %f\n', accuracy); ``` 其中,`siamese_network`函数用于定义匹配网络的结构,`adam_optimizer`函数用于定义Adam优化器,`compute_loss_and_grads`函数用于计算损失和梯度,`update_parameters`函数用于更新网络参数,`forward`函数用于前向传播计算输出。这些函数的实现可以参考深度学习工具箱的文档。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值