GCN初步尝试

任务如下:

读取cora数据集,有2708个节点,每个节点有1433个特征,每个节点属于7类中的一类。节点之间存在边
	注:cora的节点标号不是从0开始计数,故需要进行处理

附上一份简约代码:

  • 建立一个两层的GCN
  • 这里的edges是经过节点重新处理过后的
    • 有一条边(35,1033)如第一条e[0][0]=35,e[0][1]=1033,这是原始的边
    • 因为features里的节点是乱序的,所以,需要把(35,1033)映射到features乱序里的位置,这里用了dict。现在(35,1033) -> (163,402)
    • 此时我们输送给模型的features,edges。edges里的哪两行对应的节点也就是features里哪两行对应的特征
import torch
import numpy as np
import scipy.sparse as sp
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GATConv,SAGEConv
from torch_geometric.datasets import Planetoid

class GCN(torch.nn.Module):
    def __init__(self,feature=1433, hidden=16, classes=7):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(feature, hidden)
        self.conv2 = GCNConv(hidden, classes)
    def forward(self, features, edges):# features:{2708 * 1433},edges:{2 * 10858}
        features = self.conv1(features, edges) # 2708 * 16
        features = F.relu(features)
        features = F.dropout(features, training=self.training)
        features = self.conv2(features, edges) # 2708 * 7
        return F.log_softmax(features, dim=1)


def encode_onehot(labels):                                   # 把标签转换成onehot
    classes = set(labels)
    classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}
    labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)
    return labels_onehot

def normalize(mx):                                          # 归一化
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def  load_data(path="./dataset", dataset="cora"):
    idx_features_labels = np.genfromtxt("{}/{}.content".format(path, dataset),# 读取节点特征和标签
                                        dtype=np.dtype(str))
    features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32) # 读取节点特征
    dict = {int(element):i for i,element in enumerate(idx_features_labels[:, 0:1].reshape(-1))}    # 建立字典
    labels = encode_onehot(idx_features_labels[:, -1])                       # 标签用onehot方式表示
    e = np.genfromtxt("{}/{}.cites".format(path, dataset), dtype=np.int32)    # 读取边信息
    edges = []
    for i, x in enumerate(e):
        edges.append([dict[e[i][0]], dict[e[i][1]]])                         # 若A->B有变 则B->A 也有边
        edges.append([dict[e[i][1]], dict[e[i][0]]])                         # 给的数据是没有从0开始需要转换
    features = normalize(features)                                           # 特征值归一化
    features = torch.tensor(np.array(features.todense()), dtype=torch.float32)
    labels = torch.LongTensor(np.where(labels)[1])
    edges = torch.tensor(edges, dtype=torch.int64).T
    return features, edges, labels

if __name__ == '__main__':
    features, edges, labels = load_data()
    idx_train = range(2000)                                       # 其中2000个点是训练数据
    idx_test = range(2000, 2700)                                  # 700个测试数据
    idx_train = torch.LongTensor(idx_train)
    idx_test = torch.LongTensor(idx_test)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GCN(1433, 16, 7)  # 每一个节点有 1433个特征
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # 梯度优化算法

    model.train()
    for epoch in range(300):
        optimizer.zero_grad()
        out = model(features, edges)  # features: 2708 * 1433, edges: 2 * 10858
        loss = F.nll_loss(out[idx_train], labels[idx_train])  # 损失函数
        loss.backward()
        optimizer.step()
        print(f"epoch:{epoch + 1}, loss:{loss.item()}")

    model.eval()
    _, pred = model(features, edges).max(dim=1)
    correct = pred[idx_test].eq(labels[idx_test]).sum()  # 计算预测与标签相等个数
    acc = int(correct) / int(len(idx_test))  # 计算正确率
    print(acc)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值