GCN点分类
1 目标任务
- 获取公共数据集Cora数据
- 设计双层GCN网络
- 实现7分类(节点分类)
- 利用TSNE进行降维
- 可视化降维后的点
2 代码实现
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch
from torch_geometric.transforms import normalize_features
from sklearn.manifold import TSNE
import os
import matplotlib.pyplot as plt
# 获取数据集
def get_data(root_data):
if os.path.exists(root_data) and os.path.exists('./data'):
pass
else:
os.mkdir(root_data)
data = Planetoid(root=root_data,name='Cora',transform=normalize_features)
print('节点数:',data.data.num_node_features)
print('边数:', data.data.num_edge_features)
print('特征数:',data.data.num_features)
print('特征维度:', data.data.x.shape)
print('邻接矩阵维度:', data.data.edge_index.shape)
return data,data.data.num_features
# 降维并可视化节点
def visual(features,color):
t_SNE = TSNE(n_components=2)
out = t_SNE.fit_transform(features)
plt.title('visualization')
plt.figure(figsize=(7,7))
plt.scatter(out[:,0],out[:,1],c=color)
plt.show()
# 构建图卷积GCN模型
class GCN(nn.Module):
def __init__(self,num_features,out_features):
super(GCN, self).__init__()
self.GCNConv1 = GCNConv(num_features,16)
self.GCNConv2 = GCNConv(16, out_features)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
def forward(self,x,adj_matrix):
out1 = self.GCNConv1(x, adj_matrix)
out1 = self.relu(out1)
out1 = self.dropout(out1)
out2 = self.GCNConv2(out1, adj_matrix)
out = self.relu(out2)
return out
# 训练模型并进行测试
def train_model(data,epochs,model):
adam = torch.optim.Adam(model.parameters(),lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
criterion.to(device)
for epoch in range(epochs):
adam.zero_grad()
data.x.to(device)
data.edge_index.to(device)
result_out = model(data.x,data.edge_index)
loss = criterion(result_out[data.train_mask],data.y[data.train_mask])
loss.backward()
adam.step()
print(f'epoch: {epoch} loss:{loss.sum()}')
model.eval()
preds = model(data.x,data.edge_index)
preds_results = torch.argmax(preds,dim=1)
acc = torch.sum(preds_results[data.val_mask] == data.y[data.val_mask])/torch.sum(data.val_mask)
print('分类准确率:',acc)
visual(preds.cpu().detach().numpy(),color=data.y)
if __name__ == "__main__":
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
root_data = r"./data/planetoid"
data,num_features = get_data(root_data)
epochs = 500
model = GCN(num_features,7)
train_model(data.data, epochs, model)
3 结果展示