STTran 源码解读(3):STTran类

文章详细解读了STTran类的初始化和前向传播过程,涉及词向量嵌入、卷积层、Transformer的构建,包括多头注意力层、局部和全局注意力机制,以及时空信息的编码解码。模型主要用于计算对象和边的上下文,处理图像中的对象关系。
摘要由CSDN通过智能技术生成

STTran 源码解读3:STTran类


init部分

1.定义ObjectClassifier:用于计算object和edge的上下文:
init部分:
(1)词向量嵌入(embed_vecs):将obj_classes转化为dim=200的词嵌入向量,对应论文中的semantic embedding vector,即公式(3)中对应的s.
2.定义union_func1卷积层
3.定义transformer类:定义时空transformer
(1)定义encoder_layer:包含head=4的多头注意力层,线性层,正则化层,dropout层;
(2)定义local_attention:由一层encoder_layer组成;
(3)定义decoder_layer: 包含head=8的多头注意力层,线性层,正则化层,dropout层;
(4)定义global_attention:由三层decoder_layer组成;
(5)定义position_embedding:embbedding层,output_dim=1936=representation vector’s dim
of the relation(对应论文中的公式(3))


forward部分

detector类的output:entry,输入STTran类的forward函数部分:

1.计算object和edge的contexts:
三种模式:
1)predcls通过bbox,class预测对象之间的relation:
entry基本保持不变;
2)sgcls通过bbox预测class以及对象之间的relation
3)sgdet通过图片预测bbox,class,对象之间的关系
<1> 训练模式:
entry[‘distribution’] 只获得组成关系对象对的头尾实体特征——实体关系分类分布。

<2>评估模式:

2.视觉部分:
在这里插入图片描述可视化向量x_visual拼接sub_rep,obj_rep,及联合框特征图,dim=(对象对数量 ,512*3)

3.语义部分:
在这里插入图片描述
语义向量x_semantic拼接sub_emb,obj_emb(对应论文公式(3)中s),dim(维度)=(对象对数量 ,200*2)

4. 时空transformer:(重点)
rel_features对应于公式(3)中的x,由x_visual和x_semantic拼接得到, dim=(对象对数量,1536+400=1936)
transformer类forward部分:

(1)空间编码器
rel_input,masks——local_attention(由一层encoder_layer组成)——local_output,local_attention_weight
dim(rel_input)=(一帧中最多的box数量,图片数量,1936)
dim(masks)=(图片数量,一帧中最多的box数量)
dim(local_output)=(len(im_idx),1936)
dim(local_attention_weight)=(1,图片数量,一帧中最多的box数量,一帧中最多的box数量)

(2)滑动窗口部分
滑动窗口数量为2,即前后帧的信息。
global_input为前后两帧的local_output。dim(global_input)=(一帧中最多的box数量2,图片对数量,1936)
position_embed(帧编码)为position_embedding的两层权重值。其维度与global_input一致,dim(position_embed)=(一帧中最多的box数量
2,图片对数量,1936)
global_mask,dim(global_mask)=(图片对数量,一帧中最多的box数量*2)

(3)时间解码器
global_input,position_embed,global_mask——global_attention(由三层decoder_layer组成)——global_output,global_attention_weight
dim(global_output)=(一帧中最多的box数量2,图片对数量,1936)
dim(global_attention_weight)=(3,对象对数量,一帧中最多的box数量
2,一帧中最多的box数量*2)

(4)返回值:output,local_attention_weight,global_attention_weight
dim(output)=dim(feature)=(len(im_idx)即对象对数量,1936)


(如有错误,还请指正,谢谢)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值