【点云网络】pointnet_part_seg.py

"""
from: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_part_seg.py


没看懂? C + label ?
    out_max = torch.cat([out_max,label.squeeze(1)],1) # label ?
    expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
    concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) # concat:[B, 4944, N]

part_segmentation => output:[B, N, C], C = part_num=50
"""

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.nn.functional as F
from pointnet_utils import STN3d, STNkd, feature_transform_reguliarzer


class get_model(nn.Module):
    def __init__(self, part_num=50, normal_channel=True):
        super(get_model, self).__init__()
        if normal_channel:
            channel = 6
        else:
            channel = 3
        self.part_num = part_num
        self.stn = STN3d(channel)
        self.conv1 = torch.nn.Conv1d(channel, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 128, 1)
        self.conv4 = torch.nn.Conv1d(128, 512, 1)
        self.conv5 = torch.nn.Conv1d(512, 2048, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(128)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(2048)
        self.fstn = STNkd(k=128)
        self.convs1 = torch.nn.Conv1d(4944, 256, 1)
        self.convs2 = torch.nn.Conv1d(256, 256, 1)
        self.convs3 = torch.nn.Conv1d(256, 128, 1)
        self.convs4 = torch.nn.Conv1d(128, part_num, 1)
        self.bns1 = nn.BatchNorm1d(256)
        self.bns2 = nn.BatchNorm1d(256)
        self.bns3 = nn.BatchNorm1d(128)

    def forward(self, point_cloud, label):
        B, D, N = point_cloud.size()
        trans = self.stn(point_cloud)
        point_cloud = point_cloud.transpose(2, 1) # point_cloud:[B,N,D]
        if D > 3:
            point_cloud, feature = point_cloud.split(3, dim=2)
        point_cloud = torch.bmm(point_cloud, trans)
        if D > 3:
            point_cloud = torch.cat([point_cloud, feature], dim=2)

        point_cloud = point_cloud.transpose(2, 1) # point_cloud:[B,D,N]

        out1 = F.relu(self.bn1(self.conv1(point_cloud)))
        out2 = F.relu(self.bn2(self.conv2(out1)))
        out3 = F.relu(self.bn3(self.conv3(out2))) # out3:[B, 128, N]

        trans_feat = self.fstn(out3)
        x = out3.transpose(2, 1) # x:[B, N, 128]
        net_transformed = torch.bmm(x, trans_feat)
        net_transformed = net_transformed.transpose(2, 1)

        out4 = F.relu(self.bn4(self.conv4(net_transformed)))
        out5 = self.bn5(self.conv5(out4))  # out3:[B, 2048, N]
        out_max = torch.max(out5, 2, keepdim=True)[0]
        out_max = out_max.view(-1, 2048) # out_max:[B, 2048]

        out_max = torch.cat([out_max,label.squeeze(1)],1) # label ?
        expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
        concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) # concat:[B, 4944, N]
        net = F.relu(self.bns1(self.convs1(concat)))
        net = F.relu(self.bns2(self.convs2(net)))
        net = F.relu(self.bns3(self.convs3(net)))
        net = self.convs4(net) # net:[B, part_num=50, N]
        net = net.transpose(2, 1).contiguous() # net:[B, N, part_num=50]
        net = F.log_softmax(net.view(-1, self.part_num), dim=-1) # net:[B * N, part_num=50]
        net = net.view(B, N, self.part_num) # [B, N, 50]

        return net, trans_feat


class get_loss(torch.nn.Module):
    def __init__(self, mat_diff_loss_scale=0.001):
        super(get_loss, self).__init__()
        self.mat_diff_loss_scale = mat_diff_loss_scale

    def forward(self, pred, target, trans_feat):
        loss = F.nll_loss(pred, target)
        mat_diff_loss = feature_transform_reguliarzer(trans_feat)
        total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
        return total_loss
PointNet2是一种针对点云分类和分割任务的深度学习框架。PointNet2_Part_Seg_SSG是基于PointNet2框架的一个应用,用于点云部分分割任务。 PointNet2使用了一种层级的神经网络结构,能够有效地处理无序的点云数据。它将点云分为多个局部区域,对每个区域进行特征提取,最后整合局部特征得到全局特征表达。这种设计能够提取点云的局部和全局特征,从而实现对点云数据的分类和分割。 PointNet2_Part_Seg_SSG是PointNet2框架的一种改进,主要针对点云的部分分割任务。它使用了SSG(Single-Scale Grouping)模块,通过分组聚合点的特征,从而对点云进行细分。SSG模块首先选择每个局部区域中的中心点,并将其他点分配给最近的中心点。然后,SSG模块对每个中心点的邻域进行特征提取和聚合,得到该局部区域的特征表示。最后,通过进一步的卷积和池化操作,得到点云的全局特征表示。 在训练过程中,PointNet2_Part_Seg_SSG使用交叉熵损失函数来度量预测的分割结果与真实标签之间的差异。通过反向传播算法,可以优化网络的参数,使得网络能够更好地学习点云的特征表示和分割任务。 总的来说,PointNet2_Part_Seg_SSG是基于PointNet2框架的一个改进版本,专门用于点云的部分分割任务。它通过采用SSG模块,能够对点云进行更精细的细分和特征提取,从而提高了点云分割任务的准确性和效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值