MethodSegmentedGraphBertGraphClassification(
(bert): MethodGraphBert(
(embeddings): BertEmbeddings(
(raw_feature_embeddings): Linear(in_features=40, out_features=32, bias=True)
(tag_embeddings): Embedding(1000, 32)
(degree_embeddings): Embedding(1000, 32)
(wl_embeddings): Embedding(1000, 32)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=32, out_features=32, bias=True)
(key): Linear(in_features=32, out_features=32, bias=True)
(value): Linear(in_features=32, out_features=32, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=32, out_features=32, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=32, out_features=32, bias=True)
(key): Linear(in_features=32, out_features=32, bias=True)
(value): Linear(in_features=32, out_features=32, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=32, out_features=32, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=32, out_features=32, bias=True)
(activation): Tanh()
)
)
(res_h): Linear(in_features=1600, out_features=32, bias=True)
(res_y): Linear(in_features=1600, out_features=2, bias=True)
(cls_y): Linear(in_features=32, out_features=2, bias=True)
)
数据准备的差别在于context_idx_list换成了segment_count_list
segment_fusion_output = torch.zeros(size=[segment_count_list.size()[0], sequence_output.size()[1]])
current_global_seg_index = 0
for graph_index in range(segment_count_list.size()[0]):
#[0,169)
graph_seg_number = segment_count_list[graph_index].item()
#即每一个对应的segment_count
for seg_i in range(current_global_seg_index, current_global_seg_index + graph_seg_number):
segment_fusion_output[graph_index] += sequence_output[seg_i]
segment_fusion_output[graph_index] /= graph_seg_number
current_global_seg_index += graph_seg_number