设定
dataset_name = 'MUTAG'
strategy = 'full_input'
所以
max_graph_size = 28
nclass = 2
1 节点分类
k = max_graph_size # =28
lr = 0.0005
max_epoch = 500
ngraph = nfeature = max_graph_size # =28
x_size = nfeature # =28
hidden_size = intermediate_size = 32
num_attention_heads = 2
num_hidden_layers = 2
y_size = nclass # =2
graph_size = ngraph # =28
residual_type = 'none'
设定fold取自[1,11)用来交叉验证
MethodGraphBertGraphClassification(
(bert): MethodGraphBert(
(embeddings): BertEmbeddings(
(raw_feature_embeddings): Linear(in_features=28, out_features=32, bias=True)
(tag_embeddings): Embedding(1000, 32)
(degree_embeddings): Embedding(1000, 32)
(wl_embeddings): Embedding(1000, 32)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=32, out_features=32, bias=True)
(key): Linear(in_features=32, out_features=32, bias=True)
(value): Linear(in_features=32, out_features=32, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=32, out_features=32, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=32, out_features=32, bias=True)
(key): Linear(in_features=32, out_features=32, bias=True)
(value): Linear(in_features=32, out_features=32, bias=True)
(dropout): Dropout(p=0.3, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=32, out_features=32, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=32, out_features=32, bias=True)
(LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.5, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=32, out_features=32, bias=True)
(activation): Tanh()
)
)
(res_h): Linear(in_features=784, out_features=32, bias=True)
(res_y): Linear(in_features=784, out_features=2, bias=True)
(cls_y): Linear(in_features=32, out_features=2, bias=True)
)
选取训练集
将对应的节点x, 节点度d, 权重w, wl, y_true, context_idx_list
设置
residual_type = 'none'
得到residual_h, residual_y都是None
通过bert计算出y_pred
经过交叉熵损失BP后返回learning_record_dict