作业
请通过继承Data类实现一个类,专门用于表示“机构-作者-论文”的网络。该网络
包含“机构“、”作者“和”论文”三类节点,以及“作者-机构“和“作者-论文“两类
边。对要实现的类的要求:1)用不同的属性存储不同节点的属性;2)用不同的
属性存储不同的边(边没有属性);3)逐一实现获取不同节点数量的方法。
Data类
class Data(object):
def __init__(self, x=None, edge_index=None, edge_attr=None,
y=None, **kwargs):
r"""
Args:
x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes,
num_node_features]`
edge_index (LongTensor, optional): 边索引矩阵,大小为`[2,
num_edges]`,第0行可称为头(head)节点、源(source)节点、邻接节点,第
1行可称为尾(tail)节点、目标(target)节点、中心节点
edge_attr (Tensor, optional): 边属性矩阵,大小为
`[num_edges, num_edge_features]`
y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是
边的标签)
"""
self.x = x
self.edge_index = edge_index
self.edge_attr = edge_attr
self.y = y
for key, item in kwargs.items():
if key == 'num_nodes':
self.__num_nodes__ = item
else:
self[key] = item
答案
from torch_geometric.data import Data
'''
【OAP,机构-作者-论文】
O-Orginazation,机构;
A-Author,作者;
P-Paper,论文
'''
class OPA_Data(Data):
def __init__(self, x_O=None, x_A=None, x_P=None, edge_index_A_O=None, edge_index_A_P=None, y=None, **kwargs):
super().__init__(y, **kwargs)
self.x_O = x_O # 机构类节点
self.x_A = x_A # 作者类节点
self.x_P = x_P # 论文类节点
self.edge_index_A_O = edge_index_A_O # 作者-机构边的序号
self.edge_index_A_P = edge_index_A_P # 作者-论文边的序号
def num_nodes_O(self):
return self.x_O.shape[0]//第一维度的大小
def num_nodes_A(self):
return self.x_A.shape[0]
def num_nodes_P(self):
return self.x_P.shape[0]
def num_edges_AO(self):
return self.edge_index_A_O.shape[1]//第二维度的大小
def num_edges_AP(self):
return self.edge_index_A_P.shape[1]
@classmethod
def from_dict(cls, dictionary):
data = cls()
for key, item in dictionary.items()://词典不可遍历,dictionary.items()以列表形式、元素为元组 返回字典中所有键值对
data[key] = item
return data
#输入原始数据
OPA_graph_dict = {
'x_O': x_O,
'x_A': x_A,
'x_P': x_P,
'edge_index_A_O': edge_index_A_O,
'edge_index_A_P': edge_index_A_P,
'y': y,
'other_attr': other_attr
}
# 转dict对象为Data对象
OPA_graph_data = OPA_Data.from_dict(OPA_graph_dict)
# 获取OPA图上不同节点、不同边的数量
print(f'Number of orginazation nodes:{OPA_graph_data.num_nodes_O}') # 节点数量
print(f'Number of author nodes:{OPA_graph_data.num_nodes_A}') # 机构数量
print(f'Number of paper nodes:{OPA_graph_data.num_nodes_P}') # 论文数量
print(f'Number of author-orginazation edges:{OAP_graph_data.num_edges_AO}') # 作者-机构边数量
print(f'Number of author-paper edges: {OAP_graph_data.num_edges_AP}') # 作者-论文边数量