Segmented GRAPH-BERT代码理解②script_graph_classification.py

该博客详细介绍了在MUTAG数据集上进行图分类的任务,采用全输入策略,最大图大小为28,节点类别数为2。模型结构包括BertEmbeddings、BertEncoder和几个BertLayer,每个层含有注意力机制和中间层。训练参数如学习率、最大周期、隐藏层大小等被设定,并使用了交叉验证。在训练过程中,通过Bert计算预测输出,利用交叉熵损失进行反向传播。最终返回学习记录。
摘要由CSDN通过智能技术生成

设定

dataset_name = 'MUTAG'
strategy = 'full_input'

所以

max_graph_size = 28
nclass = 2

1 节点分类

k = max_graph_size # =28
lr = 0.0005

max_epoch = 500
ngraph = nfeature = max_graph_size # =28
x_size = nfeature # =28
hidden_size = intermediate_size = 32
num_attention_heads = 2
num_hidden_layers = 2
y_size = nclass # =2
graph_size = ngraph # =28
residual_type = 'none'

设定fold取自[1,11)用来交叉验证

MethodGraphBertGraphClassification(
  (bert): MethodGraphBert(
    (embeddings): BertEmbeddings(
      (raw_feature_embeddings): Linear(in_features=28, out_features=32, bias=True)
      (tag_embeddings): Embedding(1000, 32)
      (degree_embeddings): Embedding(1000, 32)
      (wl_embeddings): Embedding(1000, 32)
      (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=32, out_features=32, bias=True)
              (key): Linear(in_features=32, out_features=32, bias=True)
              (value): Linear(in_features=32, out_features=32, bias=True)
              (dropout): Dropout(p=0.3, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=32, out_features=32, bias=True)
              (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.5, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=32, out_features=32, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=32, out_features=32, bias=True)
            (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.5, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=32, out_features=32, bias=True)
              (key): Linear(in_features=32, out_features=32, bias=True)
              (value): Linear(in_features=32, out_features=32, bias=True)
              (dropout): Dropout(p=0.3, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=32, out_features=32, bias=True)
              (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.5, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=32, out_features=32, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=32, out_features=32, bias=True)
            (LayerNorm): LayerNorm((32,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.5, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=32, out_features=32, bias=True)
      (activation): Tanh()
    )
  )
  (res_h): Linear(in_features=784, out_features=32, bias=True)
  (res_y): Linear(in_features=784, out_features=2, bias=True)
  (cls_y): Linear(in_features=32, out_features=2, bias=True)
)

选取训练集在这里插入图片描述
将对应的节点x, 节点度d, 权重w, wl, y_true, context_idx_list
在这里插入图片描述
在这里插入图片描述
设置

residual_type = 'none'

得到residual_h, residual_y都是None
通过bert计算出y_pred
在这里插入图片描述
经过交叉熵损失BP后返回learning_record_dict
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值