"""
from: https://github.com/yanx27/Pointnet_Pointnet2_pytorch/blob/master/models/pointnet_part_seg.py
没看懂? C + label ?
out_max = torch.cat([out_max,label.squeeze(1)],1) # label ?
expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) # concat:[B, 4944, N]
part_segmentation => output:[B, N, C], C = part_num=50
"""
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.nn.functional as F
from pointnet_utils import STN3d, STNkd, feature_transform_reguliarzer
class get_model(nn.Module):
def __init__(self, part_num=50, normal_channel=True):
super(get_model, self).__init__()
if normal_channel:
channel = 6
else:
channel = 3
self.part_num = part_num
self.stn = STN3d(channel)
self.conv1 = torch.nn.Conv1d(channel, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 128, 1)
self.conv4 = torch.nn.Conv1d(128, 512, 1)
self.conv5 = torch.nn.Conv1d(512, 2048, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(128)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(2048)
self.fstn = STNkd(k=128)
self.convs1 = torch.nn.Conv1d(4944, 256, 1)
self.convs2 = torch.nn.Conv1d(256, 256, 1)
self.convs3 = torch.nn.Conv1d(256, 128, 1)
self.convs4 = torch.nn.Conv1d(128, part_num, 1)
self.bns1 = nn.BatchNorm1d(256)
self.bns2 = nn.BatchNorm1d(256)
self.bns3 = nn.BatchNorm1d(128)
def forward(self, point_cloud, label):
B, D, N = point_cloud.size()
trans = self.stn(point_cloud)
point_cloud = point_cloud.transpose(2, 1) # point_cloud:[B,N,D]
if D > 3:
point_cloud, feature = point_cloud.split(3, dim=2)
point_cloud = torch.bmm(point_cloud, trans)
if D > 3:
point_cloud = torch.cat([point_cloud, feature], dim=2)
point_cloud = point_cloud.transpose(2, 1) # point_cloud:[B,D,N]
out1 = F.relu(self.bn1(self.conv1(point_cloud)))
out2 = F.relu(self.bn2(self.conv2(out1)))
out3 = F.relu(self.bn3(self.conv3(out2))) # out3:[B, 128, N]
trans_feat = self.fstn(out3)
x = out3.transpose(2, 1) # x:[B, N, 128]
net_transformed = torch.bmm(x, trans_feat)
net_transformed = net_transformed.transpose(2, 1)
out4 = F.relu(self.bn4(self.conv4(net_transformed)))
out5 = self.bn5(self.conv5(out4)) # out3:[B, 2048, N]
out_max = torch.max(out5, 2, keepdim=True)[0]
out_max = out_max.view(-1, 2048) # out_max:[B, 2048]
out_max = torch.cat([out_max,label.squeeze(1)],1) # label ?
expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, N)
concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) # concat:[B, 4944, N]
net = F.relu(self.bns1(self.convs1(concat)))
net = F.relu(self.bns2(self.convs2(net)))
net = F.relu(self.bns3(self.convs3(net)))
net = self.convs4(net) # net:[B, part_num=50, N]
net = net.transpose(2, 1).contiguous() # net:[B, N, part_num=50]
net = F.log_softmax(net.view(-1, self.part_num), dim=-1) # net:[B * N, part_num=50]
net = net.view(B, N, self.part_num) # [B, N, 50]
return net, trans_feat
class get_loss(torch.nn.Module):
def __init__(self, mat_diff_loss_scale=0.001):
super(get_loss, self).__init__()
self.mat_diff_loss_scale = mat_diff_loss_scale
def forward(self, pred, target, trans_feat):
loss = F.nll_loss(pred, target)
mat_diff_loss = feature_transform_reguliarzer(trans_feat)
total_loss = loss + mat_diff_loss * self.mat_diff_loss_scale
return total_loss
【点云网络】pointnet_part_seg.py
于 2023-02-13 15:07:52 首次发布