import torch
import torch_geometric
import gc
class Data(object):
def __init__(self, lunwen=None, jigou=None, zuozhe=None, edge_index=None, edge_a2p=None, edge_a2i=None, y=None, **kwargs):
#**kwargs 传入的参数是 dict 类型
# Args:
# x (Tensor, optional): 节点属性矩阵,大小为`[num_nodes, num_node_features]`这里是多少行,多少列的意思!!!!
# edge_index (LongTensor, optional): 边索引矩阵,大小为`[2, num_edges]`,第1行为尾节点,第0行为头节点,头指向尾
# edge_attr (Tensor, optional): 边属性矩阵,大小为`[num_edges, num_edge_features]`
# y (Tensor, optional): 节点或图的标签,任意大小(,其实也可以是边的标签)
self.edge_index = edge_index
self.edge_a2p = edge_a2p
self.edge_a2i = edge_a2i
self.y = y
self.pap = lunwen
self.aut = zuozhe
self.ins = jigou
if torch_geometric.is_debug_enabled():
self.debug()
def get_node_num(self, k):
if k == 'author':
return self.aut.shape[0]
if k == 'institude':
return self.ins.shape[0]
if k == 'paper':
return self.pap.shape[0]
graph_data = Data(lunwen= torch.randn(4,3), jigou=torch.randn(4,3),zuozhe=torch.randn(4,3),edge_a2i=torch.tensor([4]),
edge_a2p = torch.tensor([4]),edge_index = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3],[4, 5, 6, 7, 8, 9, 10, 11]], dtype=torch.long))
print("author的节点数量:", graph_data.get_node_num('author'))
print("paper的节点数量:", graph_data.get_node_num('paper'))
print("institude的节点数量:", graph_data.get_node_num('institude'))
gc.collect()
感觉变成了数据游戏=-=
最后的最后才明白x对应的行为num_node,列为feature。。。
问题是,并不能简简单单的把源函数里的函数复制过来。