分位稀疏Transformer(Quantile Sparse Transformer,QSTransformer)稀疏attention实例

scient

scient一个用python实现科学计算相关算法的包,包括自然语言、图像、神经网络、优化算法、机器学习、图计算等模块。

scient源码和编译安装包可以在Python package index获取。

The source code and binary installers for the latest released version are available at the [Python package index].

https://pypi.org/project/scient

可以用pip安装scient

You can install scient like this:

pip install scient

也可以用setup.py安装。

Or in the scient directory, execute:

python setup.py install

scient.neuralnet

神经网络相关算法模块,包括attention、transformer、bert、lstm、resnet、crf、dataset、fit等。

scient.neuralnet.transformer

transformer模块,实现以多头注意力机制为基础的transformer模型,包括Transformer、ViTransformer、Former、Encoder、Decoder。

分位稀疏Transformer(Quantile Sparse Transformer,QSTransformer)

分位稀疏Transformer(Quantile Sparse Transformer)模型的原理是通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。通过选择大于分位数(quantile)的attention矩阵的元素,将注意退化为稀疏注意。这样,保留最有助于引起注意的部分,并删除其他无关的信息。这种选择性方法在保存重要信息和消除噪声方面是有效的。注意力可以更多地集中在最有贡献的价值因素上。

qstransformer

分位稀疏Transformer的参数:

  • quantile: float 取值0-1 默认为None, 在attention中,大于等quantile的值保留,小于quantile的值对计算不起作用;
  • dim: int or tuple or list 默认为-1,取分位数的维度。例如:dim=-1对QK注意力矩阵的每行按quantile进行稀疏化;dim=-2对QK注意力矩阵的每列按quantile进行稀疏化;dim=(-1,-2)对QK注意力矩阵的所有元素按quantile进行稀疏化;

这里采用ViTransformer来测试Quantile Sparse Transformer,数据集、模型、训练的代码参考ViTransformer,这里将ViTransformer模型构建的代码由

model=transformer.ViTransformer(n_class=10,patch_size=16,in_shape=(160,160),embed_size=768)

改为

model=transformer.ViTransformer(n_class=10,patch_size=16,in_shape=(160,160),embed_size=768,quantile=0.75)

训练过程如下:

train iter 0: avg_batch_loss=2.25894 perform=0.16254: 100%|██████████| 283/283 [04:42<00:00,  1.00it/s]
eval iter 0: avg_batch_loss=2.21871 perform=0.18045: 100%|██████████| 123/123 [01:09<00:00,  1.78it/s]
train iter 1: avg_batch_loss=2.16985 perform=0.19922: 100%|██████████| 283/283 [05:02<00:00,  1.07s/it]
eval iter 1: avg_batch_loss=2.19165 perform=0.19165: 100%|██████████| 123/123 [01:16<00:00,  1.62it/s]
train iter 2: avg_batch_loss=2.05453 perform=0.23911: 100%|██████████| 283/283 [05:07<00:00,  1.09s/it]
eval iter 2: avg_batch_loss=2.27452 perform=0.19343: 100%|██████████| 123/123 [01:15<00:00,  1.62it/s]
train iter 3: avg_batch_loss=2.00221 perform=0.26947: 100%|██████████| 283/283 [05:07<00:00,  1.09s/it]
eval iter 3: avg_batch_loss=1.96465 perform=0.28556: 100%|██████████| 123/123 [01:11<00:00,  1.71it/s]
train iter 4: avg_batch_loss=1.88308 perform=0.30891: 100%|██████████| 283/283 [04:47<00:00,  1.02s/it]
eval iter 4: avg_batch_loss=1.84558 perform=0.34105: 100%|██████████| 123/123 [01:07<00:00,  1.81it/s]
train iter 5: avg_batch_loss=1.79527 perform=0.34515: 100%|██████████| 283/283 [05:01<00:00,  1.07s/it]
eval iter 5: avg_batch_loss=1.88459 perform=0.31534: 100%|██████████| 123/123 [01:15<00:00,  1.62it/s]
train iter 6: avg_batch_loss=1.71003 perform=0.38138: 100%|██████████| 283/283 [05:07<00:00,  1.09s/it]
eval iter 6: avg_batch_loss=1.77757 perform=0.35963: 100%|██████████| 123/123 [01:09<00:00,  1.78it/s]
train iter 7: avg_batch_loss=1.62553 perform=0.41473: 100%|██████████| 283/283 [04:48<00:00,  1.02s/it]
eval iter 7: avg_batch_loss=1.70259 perform=0.39602: 100%|██████████| 123/123 [01:15<00:00,  1.62it/s]
train iter 8: avg_batch_loss=1.52049 perform=0.45185: 100%|██████████| 283/283 [05:04<00:00,  1.07s/it]
eval iter 8: avg_batch_loss=1.66643 perform=0.41359: 100%|██████████| 123/123 [01:09<00:00,  1.76it/s]
train iter 9: avg_batch_loss=1.47260 perform=0.47722: 100%|██████████| 283/283 [04:55<00:00,  1.04s/it]
eval iter 9: avg_batch_loss=1.66401 perform=0.41562: 100%|██████████| 123/123 [01:13<00:00,  1.68it/s]

与ViTransformer训练过程相比,Quantile Sparse Transformer的每个iter的准确率都比ViTransformer有所提升,说明Quantile Sparse Transformer是对原生Transformer的有效改进。

如果将dim设置为(-1,-2),训练过程如下:

model=transformer.ViTransformer(n_class=10,patch_size=16,in_shape=(160,160),embed_size=768,quantile=0.75,dim=(-1,-2))
train iter 0: avg_batch_loss=2.21098 perform=0.17917: 100%|██████████| 283/283 [04:26<00:00,  1.06it/s]   
eval iter 0: avg_batch_loss=2.21995 perform=0.19623: 100%|██████████| 123/123 [01:01<00:00,  2.00it/s]   
train iter 1: avg_batch_loss=2.13941 perform=0.20842: 100%|██████████| 283/283 [04:26<00:00,  1.06it/s]   
eval iter 1: avg_batch_loss=2.24416 perform=0.16009: 100%|██████████| 123/123 [01:01<00:00,  1.99it/s]   
train iter 2: avg_batch_loss=2.02153 perform=0.24632: 100%|██████████| 283/283 [04:33<00:00,  1.03it/s]   
eval iter 2: avg_batch_loss=2.30085 perform=0.19369: 100%|██████████| 123/123 [01:01<00:00,  2.01it/s]   
train iter 3: avg_batch_loss=1.92945 perform=0.2913: 100%|██████████| 283/283 [04:27<00:00,  1.06it/s]    
eval iter 3: avg_batch_loss=2.17014 perform=0.23237: 100%|██████████| 123/123 [01:02<00:00,  1.97it/s]   
train iter 4: avg_batch_loss=1.83319 perform=0.32277: 100%|██████████| 283/283 [04:28<00:00,  1.05it/s]   
eval iter 4: avg_batch_loss=1.94234 perform=0.296: 100%|██████████| 123/123 [01:02<00:00,  1.95it/s]     
train iter 5: avg_batch_loss=1.75904 perform=0.35867: 100%|██████████| 283/283 [04:31<00:00,  1.04it/s]   
eval iter 5: avg_batch_loss=1.81859 perform=0.35378: 100%|██████████| 123/123 [01:02<00:00,  1.97it/s]   
train iter 6: avg_batch_loss=1.68330 perform=0.39402: 100%|██████████| 283/283 [04:52<00:00,  1.03s/it]   
eval iter 6: avg_batch_loss=1.80461 perform=0.35225: 100%|██████████| 123/123 [01:02<00:00,  1.96it/s]   
train iter 7: avg_batch_loss=1.59485 perform=0.42892: 100%|██████████| 283/283 [04:39<00:00,  1.01it/s]   
eval iter 7: avg_batch_loss=1.64702 perform=0.41614: 100%|██████████| 123/123 [01:06<00:00,  1.84it/s]   
train iter 8: avg_batch_loss=1.50010 perform=0.45784: 100%|██████████| 283/283 [04:41<00:00,  1.01it/s]   
eval iter 8: avg_batch_loss=1.63136 perform=0.41665: 100%|██████████| 123/123 [01:06<00:00,  1.84it/s]   
train iter 9: avg_batch_loss=1.43621 perform=0.48676: 100%|██████████| 283/283 [04:39<00:00,  1.01it/s]   
eval iter 9: avg_batch_loss=1.61981 perform=0.42606: 100%|██████████| 123/123 [01:04<00:00,  1.92it/s]

可以看到按行进行稀疏化比按所有元素进行稀疏化的效果要好。

同志们可以自行尝试更多可能。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值