图神经网络GNN

本文介绍了如何使用PyTorch和相关库如networkx和torch_geometric对KarateClub社交网络数据集进行分析,包括构建图模型、GCN层实现、以及节点嵌入的可视化。
摘要由CSDN通过智能技术生成
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
from torch.nn import Linear
from torch_geometric.nn import GCNConv


def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G,seed=42),with_labels=False,node_color=color,cmap='Set2')

    plt.show()

def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h=h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap='Set2')
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch:{epoch},Loss:{loss.item():.4f}',fontsize=16)
    plt.show()



class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(123)
        self.conv1 = GCNConv(dataset.num_features,4)
        self.conv2 = GCNConv(4,4)
        self.conv3 = GCNConv(4,2)
        self.classifier = Linear(2,dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)  #输入特征和邻接矩阵
        h = h.tanh()
        h = self.conv2(h,edge_index)
        h = h.tanh()
        h = self.conv3(h,edge_index)
        h = h.tanh()

        out = self.classifier(h)

        return out,h






dataset = KarateClub()
print(f'Dataset:{dataset}:')
print('===================')
print(f'Number of graphs:{len(dataset)}')
print(f'Number of features:{dataset.num_features}')
print(f'Number of classes:{dataset.num_classes}')

data = dataset[0]
print(data)


edge_index = data.edge_index
print(edge_index.t())

G = to_networkx(data,to_undirected=True)
visualize_graph(G,color=data.y)


model = GCN()
print(model)

_,h = model(data.x,data.edge_index)
print(f'Embedding shape:{list(h.shape)}')

visualize_embedding(h,color=data.y)

b46e5dc8e02a4a2482a5c24d0678c054.png

e1d7311d8e594ce3ae3e4ecfab04b250.png

511bce563a0b43f59dd6fbdd090f56ca.png

5c54780df11e4366b01b313253ec7ffb.png

 

 

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值