环境配置与PyG中图与图数据集的使用

一、环境配置

因为没有GPU,我这里只安装了CPU版本

  1. 先安装pytorch
pip install torch torchvision torchaudio
  1. 安装PyG
pip install torch-scatter
pip install torch-sparse
pip install torch-cluster
pip install torch-spline-conv
pip install torch-geometric

二、Data类

Data类的构造函数如下:

class Data(object):

    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None, **kwargs):
    r"""
    Args:
        x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`
        edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第0行为尾节点,第1行为头节点,头指向尾
        edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
        y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
    """
    self.x = x
    self.edge_index = edge_index
    self.edge_attr = edge_attr
    self.y = y

    for key, item in kwargs.items():
        if key == 'num_nodes':
            self.__num_nodes__ = item
        else:
            self[key] = item

Data类的使用:

import torch
from torch_geometric.data import Data

# 节点数据
x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=torch.float)
# 节点类型
y = torch.tensor([0, 1, 0, 1], dtype=torch.int)
# 邻接矩阵 COO格式
edge_index = torch.tensor([[0, 0, 1, 2, 3],
                           [3, 1, 0, 1, 2]], dtype=torch.long)

data = Data(x=x, y=y, edge_index=edge_index)

三、Dataset类

生成数据集

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/dataset/Cora', name='Cora')

数据集的使用

model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

作业

请通过继承Data类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。

from torch_geometric.data import Data


class MyData(Data):
    def __init__(self, author, author_edge, institution, institution_edge, paper, paper_edge, **kwargs):
        super().__init__(**kwargs)
        self.author = author
        self.author_edge = author_edge
        self.institution = institution
        self.institution_edge = institution_edge
        self.paper = paper
        self.paper_edge = paper_edge
        
    def get_author_num(self):
        return len(self.author)
    
    def get_institution_num(self):
        return len(self.institution)
    
    def get_paper_num(self):
        return len(self.paper)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值