1、torch.nn.Transformer()
API解释参考:《pytorch中的transformer》
2 、CNN使用Transformer代码
注意:
[1]nhead必须能被d_model整除(序列被几个头注意)
[2]CNN特征图通道512被当成序列,放到第一个维度,批次放到第二个维度
[3]Transformer必须有src和tgt两个向量,CNN是自相关性解算,都放入特征图向量
特别提示
:
Transformer计算量大,他提取的是全局信息。CNN更多偏向于局部信息,使用时可以利用深层的特征图部分通道做Transformer后,变换回去和其他未做Transformer的特征图cat继续卷积
。
这样即引入Transformer提供了全局语义信息,又不至于计算量过大
。
net_2 = nn.Transformer(d_model=81,<