加入残差的模型

文章探讨了在GAT模型中加入残差连接以解决过平滑问题的方法,分析了为何第一层和最后一层不使用残差连接,并比较了不同模型配置(如GAT、GATv2和不同头数版本)在是否使用残差和邻居采样策略上的影响。还提到了《TrainingGraphNeuralNetworkswith1000Layers》论文的影响。
摘要由CSDN通过智能技术生成

目前准备向模型里加入残差连接,以缓解过平滑问题,但昨天用GPU24GB的服务器跑16层的模型,还是带不动,先不管了。先明确想要尝试的模型

在作者原代码里给了四个前向传播函数

1. 邻居采样中的前向传播

    def forward_neighbor_sampler(self, x, adjs):#定义了forward_neighbor_sampler方法,用于在邻居采样中前向传播。
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            new_x = self.layers[i]((x, x_target), edge_index)
            #将目标节点特征x_target作为输入的第一个部分,将源节点特征x作为输入的第二个部分,
            # 传递给self.layers[i]方法进行处理,得到新的节点特征new_x。
            if i != self.num_layers - 1:
                new_x = self.non_linearity(new_x)
            if 0 < i < self.num_layers - 1 and self.use_residual:
                x = new_x + x_target
            else:
                x = new_x
            if i < self.num_layers - 1:
                if self.use_layer_norm:
                    x = self.layer_norms[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

在该函数中,如果使用残差,则在除第一层和最后一层外的中间层,直接在新的节点特征上加原始节点特征。

疑问:为什么第一层不添加残差连接?

解释:在邻居采样中的前向传播过程中,第一层的输入特征由变量x表示,并且没有进行残差连接。这是因为第一层的输入特征x是原始的节点特征,还没有经过其他层的变换,因此没有前一层的特征用于进行残差连接。

疑问:为什么最后一层不添加残差连接

解释:在给定的代码中,最后一层没有使用残差连接是因为在残差连接的设计中,通常将残差连接应用于中间层而不是最后一层。残差连接的目的是通过将前一层的输出添加到当前层的输出中,以便传递更多的信息和梯度。在中间层应用残差连接可以帮助缓解梯度消失问题,并提高模型的训练效果。然而,在最后一层应用残差连接可能会引入不必要的复杂性,并且可能不会带来明显的性能提升。最后一层通常是模型的输出层,它的输出直接用于任务的预测或分类。因此,在最后一层不使用残差连接可以简化模型结构并减少计算量。总结起来,最后一层不使用残差连接是常见的做法,因为残差连接通常应用于中间层,而最后一层的输出直接用于模型的预测或分类。

2. 带有残差连接的前向传播

    def exp_forward_neighbor_sampler(self, x, adjs):#用于在邻居采样中前向传播(带有残差连接)。
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]  # Target nodes are always placed first.
            new_x = self.layers[i]((x, x_target), edge_index)
            if self.use_residual:
                if self.use_resdiual_linear:
                    x = new_x + self.residuals[i](x_target)
                else:
                    x = new_x + x_target
            else:
                x = new_x

            if i < self.num_layers - 1:
                x = self.non_linearity(x)
                if self.use_layer_norm:
                    x = self.layer_norms[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

如果使用残差连接,进而判断是否将残差连接线性转换,增添一个线性层,或者如上面直接相加。

3. 在层之间传播

    def forward_saint(self, x, adj_t):
        for i, layer in enumerate(self.layers[:-1]):
            new_x = layer(x, adj_t)
            new_x = self.non_linearity(new_x)
            # residual
            if i > 0 and self.use_residual:
                if self.use_resdiual_linear:
                    x = new_x + self.residuals[i](x)
                else:
                    x = new_x + x
                x = new_x + x
            else:
                x = new_x
            if self.use_layer_norm:
                x = self.layer_norms[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.layers[-1](x, adj_t)
        return x

该前向传播函数中,也是在中间层加入了残差连接。

模型中实际调用前向传播函数:

    def forward(self, x, adjs):
        if self.saint:
            return self.forward_saint(x, adjs)
        else:
            return self.forward_neighbor_sampler(x, adjs)

代码尝试:

1. 不带任何残差连接的三层GAT,GATv2

修改参数,将默认改为False

不使用任何残差相关的模型,记录测试集准确率

GAT 1head     

learnable_params: 110672

GATv2 1head

2. 不加残差,将use_saint改为true,换了个采样,不知道有啥影响

if args.use_saint:
        for key, idx in split_idx.items():
            mask = torch.zeros(data.num_nodes, dtype=torch.bool)
            mask[idx] = True
            data[f'{key}_mask'] = mask
        train_loader = GraphSAINTRandomWalkSampler(data,
                                                   batch_size=args.batch_size,
                                                   walk_length=args.walk_length,
                                                   num_steps=args.num_steps,
                                                   sample_coverage=0,
                                                   save_dir=dataset.processed_dir)
    else:
        train_loader = NeighborSampler(data.edge_index, node_idx=train_idx, sizes=args.num_layers * [10],
                                       batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

2. 带有哪种残差连接的3层GAT

想要尝试的加入残差连接的模型层

GAT1heads

GATv2_1heads

gat8heads

GATv28heads

不加入线性残差

GATv2 1heads

标题

将模型层数改为8,跑了4个模型

1. 将h0和倒数第二层的输入结合一起输入给最后一层模型

然后我就看了篇论文

《Training Graph Neural Networks with 1000 Layers》

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值