元学习——原型网络(Prototypical Networks)

元学习——原型网络(Prototypical Networks)

1. 基本介绍

1.1 本节引入

在之前的的文章中,我们介绍了关于连体网络的相关概念,并且给出了使用Pytorch实现的基于连体网络的人脸识别网络的小样本的学习过程。在接下来的内容中,我们来继续介绍另外一种小样本学习的神经网络结构——原型网络。这种网络的特点是拥有能够不仅仅应用在当前数据集的泛化分类能力。在接下来的内容中,我们将介绍以下几个内容:

  1. 原型网络的基本结构。
  2. 原型网络算法描述。
  3. 将原型网络应用于分类任务。
1.2 原型网络引入

相比于连体网络,原型网络是另外一种简单,高效的小样本的学习方式。与连体网络的学习目标类似。原型网络的目标也是学习到一个向量空间来实现文本分类任务。

原型网络的基本思路是对于每一个分类来创建一个原型表示(protoypicla representation)。并且对于一个需要分类的查询,采用计算分类的原型向量和查询点的距离来进行确定。

确定基本思路之后,下面从一个例子开始,对于原型网络进行具体描述。

2 原型网络

2.1 从一个例子开始

现在,我们拥有一个支持集(support set),内部包含狮子,大象,狗三个分类的图片。也就是说,对于分类任务,我们一共拥有三个分类:{狮子,大象,狗}。现在,我们需要对于每一个分类创建一个原型表示。建立的基本流程如下图所示:

  1. 首先,我们对于每一个样本使用编码的方式 f φ ( ) f_φ() fφ(),学习到每一个样本的编码表示(信息抽取)。举个例子,我们可以使用卷积操作来实现对于图片编码信息的抽取。
    在这里插入图片描述
  2. 在学习到每一个样本的编码表示之后,我们对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。因此,一个分类的原型表示使用向量求和求平均的过程过程进行表示。
    在这里插入图片描述

当一个新的数据样本被输入到网络中的时候,我们需要的是对于这个样本预测出其分类情况。
3. 第一步,我对于这个新的数据样本使用 f φ ( ) f_φ() fφ()生成其编码表示。如下图所示:
在这里插入图片描述
4. 接下来,我们需要做的就是计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。对于距离计算,并没有特殊的要求,可以使用欧式距离或者Cos相似度等等计算方式。
在这里插入图片描述

  1. 最后在计算出所有的分类之间的距离之后,我们使用softmax的方式将距离转换成概率的形式。我们有三个分类,那么对于样本在softmax之后,获取到的就是对于这三个分类的距离情况。

在本节的最后,我们回到我们的学习过程,我们希望的是网络从小样本的数据集中进行学习。所以我们在训练的时候,我们对于每一个分类随机的生成少量的样本,我们成这些少量的样本集合为支持集,在整个的训练过程,我们只需要使用到支持集即可。而不需要所有的数据集。同理,我们随机的从数据集中抽取一个样本作为查询点并且对其进行分类的预测。这样就完成了我们从小样本学习的方式。

2.2 原型网络的整体架构

首先,我们给出原型网络的整体架构图:

在这里插入图片描述

我们从整体的架构上来分析一下这种网络结构:

  1. 第一步,我们对于支持集中的每一个样本点生成一个编码表示,通过通过求和平均的方式来生成每一个分类的原型表示。同时,对于我们的查询样本,我们也对其生成一个向量表示
  2. 同时,我们需要计算每一个查询点和每一个分类原型表示的距离情况。并计算softmax概率结果。生成对于各个分类的概率分布情况。

进一步,对于原型网络而言,其应用的范围不仅仅在单样本/小样本的学习过程中,同时还可以应用在零样本的学习方式。对于这种应用的思路是:尽管我们没有当前分类的数据样本,但是如果能够在更高的层次中生成分类的原型表示(元信息)。通过这种元信息,我们也可以完成和上面类似的计算,完成我们的分类任务。

2.3 算法描述

这里我们结合网络结构和数学公式来对原型网络进行算法描述:

  1. 假设我们当前的数据集为D,其内部的样本的表示形式为{ ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . . , ( x n , y n ) (x_1,y_1),(x_2,y_2),....,(x_n,y_n) (x1,y1),(x2,y2),....,(xn,yn)},其中x表示的向量表示,y表示分类分类标签。
  2. 对于每一个分类,我们随机的从总的样本集中为其生成n个样本点,对于每一个分类,我们生成最后支持集为S。
  3. 同理,我们随机的从总的样本集中为每一个分类选择n个样本点来生成查询集Q。
  4. 对于支持集内部的样本点,使用编码公式 f φ f_φ fφ来为每一个分类生成一个原型表示,这里的编码公式 f φ f_φ fφ可以是任意的一种信息抽取的方式。例如CNN,LSTM等等。
  5. 对于每一个分类,我们生成其原型表示为 :
    i . e . C l a s s P r o t o t y p e ( c ) = 1 S ∑ ( x i , y i ) ∈ S f φ ( x i ) i.e. Class Prototype(c)=\frac{1}{S}∑_{(x_i,y_i)∈S}f_φ(x_i) i.e.ClassPrototype(c)=S1(xi,yi)Sfφ(xi)
  6. 类似的是,我们对于查询集也生成查询集的编码。
  7. 进一步,我们需要计算的是查询集和支持集的原型表示的距离情况。
  8. 最后,需要计算的是当前样本属于每一个分类的概率 p w ( y = k ∣ x ) p_w(y=k|x) pw(y=kx),这里使用softmax的计算方式:
    i . e . p φ ( y = k ∣ x ) = e x p ( − d ( f φ ( x ) , c ) ) ∑ k e x p ( − d ( f φ ( x ) , c ) ) i.e. p_φ(y=k|x)=\frac{exp(-d(f_φ(x),c))}{∑_kexp(-d(f_φ(x),c))} i.e.pφ(y=kx)=kexp(d(fφ(x),c))exp(d(fφ(x),c))
  9. 最终,我们计算损失函数为 J ( φ ) J(φ) J(φ)
    J ( φ ) = − l o g p w ( y = k ∣ x ) J(φ)=-logp_w(y=k|x) J(φ)=logpw(y=kx)
2.4 代码描述

这里,我们选择自定义了一个简单的评论数据集,一共两个分类,每一个分类下面有5个数据,每个分类我们选择3个作为支持集,3个作为查询集,其具体的实现如下:

#encoding=utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
import jieba
import random
import torch.optim as optim



def createData():
    text_list_pos = ["电影内容很好","电影题材很好","演员演技很好","故事很感人","电影特效很好"]
    text_list_neg = ["电影内容垃圾","电影是真的垃圾","表演太僵硬了","故事又臭又长","电影太让人失望了"]
    test_pos = ["电影","很","好"]
    test_neg = ["电影","垃圾"]
    words_pos = [[item for item in jieba.cut(text)] for text in text_list_pos]
    words_neg = [[item for item in jieba.cut(text)] for text in text_list_neg]
    words_all = []
    for item in words_pos:
        for key in item:
            words_all.append(key)
    for item in words_neg:
        for key in item:
            words_all.append(key)
    vocab = list(set(words_all))
    word2idx = {w:c for c,w in enumerate(vocab)}
    idx_words_pos = [[word2idx[item] for item in text] for text in words_pos]
    idx_words_neg = [[word2idx[item] for item in text] for text in words_neg]
    idx_test_pos = [word2idx[item] for item in test_pos]
    idx_test_neg = [word2idx[item] for item in test_neg]
    return vocab,word2idx,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg
def createOneHot(vocab,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg):
    input_dim = len(vocab)
    features_pos = torch.zeros(size=[len(idx_words_pos),input_dim])
    features_neg = torch.zeros(size=[len(idx_words_neg), input_dim])
    for i in range(len(idx_words_pos)):
        for j in idx_words_pos[i]:
            features_pos[i,j] = 1.0

    for i in range(len(idx_words_neg)):
        for j in idx_words_neg[i]:
            features_neg[i,j] = 1.0
    features = torch.cat([features_pos,features_neg],dim=0)
    labels = [1,1,1,1,1,0,0,0,0,0]
    labels = torch.LongTensor(labels)
    test_x_pos = torch.zeros(size=[1,input_dim])
    test_x_neg = torch.zeros(size=[1,input_dim])
    for item in idx_test_pos:
        test_x_pos[0,item] = 1.0
    for item in idx_test_neg:
        test_x_neg[0,item] = 1.0
    test_x = torch.cat([test_x_pos,test_x_neg],dim=0)
    test_labels = torch.LongTensor([1,0])
    return features,labels,test_x,test_labels
def randomGenerate(features):
    N = features.shape[0]
    half_n = N // 2
    support_input = torch.zeros(size=[6, features.shape[1]])
    query_input = torch.zeros(size=[4,features.shape[1]])
    postive_list = list(range(0,half_n))
    negtive_list = list(range(half_n,N))
    support_list_pos = random.sample(postive_list,3)
    support_list_neg = random.sample(negtive_list,3)
    query_list_pos = [item for item in postive_list if item not in support_list_pos]
    query_list_neg = [item for item in negtive_list if item not in support_list_neg]
    index = 0
    for item in support_list_pos:
        support_input[index,:] = features[item,:]
        index += 1
    for item in support_list_neg:
        support_input[index,:] = features[item,:]
        index += 1
    index = 0
    for item in query_list_pos:
        query_input[index,:] = features[item,:]
        index += 1
    for item in query_list_neg:
        query_input[index,:] = features[item,:]
        index += 1
    query_label = torch.LongTensor([1,1,0,0])
    return support_input,query_input,query_label




class fewModel(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_class):
        super(fewModel,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_class = num_class
        # 线性层进行编码
        self.linear = nn.Linear(input_dim,hidden_dim)


    def embedding(self,features):
        result = self.linear(features)
        return result

    def forward(self,support_input,query_input):

        support_embedding = self.embedding(support_input)
        query_embedding = self.embedding(query_input)
        support_size = support_embedding.shape[0]
        every_class_num  = support_size // self.num_class
        class_meta_dict = {}
        for i in range(0,self.num_class):
            class_meta_dict[i] = torch.sum(support_embedding[i*every_class_num:(i+1)*every_class_num,:],dim=0) / every_class_num
        class_meta_information = torch.zeros(size=[len(class_meta_dict),support_embedding.shape[1]])
        for key,item in class_meta_dict.items():
            class_meta_information[key,:] = class_meta_dict[key]
        N_query = query_embedding.shape[0]
        result = torch.zeros(size=[N_query,self.num_class])
        for i in range(0,N_query):
            temp_value = query_embedding[i].repeat(self.num_class,1)
            cosine_value = torch.cosine_similarity(class_meta_information,temp_value,dim=1)
            result[i] = cosine_value
        result = F.log_softmax(result,dim=1)
        return result

hidden_dim = 4
n_class = 2
lr = 0.01
epochs = 1000
vocab,word2idx,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg = createData()
features,labels,test_x,test_labels = createOneHot(vocab,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg)

model = fewModel(features.shape[1],hidden_dim,n_class)
optimer = optim.Adam(model.parameters(),lr=lr,weight_decay=5e-4)

def train(epoch,support_input,query_input,query_label):
    optimer.zero_grad()
    output = model(support_input,query_input)
    loss = F.nll_loss(output,query_label)
    loss.backward()
    optimer.step()
    print("Epoch: {:04d}".format(epoch),"loss:{:.4f}".format(loss))

if __name__ == '__main__':
    for i in range(epochs):
        support_input, query_input, query_label = randomGenerate(features)
        train(i,support_input,query_input,query_label)


 
### 少样本学习与元学习在文本处理中的应用 #### 定义与背景 少样本学习(few-shot learning)旨在利用少量标注数据实现高效的学习效果。对于新类别,模型仅需几个样例就能快速适应并做出准确预测[^2]。 #### 文本处理场景下的挑战 自然语言处理(NLP)任务通常面临数据稀缺问题,尤其是在低资源语言环境下。传统方法依赖大量标记语料库,在遇到新颖概念或领域迁移时表现不佳。而few-shot learning通过引入先验知识和灵活调整机制来缓解这一困境。 #### 元学习框架概述 元学习(meta-learning),也称为“学会学习”,是一种让机器具备快速获取新技能能力的技术路线。具体到NLP领域: - **Meta-filter动态对齐**:研究指出一种创新方案——即通过meta-filter实现在few-shot情境下动态调节特征表示的能力[^1]。 - **多任务联合训练**:为了增强泛化性能,可采用multi-task策略同步优化多个相关子任务,从而促进跨域知识共享。 - **原型网络(Prototypical Networks)**:由Vinyals等人提出的匹配网络(Matching Networks)[^4]及其变体广泛应用于短文本分类、关系抽取等实际案例中,其核心思想在于计算支持集和支持向量之间的相似度得分以决定最终归属。 ```python import torch.nn.functional as F from transformers import BertModel, BertTokenizer class ProtoNetTextClassifier(nn.Module): def __init__(self, pretrained_model='bert-base-uncased', n_way=5, k_shot=1): super().__init__() self.bert = BertModel.from_pretrained(pretrained_model) self.n_way = n_way self.k_shot = k_shot def forward(self, support_set_ids, query_set_ids): # 获取BERT嵌入 support_embeddings = self._get_bert_embedding(support_set_ids).mean(dim=0) query_embeddings = self._get_bert_embedding(query_set_ids) # 计算距离矩阵并与最近邻比较 dists = euclidean_dist(query_embeddings, support_embeddings) log_pred = -dists return F.log_softmax(log_pred, dim=-1) def _get_bert_embedding(self, input_ids): outputs = self.bert(input_ids=input_ids) cls_hidden_states = outputs.last_hidden_state[:, 0, :] return cls_hidden_states.unsqueeze(0) def euclidean_dist(x, y): """Compute Euclidean distance between two tensors.""" m, d = x.size() n, d_ = y.size() assert d == d_, "Feature dimensions must match." x_expanded = x.unsqueeze(1).expand(m, n, d) y_expanded = y.unsqueeze(0).expand(m, n, d) squared_diff = (x_expanded - y_expanded) ** 2 distances = squared_diff.sum(-1) return distances.sqrt() ``` 此代码片段展示了如何基于预训练的语言模型(BERT)构建一个简单的原型网络用于文本分类任务。该架构能够有效应对小规模数据集情况,并且易于扩展至其他类型的NLP应用场景。
评论 35
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值