图神经网络、GCN实现图内点云分类任务(物体的部件分类)


本项目是一个简单的使用图中点分类代码,内涵完整的网络搭建、模型训练、模型保存、模型调用、可视化、的全过程。可以帮助初学者快速熟悉流程。帮助入门。

数据集下载

关键代码

项目使用了shapenet数据集中的飞机类数据集,在使用图神经网络飞机上进行部件分割,本项目写了一个自动下载数据集的方法,直接运行项目会自动下载数据集。关键部分代码如下。

def load_data():
    path = './data'
    if not os.path.exists(path):
        # 如果目录不存在,则创建它
        os.makedirs(path)
        print(f"目录 '{path}' 已创建。")
    else:
        print(f"目录 '{path}' 已存在。")
    train_data = ShapeNet(root=path, categories=['Airplane'], split='trainval', pre_transform=T.KNNGraph(k=6))
    test_data = ShapeNet(root=path, categories=['Airplane'], split='test', pre_transform=T.KNNGraph(k=6))
    return train_data, test_data
// A code block
这个代码先检查了存储路径是否存在,不存在创建一个,之后在直接下载数据集。
我们下载的时候选定了预处理pre_transform=T.KNNGraph(k=6),会预先使用
k临近算法给图数据生成边。这个方法下载了训练集和测试集。

数据集结构

 t, t1 = load_data()
 print(t)
 print(t[0])
 print(len(t))
// 输出
ShapeNet(2349, categories=['Airplane'])
# 这句的意思一共2349个图,都是飞机类
Data(x=[2518, 3], y=[2518], pos=[2518, 3], category=[1], 
edge_index=[2, 15108])
# 这里是选中了一张图看看里面的结构,x是点特征每个点是三维的,意思是一个
# 点用三个数表示,可以联想成空间里面点的x,y,z吧。pos是空间的点坐标也是
# 2518个。y是各个点的标签,用来表示这个点是属于那个部件的。category
# 表示整个图是什么类,这里只有飞机类,所以说只有一个数。edge_index是
# 邻接矩阵,指明那个点之间有连接
2349

网络模型

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, num_classes)

    def forward(self, data, edge_index):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = data.to(device)
        edge_index = edge_index.to(device)
        x = data.x
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = self.lin(x)

        return x

1 __init__ 方法:
in_channels:输入特征的维度,即每个节点的特征向量的维度。
hidden_channels:隐藏层的维度,代表在每一层中节点特征被投影到的空间维度。
num_classes:输出类别的数量,表示模型要分类的类别数。
该方法初始化了网络的各个层,包括四个图卷积层(GCNConv)和一个线性层(nn.Linear):
conv1: 将输入特征从 in_channels 映射到 hidden_channels。
conv2, conv3, conv4: 每一层将特征从 hidden_channels 映射到相同的 hidden_channels。
lin: 线性层,将最后一个隐藏层的输出特征映射到 num_classes,用于最终的分类任务。

2 forward 方法: 该方法定义了前向传播的过程,即输入数据通过网络时的运算流程。
device:自动检测并选择使用 GPU(如果可用)或 CPU 作为计算设备。
data.to(device) 和edge_index.to(device): 将输入数据和边信息移动到指定的设备(CPUGPU)。
x = data.x: 提取输入节点的特征矩阵 x。
conv1 -> conv4: 执行四次图卷积操作,每次卷积后使用 ReLU 激活函数。ReLU 引入非线性,使模型能够学习复杂的模式。
lin: 最后一层是线性层,将最终的图卷积输出特征映射到类别空间,用于分类任务。

模型训练

训练关键代码如下:

def train_model(epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_, test_ = load_data()
    num_class = len(np.unique(train_[0].y))
    net = Net(in_channels=3, hidden_channels=32, num_classes=num_class).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    ok = None
    for epoch in range(epochs):
        train_loss, train_acc = 0, 0
        max_acc = 0
        for i in tqdm(range(len(train_)), desc='{}/{}'.format(epoch + 1, epochs)):
            data = train_[i].to(device)
            optimizer.zero_grad()
            output = net(data, data.edge_index)
            loss = criterion(output, data.y)
            # print(loss)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, preds = torch.max(output, 1)
            train_acc += (preds == data.y).sum().item() / len(data.y)
        avg_loss = train_loss / len(data.y)
        avg_acc = train_acc / len(data.y)
        if avg_acc > max_acc:
            max_acc = avg_acc
            torch.save(net.state_dict(), './models/best_{}.pt'.format(max_acc))
            print('找到更高准确率的模型,准确率为{}'.format(max_acc))
        print('平均损失{}   平均准确率{}'.format(avg_loss, avg_acc))

这段代码定义了一个训练方法,接受一个epochs,指明要训练多少次使用了CrossEntropyLoss做损失函数。训练过程中每个epoch检查一遍模型准确率若准确率比前者高就保存本次训练模型。
训练过程截图:
在这里插入图片描述

测试模型+可视化结果

关键代码:

# 使用open3D进行可视化
def visualize(data_x, data_y):
    # 创建点云对象
    data_x = data_x.cpu().detach().numpy()
    data_y = data_y.cpu().detach().numpy()
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(data_x)

    # 创建颜色映射字典
    color_map = {
        0: [1, 0, 0],  # 类别0:红色
        1: [0, 1, 0],  # 类别1:绿色
        2: [0, 0, 1],  # 类别2:蓝色
        3: [0, 1, 1],  # 类别3:青色
        # 添加更多类别及其颜色
    }

    # 将标签转换为颜色
    colors = np.array([color_map.get(label, [0, 0, 0]) for label in data_y])
    pcd.colors = o3d.utility.Vector3dVector(colors)

    # 可视化点云
    o3d.visualization.draw_geometries([pcd])

//进行测试并且可视化,这里就拿了其中一张图进行可视化
import numpy as np
from Train import load_data
from Net import Net
import os
import torch
from visulizd import visualize
from tqdm import tqdm
if __name__ == '__main__':
    # 取第120张图进行分类
    i = 120
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train, test = load_data()
    num_c = len(np.unique(test[0].y))
    # test = test.to(device)
    model = Net(in_channels=3, hidden_channels=32, num_classes=num_c)
    model.load_state_dict(torch.load("./models/best_0.7097560044925633.pt"))
    model.eval()
    model.to(device)
    out = model(test[i], test[i].edge_index)
    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(out, test[i].y.to(device))
    acc = out.argmax(dim=1).eq(test[i].y.to(device)).sum().item() / len(test[i].y)
    print('损失{}  准确率{}'.format(loss, acc))
    visualize(test[i].pos, out.argmax(dim=1))


这里我只用了训练20次的模型,若想追求更高准确率,可以增加训练次数和优化网络模型结构
在这里插入图片描述

在这里插入图片描述

可能会出现的问题

pyg没有配置好

可以查看我这个博客,能够完美配置好pyg
链接: 完美配置pyg

懒人专属(代码链接)

链接: 代码链接

  • 13
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码飞速跑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值