原文链接:
https://colab.research.google.com/drive/14OvFnAXggxB8vM4e8vSURUp1TaKnovzX?usp=sharing
import torch.nn
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.datasets import Planetoid
from torch_geometric import nn
from torch.nn import Module
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# 下载数据
dataset=Planetoid(root='./dataset', name='Cora', transform=NormalizeFeatures())
data=dataset[0]
print(data)
# print(data.train_mask.sum())
# print(data.test_mask.sum())
#可视化
def visualize(h, color):
z=TSNE().fit_transform(h.detach().numpy())#detach:返回与张量相同的tensor,但没有梯度 #TSNE:将高维数据降维,默认是降成两维
plt.figure(figsize=(10, 10)) #规范画布尺寸
plt.xticks([]) #空列表意味着不显示坐标轴刻度
plt.yticks([])
plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap='Set2')
#以降维后的数据画散点图,点的大小s=70,点的颜色有color种,颜色条是set2
plt.show()
# 搭建模型
class GCN(Module):
def __init__(self, hidden_features):
super(GCN, self).__init__()
self.conv1=nn.GCNConv(dataset.num_features, hidden_features)
self.conv2=nn.GCNConv(hidden_features, dataset.num_classes)
def forward(self, x, edge_index):
x=self.conv1(x, edge_index)
x=F.relu(x)
x=F.dropout(x, p=0.5, training=self.training)
x=self.conv2(x, edge_index)
return x
#初始化模型
model=GCN(hidden_features=16)
Loss_F=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#开始训练
def train():
model.train()
optimizer.zero_grad()
output=model(data.x, data.edge_index)
train_loss=Loss_F(output[data.train_mask], data.y[data.train_mask])
train_loss.backward()
optimizer.step()
return train_loss
def test():
model.eval()
output=model(data.x, data.edge_index)
num_correct=(output[data.test_mask].argmax(dim=1)==data.y[data.test_mask]).sum()
acc=num_correct/data.test_mask.sum()
return acc
writer=SummaryWriter('./logs_train')
epoch=200
step=0
for i in range(epoch):
loss=train()
step=step+1
print('第{}次训练的loss:{}'.format(i+1, loss))
writer.add_scalar('train_loss', loss, step)
test_acc=test()
print('测试集上的准确率为:{}'.format(test_acc))
writer.close()
#绘图
model.eval()
out=model(data.x, data.edge_index)
visualize(out, data.y)#c=data.y:相同标签的节点具有相同的颜色