图模型训练

一、依赖安装

网址:pyg-team/pytorch_geometric: Graph Neural Network Library for PyTorch (github.com)

找到此处,点击here进入依赖安装界面

找到自己安装的torch版本并点击,,进入安装依赖

二、用库自带的数据集

代码:

import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx


def visualiize_graph(G,color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G,seed = 42), node_color=color, cmap = "Set2", with_labels=False)

    plt.show()

def visualize_embedding(h,color,epoch = None,loss = None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], c=color, cmap = "Set2", s = 140)
    if epoch is not None and loss is not None:
        plt.xlabel("Epoch: {}, Loss: {:.4f}".format(epoch, loss))
    plt.show()


from torch_geometric.datasets import KarateClub

dataset = KarateClub()

data = dataset[0]
print(data)
#展示点之间的关系
edge_index = data.edge_index
print(edge_index.t())

#可视化
G = to_networkx(data, to_undirected=True)
visualiize_graph(G,data.y)

查看数据集格式

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

数据可视化

三、模型网络搭建

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub

from test import visualize_embedding

dataset = KarateClub()
data = dataset[0]

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self,x,edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()

        #分类层
        out = self.classifier(h)
        return out,h

model = GCN()
# print(model)

_,h = model(data.x, data.edge_index)
print(f'Hidden state size: {h.size()}')

visualize_embedding(h,color = data.y)

四、模型训练

代码:

import time
import torch.nn
from model import GCN
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub

model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def visualiize_graph(G,color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G,seed = 42), node_color=color, cmap = "Set2", with_labels=False)

    plt.show()

def visualize_embedding(h,color,epoch = None,loss = None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0], h[:,1], c=color, cmap = "Set2", s = 140)
    if epoch is not None and loss is not None:
        plt.xlabel("Epoch: {}, Loss: {:.4f}".format(epoch, loss))
    plt.show()

def train(data):
    optimizer.zero_grad()
    out,h = model(data.x,data.edge_index)
    loss = criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss,h

dataset = KarateClub()
data = dataset[0]

for epoch in range(101):
    loss,h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h,color = data.y, epoch = epoch, loss = loss)
        time.sleep(0.3)

结果:

 

 

最后会发现不同颜色的点逐渐分散开,loss越来越小 

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

几两春秋梦_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值