init部分
1.定义ObjectClassifier:用于计算object和edge的上下文:
init部分:
(1)词向量嵌入(embed_vecs):将obj_classes转化为dim=200的词嵌入向量,对应论文中的semantic embedding vector,即公式(3)中对应的s.
2.定义union_func1卷积层
3.定义transformer类:定义时空transformer
(1)定义encoder_layer:包含head=4的多头注意力层,线性层,正则化层,dropout层;
(2)定义local_attention:由一层encoder_layer组成;
(3)定义decoder_layer: 包含head=8的多头注意力层,线性层,正则化层,dropout层;
(4)定义global_attention:由三层decoder_layer组成;
(5)定义position_embedding:embbedding层,output_dim=1936=representation vector’s dim
of the relation(对应论文中的公式(3))
forward部分
detector类的output:entry,输入STTran类的forward函数部分:
1.计算object和edge的contexts:
三种模式:
1)predcls通过bbox,class预测对象之间的relation:
entry基本保持不变;
2)sgcls通过bbox预测class以及对象之间的relation
3)sgdet通过图片预测bbox,class,对象之间的关系
<1> 训练模式:
entry[‘distribution’] 只获得组成关系对象对的头尾实体特征——实体关系分类分布。
<2>评估模式:
2.视觉部分:
可视化向量x_visual拼接sub_rep,obj_rep,及联合框特征图,dim=(对象对数量 ,512*3)
3.语义部分:
语义向量x_semantic拼接sub_emb,obj_emb(对应论文公式(3)中s),dim(维度)=(对象对数量 ,200*2)
4. 时空transformer:(重点)
rel_features对应于公式(3)中的x,由x_visual和x_semantic拼接得到, dim=(对象对数量,1536+400=1936)
transformer类forward部分:
(1)空间编码器
rel_input,masks——local_attention(由一层encoder_layer组成)——local_output,local_attention_weight
dim(rel_input)=(一帧中最多的box数量,图片数量,1936)
dim(masks)=(图片数量,一帧中最多的box数量)
dim(local_output)=(len(im_idx),1936)
dim(local_attention_weight)=(1,图片数量,一帧中最多的box数量,一帧中最多的box数量)
(2)滑动窗口部分
滑动窗口数量为2,即前后帧的信息。
global_input为前后两帧的local_output。dim(global_input)=(一帧中最多的box数量2,图片对数量,1936)
position_embed(帧编码)为position_embedding的两层权重值。其维度与global_input一致,dim(position_embed)=(一帧中最多的box数量2,图片对数量,1936)
global_mask,dim(global_mask)=(图片对数量,一帧中最多的box数量*2)
(3)时间解码器
global_input,position_embed,global_mask——global_attention(由三层decoder_layer组成)——global_output,global_attention_weight
dim(global_output)=(一帧中最多的box数量2,图片对数量,1936)
dim(global_attention_weight)=(3,对象对数量,一帧中最多的box数量2,一帧中最多的box数量*2)
(4)返回值:output,local_attention_weight,global_attention_weight
dim(output)=dim(feature)=(len(im_idx)即对象对数量,1936)
(如有错误,还请指正,谢谢)