TNT : transformer in transformer

1. TNT

Vit只关注了patch级别的信息,忽略了patch 内部的像素级别的局部信息

在这里插入图片描述
patch可以理解为将图像改为一个一个的grid,然后在每个grid中有很多像素点。
在这里插入图片描述
第一层的in transformer将像素通过上面的w和b线性融合到patch中,实现第一层的transformer
vit是直接将patch Z加入到transformer,所以上面的将像素融合到patch,然后将patch再加入到transformer中
可以看到TNT模块那里,就是进行不断的叠加,不断的迭代。
对于图像,也需要加入位置的编码信息,在patch级别和pixel级别分别要进行位置的编码

在这里插入图片描述
对于一个3X3像素的patch而言,每个patch有个位置编码I ,II。然后每个patch的像素级别的编码一共是有3X3=9个,注意的是,每个patch的像素位置编码是共享的,都是1~9,即patch中像素m^2。
在这里插入图片描述
所以总体的框架流程如上图所示:图像分为patch,每一个patch生成的向量放入到TNT block中,嵌套L层,最后输出分类的结果。最前面加了一个class token,同patch生成的向量一块放入到TNT block中,就是用来分类用的。将输出加一个全连接层,达到分类的效果。全连接层就是将向量映射到标签空间上,实现分类的效果
在这里插入图片描述
总体的框架就是通过TNT对图像进行编码,然后通过transformer进行解码。
在这里插入图片描述
可以看到进行了pixel和patch的位置编码是可以提升准确率的

3. 损失函数

focal loss可以解决样本分布不均衡的问题
anti focal loss是和focal相反的,但是在seq2seq上效果会更好
在这里插入图片描述

4. beamsearch

在这里插入图片描述
在进行预测的时候实际上就是通过上N个的输入,预测一个输出
本质上是找到一个序列,是这个序列的概率最大,实现的过程就是每次找到下一个的最大的概率,然后将一个一个的串在一起,就相当于最大概率的输出序列
beamsearch本质上是找一个序列,就是这个序列存在的概率是最大的,是在寻找的过程中,一次性走两步 ,而非上面的每次只找一个,走一步的方式。
如图所示,可以一次性走1,5,10,50,100步,随着步数的增加,最有可能找到全局的最优解。
在这里插入图片描述

5. trick

在这里插入图片描述
data-leak数据泄漏,或者说就是个小的bug,针对这个bug没准训练的效果会更好。
seq长度就是label长度,长度实际上是跟分辨率挂钩的,比较长的序列上,分辨率大一些比较好,因为有很多噪声点,也容易学进去,
在这里插入图片描述
在这里插入图片描述

5.4 trick4

上面的后处理程序的核心思想就是将你的inchi的转成化学式,然后通过三方接口将化学式再转成标准的inchi格式,然后和你的inchi进行对比校验,如果是一样的说明表达式没有问题

5.5 trick5

测试集是没有标签的,可以通过我们已经预测的结果对测试集进行打伪标签,然后进行喂给模型进行微调fine-tuning
存在的问题就是可能过拟合

5.6 trick6

标签平滑

5.7 trick7

进行norm,提分0.1左右
参考:Normalize your predictions

from tqdm import tqdm
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from pathlib import Path

def normalize_inchi(inchi):
    try:
        mol = Chem.MolFromInchi(inchi)
        return inchi if (mol is None) else Chem.MolToInchi(mol)
    except: return inchi


# Segfault in rdkit taken care of, run it with:
# while [ 1 ]; do python normalize_inchis.py && break; done
if __name__=='__main__':
    # Input & Output
    orig_path = Path('submission.csv')
    norm_path = orig_path.with_name(orig_path.stem+'_norm.csv')
    
    # Do the job
    N = norm_path.read_text().count('\n') if norm_path.exists() else 0
    print(N, 'number of predictions already normalized')

    r = open(str(orig_path), 'r')
    w = open(str(norm_path), 'a', buffering=1)

    for _ in range(N):
        r.readline()
    line = r.readline()  # this line is the header or is where it segfaulted last time
    w.write(line)

    for line in tqdm(r):
        splits = line[:-1].split(',')
        image_id = splits[0]
        inchi = ','.join(splits[1:]).replace('"','')
        inchi_norm = normalize_inchi(inchi)
        w.write(f'{image_id},"{inchi_norm}"\n')
    r.close()
    w.close()

How much difference it made (optional)

import pandas as pd
import edlib
from tqdm import tqdm

sub_df = pd.read_csv('submission.csv')
sub_norm_df = pd.read_csv('submission_norm.csv')

lev = 0
N = len(sub_df)
for i in tqdm(range(N)):
    inchi, inchi_norm = sub_df.iloc[i,1], sub_norm_df.iloc[i,1]
    lev += edlib.align(inchi, inchi_norm)['editDistance']

print(lev/N)

6. 比赛结果

在这里插入图片描述

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值