pytorch|图卷积神经网络(GCN)在Karate数据集的应用

博主又重新学习GCN,本文用的数据集是空手道数据集,最后呈现出可视化。

首先附上全部代码:

import matplotlib.pyplot as plt
import torch
from torch.nn import Linear
from torch_geometric.datasets import KarateClub
from torch_geometric.nn import GCNConv


def show_embedding(h, color, epoch=None, loss=None):

    plt.figure(figsize=(14,9))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], s=140, c=color, cmap='Set2')
    if epoch is not None and loss is not None:
        plt.xlabel('Epoch: {}, Loss: {:.4f}'.format(epoch, loss), fontsize=16)
    plt.show()

def show_loss(epoch,loss):

    plt.figure(figsize=(14,9))
    plt.plot(epoch,loss,marker='<')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid()
    plt.show()

class GCN(torch.nn.Module):

    def __init__(self):
        super().__init__()
        torch.manual_seed(12234)
        self.conv1 = GCNConv(dataset.num_features, 4) # 定义好输入特征和输出特征
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2) # 输出 2 维向量
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index) # 输入特征 与 邻接矩阵
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # 此时的 h 是 2 维向量

        # 分类层
        out = self.classifier(h)
    
        return out,h

if __name__ == "__main__":

    dataset = KarateClub()
    data = dataset[0]
    Model = GCN()
    print(Model) # 模型结构

    criterion = torch.nn.CrossEntropyLoss() # 定义损失函数
    optimizer = torch.optim.Adam(Model.parameters(), lr=0.01) # 定义优化器

    def train(data):
        optimizer.zero_grad()
        out,h = Model(data.x, data.edge_index) # 两维向量
        loss = criterion(out[data.train_mask], data.y[data.train_mask]) # 半监督
        loss.backward() # 反向传播
        optimizer.step() # 迭代更新

        return loss, h

    loss_data = []

    for epoch in range(401):
        loss, h = train(data)
        loss_data.append(loss)
        if epoch % 50 == 0:
            show_embedding(h, color=data.y, epoch=epoch, loss=loss)

    show_loss(epoch=range(401), loss=loss_data)

然后看下输出的模型:
在这里插入图片描述
下面是h的可视化,只展示了最开始和最后的情况
在这里插入图片描述
在这里插入图片描述
最后是Loss值随着Epoch的变化
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xiao黄

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值