datawhalechina-GNN组队学习 作业:图的类

作业

请通过继承Data类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络
包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类
边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的
属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。


Data类

class Data(object):
def __init__(self, x=None, edge_index=None, edge_attr=None,
y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes,
num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2,
num_edges]`,第0行可称为头(head)节点、源(source)节点、邻接节点,第
1行可称为尾(tail)节点、目标(target)节点、中心节点
edge_attr (Tensor, optional): 边属性矩阵,大小为
`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是
边的标签)
"""
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item

答案

from torch_geometric.data import Data
'''
【OAP,机构-作者-论文】
O-Orginazation,机构;
A-Author,作者;
P-Paper,论文
'''
class OPA_Data(Data):
    
    def __init__(self,  x_O=None, x_A=None, x_P=None, edge_index_A_O=None, edge_index_A_P=None, y=None, **kwargs):
        super().__init__(y, **kwargs)
        self.x_O = x_O  # 机构类节点
        self.x_A = x_A  # 作者类节点
        self.x_P = x_P  # 论文类节点
        self.edge_index_A_O = edge_index_A_O  # 作者-机构边的序号
        self.edge_index_A_P = edge_index_A_P  # 作者-论文边的序号
        
    def num_nodes_O(self):
        return self.x_O.shape[0]//第一维度的大小
    
    def num_nodes_A(self):
        return self.x_A.shape[0]
    
    def num_nodes_P(self):
        return self.x_P.shape[0]
    
    def num_edges_AO(self):
        return self.edge_index_A_O.shape[1]//第二维度的大小
    
    def num_edges_AP(self):
        return self.edge_index_A_P.shape[1]
    
    @classmethod
    def from_dict(cls, dictionary):
        data = cls()
        for key, item in dictionary.items()://词典不可遍历,dictionary.items()以列表形式、元素为元组 返回字典中所有键值对
            data[key] = item
            
        return data
#输入原始数据    
OPA_graph_dict = {
    'x_O': x_O,
    'x_A': x_A,
    'x_P': x_P,
    'edge_index_A_O': edge_index_A_O,
    'edge_index_A_P': edge_index_A_P,
    'y': y,
    'other_attr': other_attr
}

# 转dict对象为Data对象
OPA_graph_data = OPA_Data.from_dict(OPA_graph_dict)

# 获取OPA图上不同节点、不同边的数量
print(f'Number of orginazation nodes:{OPA_graph_data.num_nodes_O}') # 节点数量
print(f'Number of author nodes:{OPA_graph_data.num_nodes_A}') # 机构数量
print(f'Number of paper nodes:{OPA_graph_data.num_nodes_P}') # 论文数量
print(f'Number of author-orginazation edges:{OAP_graph_data.num_edges_AO}') # 作者-机构边数量
print(f'Number of author-paper edges: {OAP_graph_data.num_edges_AP}') # 作者-论文边数量
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值