多任务模型SNR:Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning

Sub-Network Routing for Flexible Parameter Sharing in Multi-Task Learning

论文地址:https://ojs.aaai.org/index.php/AAAI/article/view/3788

MMoE存在的问题

MMoE示意图:
在这里插入图片描述
MMoE虽然用多个门控网络解决了多个任务之间的耦合和差异问题,使得模型能较好的处理不同相关性的任务,但是MMoE里面的多个expert相互之间没有交互,这限制了模型的进一步表达。

SNR的改进之处

SNR示意图:
在这里插入图片描述
SNR将exper进行了细粒度的拆分,拆分为多层,每层由多个子网络组成。低层的子网络和高层的子网络的连接信息(路由)是由一组二值变量编码来控制的,如果变量值为0,则表示这个子网络和上层子网络没有路由连接,如果为1,则表示有路由连接。路由越多,越接近share-bottom的多任务结构,路由越少,越接近2个单独的多任务结构。同时,各个子网络之间的交互连接,也进一步提高了多个任务的精度,这就是论文说的“灵活的参数共享”。

其中论文提出2种SNR,一种是高层和低层相互连接路由的SNR-Trans,一种是高层子网络都是由底层子网络加权求和得到的SNR-Aver。

具体做法

假设有2层,高层有2个子网络,低层有3个子网络。 u 1 \mathbf u_1 u1, u 2 \mathbf u_2 u2, u 3 \mathbf u_3 u3表示低层子网络的输出, v 1 \mathbf v_1 v1, v 2 \mathbf v_2 v2表示高层子网络的输入, z \mathbf z z表示二值编码变量 z i j ∈ { 0 , 1 } z_{ij} \in \{0, 1\} zij{0,1} W i j \mathbf W_{ij} Wij表示底层子网络和高层子网络的连接转换矩阵。

在这里插入图片描述
优化问题如下, f ( x i ; W , z ) f(\mathbf x_i; \mathbf W, \mathbf z) f(xi;W,z)是模型,这里 z i z_i zi ~ B e r n ( π i ) Bern(\pi_i) Bern(πi), π \mathbf \pi π是分布参数
在这里插入图片描述
由于 z i z_i zi是二值变量,需要转换成连续变量来优化;随机变量 s s s ~ q ( s ; ϕ ) q(s;\phi) q(s;ϕ),编码变量 z z z可以表示为
z = g ( s ) = m i n ( 1 , m a x ( 0 , s ) ) z=g(s) = min(1, max(0,s)) z=g(s)=min(1,max(0,s))
替换 z z z
在这里插入图片描述
s \mathbf s s表示成一个函数 h ( ϕ , ϵ ) h(\phi , \epsilon) h(ϕ,ϵ) ϵ \epsilon ϵ是噪音随机变量, s s s可以进一步表示为
在这里插入图片描述
加L0正则到编码变量上面,能够减少需要计算的参数量,加速计算
在这里插入图片描述
Q Q Q函数是关于 s i s_i si累积分布函数
在这里插入图片描述
所以变成
在这里插入图片描述
因此,最终加了L0正则的目标函数变为
在这里插入图片描述
线上预估时, z z z值由下面式子计算得到
在这里插入图片描述

实验部分

SNR效果显著优于各个基线,如下图
在这里插入图片描述
随模型参数量增加,SNR-Trans效果显著增加,SNR-Aver也增加,但不明显,如下:
在这里插入图片描述
同样效果下,加了L0的参数系数模型要比Dense模型size小11%,如下图

在这里插入图片描述

小结

SNR通过将共享参数层拆分为包含多个子网络的层,各个子网络之间的连接通过二值编码变量来控制连接,一方面增加各个子网络之间的交互,一方面减少总的参数量。另外通过引入随机变量来替换二值不连续编码变量等转换方式来优化模型;还通过L0来大幅减少参数量。总之,是比MMoE更加细粒度的一种多任务学习方法。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
首先,我们需要生成8-PSK调制信号。假设我们要发送1000个符号,并且每个符号占用3个采样点,则可以使用以下代码生成8-PSK信号: ``` M = 8; % 8-PSK调制 k = log2(M); % 每个符号的比特数 n = 3; % 每个符号的采样点数 numSymbols = 1000; % 发送符号的数量 data = randi([0 1], numSymbols*k, 1); % 随机生成二进制数据 dataMat = reshape(data, k, []).'; % 将数据转换为矩阵形式 dataSymbols = bi2de(dataMat, 'left-msb') + 1; % 将二进制数据转换为调制符号 modSignal = pskmod(dataSymbols, M); % 生成8-PSK调制信号 txSignal = reshape(repmat(modSignal, 1, n).', [], 1); % 将信号扩展为每个符号占用n个采样点的形式 ``` 接下来,我们需要添加高斯白噪声(AWGN)信道。假设信道信噪比为Eb/No=10 dB,则可以使用以下代码添加AWGN信道: ``` EbNo = 10; % 信道信噪比(dB) snr = EbNo + 10*log10(k) - 10*log10(n); % 计算信噪比(dB) rxSignal = awgn(txSignal, snr, 'measured'); % 添加AWGN信道 ``` 最后,我们可以使用相关函数对信号进行解调和比特解码: ``` demodSignal = pskdemod(rxSignal, M); % 8-PSK解调 demodDataMat = de2bi(demodSignal - 1, k, 'left-msb'); % 将解调符号转换为二进制数据矩阵 demodData = reshape(demodDataMat.', [], 1); % 将二进制数据矩阵转换为列向量 numErrors = sum(data ~= demodData); % 统计比特错误的数量 ber = numErrors/length(data); % 计算比特错误率 ``` 完整的MATLAB代码如下: ``` M = 8; % 8-PSK调制 k = log2(M); % 每个符号的比特数 n = 3; % 每个符号的采样点数 numSymbols = 1000; % 发送符号的数量 data = randi([0 1], numSymbols*k, 1); % 随机生成二进制数据 dataMat = reshape(data, k, []).'; % 将数据转换为矩阵形式 dataSymbols = bi2de(dataMat, 'left-msb') + 1; % 将二进制数据转换为调制符号 modSignal = pskmod(dataSymbols, M); % 生成8-PSK调制信号 txSignal = reshape(repmat(modSignal, 1, n).', [], 1); % 将信号扩展为每个符号占用n个采样点的形式 EbNo = 10; % 信道信噪比(dB) snr = EbNo + 10*log10(k) - 10*log10(n); % 计算信噪比(dB) rxSignal = awgn(txSignal, snr, 'measured'); % 添加AWGN信道 demodSignal = pskdemod(rxSignal, M); % 8-PSK解调 demodDataMat = de2bi(demodSignal - 1, k, 'left-msb'); % 将解调符号转换为二进制数据矩阵 demodData = reshape(demodDataMat.', [], 1); % 将二进制数据矩阵转换为列向量 numErrors = sum(data ~= demodData); % 统计比特错误的数量 ber = numErrors/length(data); % 计算比特错误率 ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值