【文献阅读】Transfer Learning For Text Classification Via Model Risk Analysis

前言:本文是以文本分类的迁移学习任务为例,对风险分析模型的整体框架流程做梳理。

1. LearnRisk

1.1 motivatio

  • 传统的DNN结果有一定错误的风险
  • 迁移学习目标域的标签数据难以获得,通常只有少量有标签样本

1.2 overall

在这里插入图片描述

风险分析整体分为三步:

  • 构造风险特征
  • 构建风险模型
  • 训练风险模型

2. LearnRisk-TC

在这里插入图片描述

  • 核心思路:在源域上训练好一个base model后,用目标域的少量有标签样本(如valid dataset)去训练风险模型,最后用无标签的test dataset重新微调base model。
  • 主要流程
    (1)源数据集训练base model;
    (2)有标签的目标域的验证数据集构建一批风险特征(决策树规则);
    (3)构建每个类别的正态分布:对每个风险特征构建一个正态分布(u是先验, σ {\sigma} σ后验),风险特征加权和作为每个类别的正态分布;
    (4)训练风险模型:损失函数的目标是实现正确的风险排序(风险由高到低);
    (5)利用无标签的目标域的测试数据集进行base model的微调。

2.1 构造风险特征

2.1.1 risk metric

文章中将risk metric主要分为两类,statistics-based risk metricsDNN-based risk metrics。对于每一个risk metric,都会生成一个长度为N的一维向量,N为总的类别数。假设目标域的测试数据集大小为Q,每一个文本都会有X个risk metric,最终共生成了Q*X个risk metric。

  • statistics-based risk metrics
    文章中构建了一种新的统计特征,计算公式如下:在这里插入图片描述
    其中,p为超参
    (1) C H I n e w = C H I ∗ α {CHI_{new} = CHI * {\alpha}} CHInew=CHIα, 各项解释如下:
    在这里插入图片描述

    (2) T F − I D F n e w = T F n e w ∗ I D F n e w ∗ β ∗ λ {TF-IDF_{new} = TF_{new} * IDF_{new} * {\beta} * {\lambda}} TFIDFnew=TFnewIDFnewβλ,各项解释如下:
    在这里插入图片描述
    在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/58671c4563754de7a1ddcefc4b8d15e2.png)
  • DNN-based risk metrics
    文章采用了两种模型,bert和textcnn用于提取文档特征,然后使用knn和ccd两种方法计算。

2.1.2 risk feature

文章使用单边决策树来生成risk features,决策树生成的每条规则对应了一个risk feature,如下:
在这里插入图片描述
在这里插入图片描述
最终会生成一批决策树,得到一批规则,即一批risk feature。
注意:针对每个类别生成一批决策树,也就是说,每个类别的风险特征都不一样。

2.2 构建风险模型

在这里插入图片描述
对于多分类问题,假设共N个类别,对于每个类别分别构建一个风险模型。
类别i的风险模型构建的主要流程
(1)对每个风险特征分别建立一个正态分布 N ( u , σ 2 ) {N(u, \sigma^2)} N(u,σ2)
u是先验知识: u = n / m {u=n/m} u=n/m,n是风险分析的训练数据集(即目标域的验证数据集)中成功匹配该风险特征的文档数,m是训练数据集中属于该类别的总文档数。
σ {\sigma} σ是后验知识,待模型训练得到。
注意:不同类别对应的各个风险特征的正态分布并不一样。

(2)求所有风险特征的加权和作为类别i的正态分布。
所有的风险特征都是一条条规则,指向的是匹配某个类别,假设类别i共5个风险特征,某文档匹配风险特征2,3,5,则特征向量为(0,1,1,0,1)。类别i的权重向量为 w i w_i wi则i的正态分布计算如下:
u i = x i ( w i ∗ u f ) {u_i = x_i (w_i * u_f)} ui=xi(wiuf)
σ i 2 = x i ( w i ∗ σ f 2 ) {\sigma_i^2 = x_i (w_i * \sigma_f^2)} σi2=xi(wiσf2)
其中 u f u_f uf代表的是一个长度为m的一维向量,即每个风险特征的u, σ f 2 \sigma_f^2 σf2同理。

2.3 训练风险模型

风险模型的训练目标是排序,即能够让高风险的文档正确的排在低风险文档的前面,或者说能让分类错误的文档排在分类正确的文档前面。
损失构建如下:
在这里插入图片描述
在这里插入图片描述

2.4 微调base model

核心思想:用base model对目标域的测试数据集做预测,求每个文本的预测类别,然后用训练好的风险模型去计算该类别的风险值,对base model设计一个新的损失函数进行微调。
损失函数如下:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值