torch-points3d(1.2.0版本)示例代码

https://github.com/nicolas-chaulet/torch-points3d/issues/594

环境配置:

https://pytorch-geometric.com/whl/torch-1.7.0.html

Miniconda3-py37_4.9.2-Linux-x86_64.sh

torch-1.7.0+cu101-cp37-cp37m-linux_x86_64.whl

torchvision-0.8.1+cu101-cp37-cp37m-linux_x86_64.whl

python -mpip install pytest-runner

pip版本20.1

torch_cluster-1.5.8+cu101-cp37-cp37m-linux_x86_64.whl

torch_scatter-2.0.5+cu101-cp37-cp37m-linux_x86_64.whl

torch_sparse-0.6.8+cu101-cp37-cp37m-linux_x86_64.whl

torch_spline_conv-1.2.1-cp36-cp36m-linux_x86_64.whl

python -mpip install torch-geometric

pip install torch-points3d

在这个过程中会伴随安装torch-points-kernels,会有一点儿慢,稍等

还有关于可视化的安装,建议以下版本组合,不推荐vtk9.0以上,会出问题

pyvista                 0.26.1

vtk                     8.1.2

最终环境:

Package                 Version
----------------------- -------------------
absl-py                 0.12.0
appdirs                 1.4.4
argon2-cffi             20.1.0
ase                     3.21.1
async-generator         1.10
attrs                   20.3.0
backcall                0.2.0
bleach                  3.3.0
brotlipy                0.7.0
cached-property         1.5.2
cachetools              4.2.1
certifi                 2020.6.20
cffi                    1.14.3
chardet                 3.0.4
click                   7.1.2
conda                   4.9.2
conda-package-handling  1.7.2
configparser            5.0.2
cryptography            3.2.1
cycler                  0.10.0
dataclasses             0.6
decorator               4.4.2
defusedxml              0.7.1
docker-pycreds          0.4.0
entrypoints             0.3
filelock                3.0.12
future                  0.18.2
gdown                   3.12.2
gitdb                   4.0.7
GitPython               3.1.14
google-auth             1.29.0
google-auth-oauthlib    0.4.4
googledrivedownloader   0.4
gql                     0.2.0
graphql-core            1.1
grpcio                  1.37.0
h5py                    3.2.1
hydra-core              0.11.3
idna                    2.10
imageio                 2.9.0
importlib-metadata      4.0.1
ipykernel               5.5.3
ipython                 7.22.0
ipython-genutils        0.2.0
ipywidgets              7.6.3
isodate                 0.6.0
jedi                    0.18.0
Jinja2                  2.11.3
joblib                  1.0.1
jsonpatch               1.32
jsonpointer             2.1
jsonschema              3.2.0
jupyter-client          6.1.12
jupyter-core            4.7.1
jupyterlab-pygments     0.1.2
jupyterlab-widgets      1.0.0
kiwisolver              1.3.1
llvmlite                0.32.1
Markdown                3.3.4
MarkupSafe              1.1.1
matplotlib              3.4.1
meshio                  4.3.13
mistune                 0.8.4
nbclient                0.5.3
nbconvert               6.0.7
nbformat                5.1.3
nest-asyncio            1.5.1
networkx                2.5.1
notebook                6.3.0
numba                   0.49.1
numpy                   1.20.2
nvidia-ml-py3           7.352.0
oauthlib                3.1.0
omegaconf               1.4.1
open3d                  0.9.0.0
packaging               20.9
pandas                  1.2.4
pandocfilters           1.4.3
parso                   0.8.2
pexpect                 4.8.0
pickleshare             0.7.5
Pillow                  8.2.0
pip                     20.2.4
plyfile                 0.7.3
prometheus-client       0.10.1
promise                 2.3
prompt-toolkit          3.0.18
protobuf                3.15.8
psutil                  5.8.0
ptyprocess              0.7.0
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pycosat                 0.6.3
pycparser               2.20
Pygments                2.8.1
pyOpenSSL               19.1.0
pyparsing               2.4.7
pyrsistent              0.17.3
PySocks                 1.7.1
python-dateutil         2.8.1
python-louvain          0.15
pytorch-metric-learning 0.9.99.dev0
pytz                    2021.1
pyvista                 0.26.1
PyWavelets              1.1.1
PyYAML                  5.4.1
pyzmq                   22.0.3
rdflib                  5.0.0
requests                2.24.0
requests-oauthlib       1.3.0
rsa                     4.7.2
ruamel-yaml             0.15.87
scikit-image            0.16.2
scikit-learn            0.24.1
scipy                   1.6.3
scooby                  0.5.7
Send2Trash              1.5.0
sentry-sdk              1.0.0
setuptools              50.3.1.post20201107
shortuuid               1.0.1
six                     1.15.0
smmap                   4.0.0
subprocess32            3.5.4
tensorboard             2.5.0
tensorboard-data-server 0.6.0
tensorboard-plugin-wit  1.8.0
terminado               0.9.4
testpath                0.4.4
threadpoolctl           2.1.0
torch                   1.7.0+cu101
torch-cluster           1.5.8
torch-geometric         1.7.0
torch-points-kernels    0.6.10
torch-points3d          1.2.0
torch-scatter           2.0.5
torch-sparse            0.6.8
torchfile               0.1.0
torchnet                0.0.4
torchvision             0.8.1+cu101
tornado                 6.1
tqdm                    4.51.0
traitlets               5.0.5
transforms3d            0.3.1
typing-extensions       3.7.4.3
urllib3                 1.25.11
visdom                  0.1.8.9
vtk                     8.1.2
wandb                   0.8.36
watchdog                2.0.3
wcwidth                 0.2.5
webencodings            0.5.1
websocket-client        0.58.0
Werkzeug                1.0.1
wheel                   0.35.1
widgetsnbextension      3.5.1
zipp                    3.4.1

工程上的问题: 

项目原始的环境已经安装了如下重要的库:
mmdetv2.8
python-pcl(依赖numpy==1.17.5)
open3d==0.9.0(0.10.0以上在ubuntu上默认不能直接使用)
cuda10.1
pytorch==1.6.0
torchvision==0.7.0
在开始之前检查nvcc -V 和torch的cuda版本
import torch
print(torch.version.cuda)


torch-point3d需要的东西和步骤是:
如果你原始环境是pytorch==1.6.0
在网址:
https://pytorch-geometric.com/whl/torch-1.6.0.html
下载安装如下:
torch_cluster-1.5.8+cu101-cp37-cp37m-linux_x86_64.whl
torch_scatter-2.0.5+cu101-cp37-cp37m-linux_x86_64.whl
torch_sparse-0.6.8+cu101-cp37-cp37m-linux_x86_64.whl
torch_spline_conv-1.2.0+cu101-cp37-cp37m-linux_x86_64.whl
python -mpip install torch-geometric==1.7.0
python -mpip install torch-points3d==1.2.0

由于这样默认会更新open3d和numpy
所以我们需要依照顺序
python -mpip uninstall open3d
python -mpip install open3d==0.9.0

python -mpip uninstall numpy
python -mpip install numpy==1.17.5

期间可能遇到的问题:
ERROR: Cannot uninstall 'llvmlite'. It is a distutils installed project and thus we cannot accurately determine which files belong to it which would lead to only a partial uninstall.

由于原始llvmlite的卸载出现的问题:

进入python 看一下当前python的安装库:
from distutils.sysconfig import get_python_lib
print(get_python_lib())
/root/anaconda3/lib/python3.7/site-packages

删除旧版本
rm -rf /root/anaconda3/lib/python3.7/site-packages/llvmlite-0.29.0-py3.7.egg-info

安装新版本
python -mpip install llvmlite==0.32.1



还有代码出现问题,找不到util,是因为没有将路径添加:
import sys
ROOT = os.path.dirname(os.path.abspath(__file__))
BASE = os.path.dirname(ROOT)
sys.path.append(BASE)
sys.path.append(os.path.dirname(BASE))
sys.path.append(os.path.dirname(os.path.dirname(BASE)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(BASE))))

1分类代码示例 

RSConvCLassifier.py
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: RSConvCLassifier.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 27, 2021
# ---
import torch
from torch_points3d.applications.rsconv import RSConv

class RSConvCLassifier(torch.nn.Module):
    def __init__(self, USE_NORMAL, MODELNET_VERSION):
        super().__init__()
        self.encoder = RSConv("encoder", input_nc=3 * USE_NORMAL, output_nc=int(MODELNET_VERSION), num_layers=4)
        self.log_softmax = torch.nn.LogSoftmax(dim=-1)

    @property
    def conv_type(self):
        """ This is needed by the dataset to infer which batch collate should be used"""
        return self.encoder.conv_type

    def get_output(self):
        """ This is needed by the tracker to get access to the ouputs of the network"""
        return self.output

    def get_labels(self):
        """ Needed by the tracker in order to access ground truth labels"""
        return self.labels

    def get_current_losses(self):
        """ Entry point for the tracker to grab the loss """
        return {"loss_class": float(self.loss_class)}

    def forward(self, data):
        # Set labels for the tracker
        self.labels = data.y.squeeze()

        # Forward through the network
        data_out = self.encoder(data)
        self.output = self.log_softmax(data_out.x.squeeze())

        # Set loss for the backward pass
        self.loss_class = torch.nn.functional.nll_loss(self.output, self.labels)

    def backward(self):
        self.loss_class.backward()
train_CLassifier1.py
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_CLassifier.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 27, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import datetime

import torch

from torch_points3d.datasets.classification.modelnet import SampledModelNet
import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T

from torch_points3d.datasets.batch import SimpleBatch

from torch_points3d.metrics.classification_tracker import ClassificationTracker

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)

import RSConvCLassifier as rs

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

def main():
    DIR = ""
    MODELNET_VERSION = "10"
    USE_NORMAL = True

    model = rs.RSConvCLassifier(USE_NORMAL=USE_NORMAL, MODELNET_VERSION=MODELNET_VERSION)

    NUM_WORKERS = 0
    BATCH_SIZE = 4
    dataroot = os.path.join(DIR, "data/modelnet")
    pre_transform = T.Compose([T.NormalizeScale(), T3D.GridSampling3D(0.02)])
    #transform = T.FixedPoints(4096)
    #~/torch-points3d-1.2.0/docs/src/api/transforms.rst
    if USE_NORMAL:
        transform = T.Compose([T.FixedPoints(4096),
                    T3D.AddFeatsByKeys(list_add_to_x=[True], feat_names=["norm"], input_nc_feats=None, stricts=None, delete_feats=[True])])
    else:
        transform = T.Compose([T.FixedPoints(4096)])

    dataset = SampledModelNet(dataroot, name=MODELNET_VERSION, train=True, transform=transform,
                              pre_transform=pre_transform, pre_filter=None)
    print(dataset[0])
    collate_function = lambda datalist: SimpleBatch.from_data_list(datalist)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_function
    )
    print(next(iter(train_loader)))

    logdir = "./train_CLassifier1"
    logdir = os.path.join(logdir, str(datetime.datetime.now()))
    mkdir_os(logdir)
    os.chdir(logdir)
    tracker = ClassificationTracker(dataset=dataset, stage="train", wandb_log=False, use_tensorboard=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train_epoch(device):
        model.to(device)
        model.train()
        tracker.reset("train")
        iter_data_time = time.time()
        with Ctq(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                optimizer.zero_grad()
                data.to(device)
                model.forward(data)
                model.backward()
                optimizer.step()
                if i % 10 == 0:
                    tracker.track(model)

                tq_train_loader.set_postfix(
                    **tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    EPOCHS = 10
    for i in range(EPOCHS):
        print("=========== EPOCH %i ===========" % i)
        time.sleep(0.5)
        train_epoch('cuda')
        tracker.publish(i)


if __name__ == '__main__':
    main()

train_CLassifier2.py

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_CLassifier.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 27, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import datetime

import torch

from torch_points3d.datasets.classification.modelnet import SampledModelNet
import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T

from torch_geometric.data import Batch

from torch_points3d.metrics.classification_tracker import ClassificationTracker

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)

import RSConvCLassifier as rs

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

def main():
    DIR = ""
    MODELNET_VERSION = "10"
    USE_NORMAL = True

    model = rs.RSConvCLassifier(USE_NORMAL=USE_NORMAL, MODELNET_VERSION=MODELNET_VERSION)

    NUM_WORKERS = 0
    BATCH_SIZE = 4
    dataroot = os.path.join(DIR, "data/modelnet")
    pre_transform = T.Compose([T.NormalizeScale(), T3D.GridSampling3D(0.02)])
    if USE_NORMAL:
        transform = T.Compose([T.FixedPoints(4096),
                    T3D.AddFeatsByKeys(list_add_to_x=[True], feat_names=["norm"], input_nc_feats=None, stricts=None, delete_feats=[True])])
    else:
        transform = T.Compose([T.FixedPoints(4096)])

    dataset = SampledModelNet(dataroot, name=MODELNET_VERSION, train=True, transform=transform,
                              pre_transform=pre_transform, pre_filter=None)
    print(dataset[0])

    #https://zhuanlan.zhihu.com/p/142948273
    #https://github.com/nicolas-chaulet/torch-points3d/issues/594
    collate_function = lambda datalist: Batch.from_data_list(datalist)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_function
    )
    print(next(iter(train_loader)))

    logdir = "./train_CLassifier2"
    logdir = os.path.join(logdir, str(datetime.datetime.now()))
    mkdir_os(logdir)
    os.chdir(logdir)
    tracker = ClassificationTracker(dataset=dataset, stage="train", wandb_log=False, use_tensorboard=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train_epoch(device):
        model.to(device)
        model.train()
        tracker.reset("train")
        iter_data_time = time.time()
        with Ctq(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                optimizer.zero_grad()
                data.to(device)
                model.forward(data)
                model.backward()
                optimizer.step()
                if i % 10 == 0:
                    tracker.track(model)

                tq_train_loader.set_postfix(
                    **tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    EPOCHS = 10
    for i in range(EPOCHS):
        print("=========== EPOCH %i ===========" % i)
        time.sleep(0.5)
        train_epoch('cuda')
        tracker.publish(i)


if __name__ == '__main__':
    main()

train_CLassifier3.py

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_CLassifier.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 27, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import datetime

import torch

from torch_points3d.datasets.classification.modelnet import SampledModelNet
import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T

from omegaconf import OmegaConf
from torch_points3d.datasets.classification.modelnet import ModelNetDataset

from torch_geometric.data import Data
from torch_geometric.data import Batch
from torch_points3d.datasets.batch import SimpleBatch

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)

import RSConvCLassifier as rs

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

def main():
    DIR = ""
    MODELNET_VERSION = "10"
    USE_NORMAL = True

    yaml_config = """
    task: classification
    class: modelnet.ModelNetDataset
    name: modelnet
    dataroot: %s
    number: %s
    pre_transforms:
        - transform: NormalizeScale
        - transform: GridSampling3D
          lparams: [0.02]
    train_transforms:
        - transform: FixedPoints
          lparams: [2048]
        - transform: RandomNoise
        - transform: RandomRotate
          params:
            degrees: 180
            axis: 2
        - transform: AddFeatsByKeys
          params:
            feat_names: [norm]
            list_add_to_x: [%r]
            delete_feats: [True]
    test_transforms:
        - transform: FixedPoints
          lparams: [2048]
        - transform: AddFeatsByKeys
          params:
            feat_names: [norm]
            list_add_to_x: [%r]
            delete_feats: [True]
    """ % (os.path.join(DIR, "data"), MODELNET_VERSION, USE_NORMAL, USE_NORMAL)
    params = OmegaConf.create(yaml_config)
    dataset = ModelNetDataset(params)
    print(dataset)

    model = rs.RSConvCLassifier(USE_NORMAL=USE_NORMAL, MODELNET_VERSION=MODELNET_VERSION)

    NUM_WORKERS = 0
    BATCH_SIZE = 4
    # Setup the data loaders
    dataset.create_dataloaders(
        model,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        precompute_multi_scale=False
    )
    print(next(iter(dataset.test_dataloaders[0])))

    logdir = "./train_CLassifier3"
    logdir = os.path.join(logdir, str(datetime.datetime.now()))
    mkdir_os(logdir)
    os.chdir(logdir)
    tracker = dataset.get_tracker(False, True)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train_epoch(device):
        model.to(device)
        model.train()
        tracker.reset("train")
        train_loader = dataset.train_dataloader
        iter_data_time = time.time()
        with Ctq(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):

                # #point数据
                # pos = torch.randn((4, 2048, 3))
                # #normal数据
                # x = torch.randn((4, 2048, 3))
                # #label
                # y = torch.ones((4, 1))
                # data_temp = Data(pos=pos, x=x, y=y)
                # data_temp = SimpleBatch.from_data_list([data])

                data_temp = Data(
                    pos=torch.randn((2048, 3)),
                    x=torch.randn((2048, 3)),
                    y=torch.randint(0, 10, (1,)),
                )
                datalist = [data_temp for i in range(BATCH_SIZE)]
                pre_transform = T.Compose([T.NormalizeScale(), T3D.GridSampling3D(0.02), T.FixedPoints(2048)])
                if pre_transform:
                    datalist = [pre_transform(d.clone()) for d in datalist]
                data_result = SimpleBatch.from_data_list(datalist)

                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                optimizer.zero_grad()
                data.to(device)
                model.forward(data)
                model.backward()
                optimizer.step()
                if i % 10 == 0:
                    tracker.track(model)

                tq_train_loader.set_postfix(
                    **tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    def test_epoch(device):
        model.to(device)
        model.eval()
        tracker.reset("test")
        test_loader = dataset.test_dataloaders[0]
        iter_data_time = time.time()
        with Ctq(test_loader) as tq_test_loader:
            for i, data in enumerate(tq_test_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                data.to(device)
                model.forward(data)
                tracker.track(model)

                tq_test_loader.set_postfix(
                    **tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    EPOCHS = 20
    for i in range(EPOCHS):
        print("=========== EPOCH %i ===========" % i)
        time.sleep(0.5)
        train_epoch('cuda')
        tracker.publish(i)
        test_epoch('cuda')
        tracker.publish(i)


if __name__ == '__main__':
    main()

2分割代码示例

PartSegKPConv.py
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: PartSegKPConv.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 28, 2021
# ---
import torch
from torch_points3d.core.common_modules import MLP, UnaryConv
from torch_points3d.applications.kpconv import KPConv

class MultiHeadClassifier(torch.nn.Module):
    """ Allows segregated segmentation in case the category of an object is known.
    This is the case in ShapeNet for example.

    Parameters
    ----------
    in_features -
        size of the input channel
    cat_to_seg
        category to segment maps for example:
        {
            'Airplane': [0,1,2],
            'Table': [3,4]
        }

    """

    def __init__(self, in_features, cat_to_seg, dropout_proba=0.5, bn_momentum=0.1):
        super().__init__()
        self._cat_to_seg = {}
        self._num_categories = len(cat_to_seg)
        self._max_seg_count = 0
        self._max_seg = 0
        self._shifts = torch.zeros((self._num_categories,), dtype=torch.long)
        for i, seg in enumerate(cat_to_seg.values()):
            self._max_seg_count = max(self._max_seg_count, len(seg))
            self._max_seg = max(self._max_seg, max(seg))
            self._shifts[i] = min(seg)
            self._cat_to_seg[i] = seg

        self.channel_rasing = MLP(
            [in_features, self._num_categories * in_features], bn_momentum=bn_momentum, bias=False
        )
        if dropout_proba:
            self.channel_rasing.add_module("Dropout", torch.nn.Dropout(p=dropout_proba))

        self.classifier = UnaryConv((self._num_categories, in_features, self._max_seg_count))
        self._bias = torch.nn.Parameter(torch.zeros(self._max_seg_count, ))

    def forward(self, features, category_labels, **kwargs):
        assert features.dim() == 2
        self._shifts = self._shifts.to(features.device)
        in_dim = features.shape[-1]
        features = self.channel_rasing(features)
        features = features.reshape((-1, self._num_categories, in_dim))
        features = features.transpose(0, 1)  # [num_categories, num_points, in_dim]
        features = self.classifier(features) + self._bias  # [num_categories, num_points, max_seg]
        ind = category_labels.unsqueeze(-1).repeat(1, 1, features.shape[-1]).long()

        logits = features.gather(0, ind).squeeze(0)
        softmax = torch.nn.functional.log_softmax(logits, dim=-1)

        output = torch.zeros(logits.shape[0], self._max_seg + 1).to(features.device)
        cats_in_batch = torch.unique(category_labels)
        for cat in cats_in_batch:
            cat_mask = category_labels == cat
            seg_indices = self._cat_to_seg[cat.item()]
            probs = softmax[cat_mask, : len(seg_indices)]
            output[cat_mask, seg_indices[0]: seg_indices[-1] + 1] = probs

        return output

class PartSegKPConv(torch.nn.Module):
    def __init__(self, cat_to_seg, USE_NORMALS):
        super().__init__()
        self.unet = KPConv(
            architecture="unet",
            input_nc=USE_NORMALS * 3,
            num_layers=4,
            in_grid_size=0.02
        )
        self.classifier = MultiHeadClassifier(self.unet.output_nc, cat_to_seg)

    @property
    def conv_type(self):
        """ This is needed by the dataset to infer which batch collate should be used"""
        return self.unet.conv_type

    def get_batch(self):
        return self.batch

    def get_output(self):
        """ This is needed by the tracker to get access to the ouputs of the network"""
        return self.output

    def get_labels(self):
        """ Needed by the tracker in order to access ground truth labels"""
        return self.labels

    def get_current_losses(self):
        """ Entry point for the tracker to grab the loss """
        return {"loss_seg": float(self.loss_seg)}

    def forward(self, data):
        self.labels = data.y
        self.batch = data.batch

        # Forward through unet and classifier
        data_features = self.unet(data)
        self.output = self.classifier(data_features.x, data.category)

        # Set loss for the backward pass
        self.loss_seg = torch.nn.functional.nll_loss(self.output, self.labels)
        return self.output

    def get_spatial_ops(self):
        return self.unet.get_spatial_ops()

    def backward(self):
        self.loss_seg.backward()
train_Segmentation.py
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_Segmentation.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 28, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import datetime

import torch

from torch_points3d.datasets.classification.modelnet import SampledModelNet
import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T

from omegaconf import OmegaConf
from torch_points3d.datasets.segmentation import ShapeNetDataset

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq
from tqdm.auto import tqdm

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)

import PartSegKPConv as PSKPC

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

class Trainer:
    def __init__(self, model, dataset, num_epoch = 50, device=torch.device('cuda')):
        self.num_epoch = num_epoch
        self._model = model
        self._dataset=dataset
        self.device = device

    def fit(self):
        self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001)
        self.tracker = self._dataset.get_tracker(False, True)

        for i in range(self.num_epoch):
            print("=========== EPOCH %i ===========" % i)
            time.sleep(0.5)
            self.train_epoch()
            self.tracker.publish(i)
            self.test_epoch()
            self.tracker.publish(i)

    def train_epoch(self):
        self._model.to(self.device)
        self._model.train()
        self.tracker.reset("train")
        train_loader = self._dataset.train_dataloader
        iter_data_time = time.time()
        with tqdm(train_loader) as tq_train_loader:
            for i, data in enumerate(tq_train_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                self.optimizer.zero_grad()
                data.to(self.device)
                self._model.forward(data)
                self._model.backward()
                self.optimizer.step()
                if i % 10 == 0:
                    self.tracker.track(self._model)

                tq_train_loader.set_postfix(
                    **self.tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

    def test_epoch(self):
        self._model.to(self.device)
        self._model.eval()
        self.tracker.reset("test")
        test_loader = self._dataset.test_dataloaders[0]
        iter_data_time = time.time()
        with tqdm(test_loader) as tq_test_loader:
            for i, data in enumerate(tq_test_loader):
                t_data = time.time() - iter_data_time
                iter_start_time = time.time()
                data.to(self.device)
                self._model.forward(data)
                self.tracker.track(self._model)

                tq_test_loader.set_postfix(
                    **self.tracker.get_metrics(),
                    data_loading=float(t_data),
                    iteration=float(time.time() - iter_start_time),
                )
                iter_data_time = time.time()

def main():
    DIR = ""
    CATEGORY = "All"  # @param ["Airplane", "Bag", "All", "Motorbike"] {allow-input: true}
    USE_NORMALS = True

    shapenet_yaml = """
    class: shapenet.ShapeNetDataset
    task: segmentation
    dataroot: %s
    normal: %r                                    # Use normal vectors as features
    first_subsampling: 0.02                       # Grid size of the input data
    pre_transforms:                               # Offline transforms, done only once
        - transform: NormalizeScale           
        - transform: GridSampling3D
          params:
            size: ${first_subsampling}
    train_transforms:                             # Data augmentation pipeline
        - transform: RandomNoise
          params:
            sigma: 0.01
            clip: 0.05
        - transform: RandomScaleAnisotropic
          params:
            scales: [0.9,1.1]
        - transform: AddOnes
        - transform: AddFeatsByKeys
          params:
            list_add_to_x: [True]
            feat_names: ["ones"]
            delete_feats: [True]
    test_transforms:
        - transform: AddOnes
        - transform: AddFeatsByKeys
          params:
            list_add_to_x: [True]
            feat_names: ["ones"]
            delete_feats: [True]
    """ % (os.path.join(DIR, "data"), USE_NORMALS)
    params = OmegaConf.create(shapenet_yaml)
    if CATEGORY != "All":
        params.category = CATEGORY
    dataset = ShapeNetDataset(params)
    print(dataset)

    model = PSKPC.PartSegKPConv(cat_to_seg=dataset.class_to_segments, USE_NORMALS=USE_NORMALS)

    NUM_WORKERS = 0
    BATCH_SIZE = 4
    # Setup the data loaders
    dataset.create_dataloaders(
        model,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        precompute_multi_scale=True
    )
    sample = next(iter(dataset.train_dataloader))
    print(sample)
    print(sample.keys)
    print(sample.multiscale)

    logdir = "./train_Segmentation1"
    logdir = os.path.join(logdir, str(datetime.datetime.now()))
    mkdir_os(logdir)
    os.chdir(logdir)

    trainer = Trainer(model, dataset, num_epoch = 10)
    trainer.fit()


if __name__ == '__main__':
    main()

3inference代码示例

分割inference

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_Segmentation.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 4月 28, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import datetime

import torch

import torch_points3d.core.data_transform as T3D
import torch_geometric.transforms as T

from torch_points3d.datasets.batch import SimpleBatch

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq
from tqdm.auto import tqdm

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)

from torch_points3d.metrics.model_checkpoint import ModelCheckpoint

import logging
log = logging.getLogger(__name__)

import numpy as np

from torch_geometric.io import read_txt_array

from torch_geometric.data import Data

import os.path as osp
from torch_points3d.core.data_transform import SaveOriginalPosId

HEADER = '''\
# .PCD v0.7 - Point Cloud Data file format
VERSION 0.7
FIELDS x y z label object
SIZE 4 4 4 4 4
TYPE F F F I I
COUNT 1 1 1 1 1
WIDTH {}
HEIGHT 1
VIEWPOINT 0 0 0 1 0 0 0
POINTS {}
DATA ascii
'''
def write_pcd(points, save_pcd_path):
    n = len(points)
    lines = []
    for i in range(n):
        x, y, z, label, object = points[i]
        lines.append('{:.6f} {:.6f} {:.6f} {} {}'.format(x, y, z, label, object))
    with open(save_pcd_path, 'w') as f:
        f.write(HEADER.format(n, n))
        f.write('\n'.join(lines))

def write_pcd(points, save_pcd_path):
    with open(save_pcd_path, 'w') as f:
        f.write(HEADER.format(len(points), len(points)) + '\n')
        np.savetxt(f, points, delimiter=' ', fmt='%f %f %f %d %d')

def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

category_ids = {
    "luomu1": "dangban_luomu",
}

def main():
    USE_NORMAL = False
    cuda = True
    enable_dropout = False
    torch.backends.cudnn.enabled = True
    device = torch.device("cuda" if (torch.cuda.is_available() and cuda) else "cpu")

    checkpoint_dir = "/home/boyun/deepglint/dataset/HSR工作/torch-points3d-master/outputs/benchmark/benchmark-pointnet2_charlesssg-20210426_174925"
    output_path = "/home/boyun/deepglint/dataset/HSR工作/torch-points3d-master/ynh/result"
    test_path = "/home/boyun/deepglint/dataset/HSR工作/分割数据/shapenetcore_partanno_segmentation_benchmark_v0_normal/"
    model_name = "pointnet2_charlesssg"
    # Used during resume, select with model to load from [miou, macc, acc..., latest]
    weight_name = "miou"
    checkpoint = ModelCheckpoint(checkpoint_dir, model_name, weight_name, strict=True)
    # Create dataset and mdoel
    model = checkpoint.create_model(checkpoint.dataset_properties, weight_name=weight_name)

    log.info(model)
    log.info("Model size = %i", sum(param.numel() for param in model.parameters() if param.requires_grad))

    model.eval()
    if enable_dropout==True:
        model.enable_dropout_in_eval()
    model = model.to(device)

    if not os.path.exists(output_path):
        os.makedirs(output_path)

    pre_filter = None
    pre_transform = T.Compose([T.NormalizeScale(), T3D.GridSampling3D(0.02)])
    if USE_NORMAL:
        transform = T.Compose([T.FixedPoints(8192),
                    T3D.AddFeatsByKeys(list_add_to_x=[True], feat_names=["x"], delete_feats=[True])])
    else:
        transform = T.Compose([T.FixedPoints(8192)])

    has_pre_transform = pre_transform is not None
    has_transform = transform is not None
    categories = "luomu1"
    if categories is None:
        categories = list(category_ids.keys())
    if isinstance(categories, str):
        categories = [categories]
    data_raw_list = []
    data_list = []
    file_name_list = []
    categories_ids = [category_ids[cat] for cat in categories]
    cat_idx = {categories_ids[i]: i for i in range(len(categories_ids))}

    id_scan = -1
    lines = os.listdir(test_path)
    for m_ind, m_val in enumerate(lines):
        cat = m_val
        if m_val not in categories_ids:
            continue
        categories_lines = os.listdir(os.path.join(test_path, m_val))
        for n_ind, n_val in enumerate(categories_lines):
            id_scan += 1
            data = read_txt_array(os.path.join(test_path, m_val, n_val))
            # RuntimeError: Given groups=1, weight of size [64, 3, 1, 1], expected input[1, 6, 512, 64] to have 3 channels, but got 6 channels instead
            if USE_NORMAL:
                pos = data[:, :3]
                x = data[:, 3:6]
                y = data[:, -1].type(torch.long)
                category = torch.ones(x.shape[0], dtype=torch.long) * cat_idx[cat]
                id_scan_tensor = torch.from_numpy(np.asarray([id_scan])).clone()
            else:
                pos = data[:, :3]
                x = None
                y = data[:, -1].type(torch.long)
                category = torch.ones(pos.shape[0], dtype=torch.long) * cat_idx[cat]
                id_scan_tensor = torch.from_numpy(np.asarray([id_scan])).clone()
            data = Data(pos=pos, x=x, y=y, category=category, id_scan=id_scan_tensor)

            data = SaveOriginalPosId()(data)
            if pre_filter is not None and not pre_filter(data):
                continue
            data_raw_list.append(data.clone() if has_pre_transform else data)

            if has_pre_transform and has_transform:
                data = pre_transform(data)
                data = transform(data)
                data_list.append(data)
            elif has_pre_transform and ~has_transform:
                data = pre_transform(data)
                data_list.append(data)
            file_name_list.append(n_val)
            #data_raw_list原始, data_list修改后
    with torch.no_grad():
        with tqdm(data_list) as tq_data_list:
            for index, data_value in enumerate(tq_data_list):
                data_value = SimpleBatch.from_data_list([data_value])
                # data_value.to(device)
                # model.forward(data_value)
                model.set_input(data_value, device)
                output = model.forward()
                pred = torch.max(output, -1)[1].reshape(data_value.pos.shape[0:2])
                pred_np = pred.cpu().numpy()
                data_temp = data_value.cpu()
                origin_id = (data_temp.origin_id).numpy()
                ori_data = data_raw_list[index]
                np_ori_data = (ori_data.pos).numpy()
                # np_ori_data N(ori)*3
                # origin_id 1*N(select)
                select_data = np_ori_data[origin_id.reshape(-1,)]

                xyz = select_data
                label = pred_np.reshape(-1, 1)
                h, w = label.shape
                object = np.ones((h, 1), dtype=np.int8) * (-1)
                save_pcd_point = np.hstack((xyz, label, object))
                file_name = file_name_list[index]
                write_pcd(save_pcd_point, os.path.join(output_path, file_name.replace(".txt", ".pcd")))


if __name__ == '__main__':
    main()

3配准inference代码示例

train_registration1.py

# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_registration1.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site: 
# @Time: 4月 29, 2021
# ---
# python3.8
# -*- coding: utf-8 -*-
# ---
# @Software: PyCharm
# @File: train_Segmentation.py
# @Author: ---
# @Institution: BeiJing, China
# @E-mail: lgdyangninghua@163.com
# @Site:
# @Time: 4月 28, 2021
# ---
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import sys
import time
import numpy as np
import datetime

import torch

from torch_points3d.core.data_transform import GridSampling3D, AddOnes, AddFeatByKey
from torch_geometric.transforms import Compose
from torch_geometric.data import Batch

# Model
from torch_points3d.applications.pretrained_api import PretainedRegistry

# post processing
from torch_points3d.utils.registration import get_matches, fast_global_registration

from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq
from tqdm.auto import tqdm

ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT)


def mkdir_os(path):
    if not os.path.exists(path):
        os.makedirs(path)

from plyfile import PlyData
def read_ply(path):
    with open(path, 'rb') as f:
        plydata = PlyData.read(f)
    vertex = plydata['vertex']
    return np.vstack((vertex['x'], vertex['y'], vertex['z'])).T

def main():
    # path_s = "data/3DMatch/redkitchen_000.ply"
    # path_t = "data/3DMatch/redkitchen_010.ply"
    path_s = "data/3DMatch/cloud_bin_0.ply"
    path_t = "data/3DMatch/cloud_bin_1.ply"
    pcd_s = read_ply(path_s)
    pcd_t = read_ply(path_t)

    transform = Compose([GridSampling3D(mode='last', size=0.02, quantize_coords=True), AddOnes(), AddFeatByKey(add_to_x=True, feat_name="ones")])
    data_s = transform(Batch(pos=torch.from_numpy(pcd_s).float(), batch=torch.zeros(pcd_s.shape[0]).long()))
    data_t = transform(Batch(pos=torch.from_numpy(pcd_t).float(), batch=torch.zeros(pcd_t.shape[0]).long()))

    # This will log some errors, don't worry it's all good!
    model = PretainedRegistry.from_pretrained("minkowski-registration-3dmatch").cuda()

    with torch.no_grad():
        model.set_input(data_s, "cuda")
        output_s = model.forward()
        model.set_input(data_t, "cuda")
        output_t = model.forward()

    rand_s = torch.randint(0, len(output_s), (5000,))
    rand_t = torch.randint(0, len(output_t), (5000,))
    matches = get_matches(output_s[rand_s], output_t[rand_t])
    T_est = fast_global_registration(data_s.pos[rand_s][matches[:, 0]], data_t.pos[rand_t][matches[:, 1]])



if __name__ == '__main__':
    main()

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值