【常用代码记录】

python

生成001的数

for i in range(1, 10):
    print(str(i).zfill(3))
    ## 001 002 ...

时间

生成当前时间

import time
time.strftime('%Y%m%d-%H%M%S')
# 20240330-194914


log_file = os.path.join(cfg.log_dir, 'train-{}.log'.format(time.strftime('%Y%m%d-%H%M%S')))

log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
import time
timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
experiment_dir = os.path.join(cfg.experiment_name, timestamp)
# 2024-06-30-14:09:27
from pathlib import Path

timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)

if args.log_dir is None:
    exp_dir = exp_dir.joinpath(timestr)
else:
    exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)

h5文件

读取

all_data = []
all_label = []
for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5' % partition)):
    f = h5py.File(h5_name, 'r')
    data = f['data'][:].astype('float32')
    label = f['label'][:].astype('int64')
    f.close()
    all_data.append(data)
    all_label.append(label)
all_data = np.concatenate(all_data, axis=0)
all_label = np.concatenate(all_label, axis=0)
return all_data, all_label

写入

# 创建 HDF5 文件
with h5py.File('/home/wangyongqiang/test_data_v2.h5', 'w') as file:
   # 创建数据集并将数组写入其中
   file.create_dataset('test_src_all', data=test_src_all)
   file.create_dataset('test_target_all', data=test_target_all)
print("h5文件写入完成")

添加导包路径和文件查找路径

sys路径

BASE_dir = os.path.dirname(os.path.abspath(__file__))#获取当前文件的绝对路径
sys.path.append(BASE_dir)
sys.path.append(os.path.join(BASE_dir,'models'))
sys.path.append(os.path.join(BASE_dir,'utils'))

OS模块

os.getcwd() :获取当前工作目录的路径(输出运行当前 Python脚本或程序时所处的目录路径)

os.path.dirname(os.path.abspath(__file__)) :获取当前脚本文件的绝对路径,并返回该路径的父目录。

  • __file__ 表示当前脚本文件的路径。
  • os.path.abspath(__file__) 将相对路径转换为绝对路径。
  • os.path.dirname(...) 获取绝对路径的父目录。

os.path.basename(path) 获取指定路径中的文件名部分或目录名部分。

# 这段代码的目的是确保在指定的路径 DATA_DIR 下创建一个目录
# 如果该目录已经存在,则不做任何操作。如果不存在,则创建该目录。
#这样的做法通常用于确保某个目录存在,以便在后续的文件操作中使用。
if not os.path.exists(DATA_DIR):
	os.mkdir(DATA_DIR)

os.path.exists(DATA_DIR): 检查指定路径 DATA_DIR 是否存在。返回 True 表示存在,返回 False 表示不存在。
os.mkdir(DATA_DIR): 如果 DATA_DIR 不存在,使用 os.mkdir 创建该目录。

文件夹创建

from pathlib import Path

'''CREATE DIR'''
timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
exp_dir = Path('./log/')
exp_dir.mkdir(exist_ok=True)
exp_dir = exp_dir.joinpath('classification')
exp_dir.mkdir(exist_ok=True)
if args.log_dir is None:
    exp_dir = exp_dir.joinpath(timestr)
else:
    exp_dir = exp_dir.joinpath(args.log_dir)
exp_dir.mkdir(exist_ok=True)
checkpoints_dir = exp_dir.joinpath('checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = exp_dir.joinpath('logs/')
log_dir.mkdir(exist_ok=True)

动态导包

import importlib
model = importlib.import_module(args.model)

这行代码使用 importlib.import_module 函数来动态地加载一个模块。这里的 args.model 应该是一个字符串,表示要加载的模块的名字。例如,如果 args.model 是 ‘resnet’,那么这行代码效果等同于 import resnet as model。
这种动态加载模块的方式使得程序能够在运行时根据不同的输入参数加载不同的模块,增加了程序的灵活性。

复制py文件到指定目录

这些操作主要目的是设置一个独立的实验环境,其中包含了所有必要的模块和脚本,以便进行一致和可重复的实验运行。这种做法在进行机器学习和数据科学实验时非常常见,有助于实验的管理和版本控制。

import shutil
#复制模块文件到实验目录
shutil.copy('./models/%s.py' % args.model, str(exp_dir))

这行代码使用 shutil.copy 函数将位于 models 目录下的指定模块文件复制到 exp_dir 指定的目录中。args.model 决定了哪个文件会被复制,例如如果 args.model 是 ‘resnet’,那么 ‘./models/resnet.py’ 将被复制。
exp_dir 应该是一个路径对象(通过 pathlib.Path 创建),str(exp_dir) 将其转换为字符串形式的路径。

#复制训练脚本到实验目录
shutil.copy('./train_classification.py', str(exp_dir))

这行代码将训练脚本 train_classification.py 从当前目录复制到 exp_dir 指定的目录。这样做通常是为了保持实验环境的完整性,确保所有相关的脚本和模块都在一个地方,便于运行和追踪。

版本2

os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup')

这段代码使用了 Python 的 os.system 函数来执行系统命令,具体是使用命令行的 cp(copy)命令来复制文件。这里的目的是备份关键的 Python 脚本到特定的目录中。下面是对每行代码的详细解释:

  1. 备份主脚本(main.py

    os.system('cp main.py checkpoints' + '/' + args.exp_name + '/' + 'main.py.backup')
    
    • 这行代码执行一个系统命令,将当前目录下的 main.py 文件复制到 checkpoints/ 目录下的以 args.exp_name 命名的子目录中,并将复制的文件重命名为 main.py.backup
    • args.exp_name 是通过命令行参数传入的,通常代表实验的名称或编号,这样可以确保每个实验的备份文件都保存在各自的目录中。
  2. 备份模型定义脚本(model.py

    os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup')
    
    • 类似地,这行代码将 model.py 复制到相应的实验目录下,并重命名为 model.py.backup。这个文件通常包含了模型的定义,是机器学习或深度学习实验中非常关键的一部分。
  3. 备份数据处理脚本(data.py

    os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup')
    
    • 最后,data.py 文件也被复制到相同的目录中,并命名为 data.py.backupdata.py 通常包含数据加载和预处理的逻辑,对于复现实验结果来说同样重要。

使用 os.system 的注意点

  • 跨平台兼容性cp 命令是 Unix/Linux 系统的命令,不适用于 Windows 系统。在 Windows 系统上,你需要使用 copy 命令。
  • 安全性和效率os.system 调用会创建新的进程,并且对输入参数的处理比较原始,容易受到注入攻击。对于复杂的文件操作,使用 Python 的内置模块如 shutilsubprocess(对参数更好的控制)是更安全、更高效的选择。

这种文件备份方法在实验的不同阶段保留代码的快照,有助于跟踪代码变更、调试和版本控制,是实验管理的一种简单有效的手段。

异常处理

try/except
异常捕捉可以使用 try/except 语句。
在这里插入图片描述
try 语句按照如下方式工作;

首先,执行 try 子句(在关键字 try 和关键字 except 之间的语句)。

如果没有异常发生,忽略 except 子句,try 子句执行后结束。

如果在执行 try 子句的过程中发生了异常,那么 try 子句余下的部分将被忽略。如果异常的类型和 except 之后的名称相符,那么对应的 except 子句将被执行。

如果一个异常没有与任何的 except 匹配,那么这个异常将会传递给上层的 try 中。

try/except…else
try/except 语句还有一个可选的 else 子句,如果使用这个子句,那么必须放在所有的 except 子句之后。

else 子句将在 try 子句没有发生任何异常的时候执行。
在这里插入图片描述

抛出异常

Python 使用 raise 语句抛出一个指定的异常。

raise语法格式如下:

raise [Exception [, args [, traceback]]]

以下实例如果 x 大于 5 就触发异常:

x = 10
if x > 5:
    raise Exception('x 不能大于 5。x 的值为: {}'.format(x))

执行以上代码会触发异常:

Traceback (most recent call last):
  File "test.py", line 3, in <module>
    raise Exception('x 不能大于 5。x 的值为: {}'.format(x))
Exception: x 不能大于 5。x 的值为: 10

aise 唯一的一个参数指定了要被抛出的异常。它必须是一个异常的实例或者是异常的类(也就是 Exception 的子类)。

如果你只想知道这是否抛出了一个异常,并不想去处理它,那么一个简单的 raise 语句就可以再次把它抛出。

从命令行 传递给 Python脚本参数

import sys


sys.argv[1]

sys.argv

sys.argv 是 Python 标准库 sys 模块的一个属性,它包含从命令行传递给 Python 脚本的参数。

  • sys.argv[0] 是脚本本身的名称。
  • sys.argv[1] 是第一个命令行参数,也就是执行脚本时传递的参数。

解析参数,从yaml文件中解析

import yaml
import sys

def parse_args_from_yaml(yaml_path):
    with open(yaml_path, 'r',encoding='utf-8') as fd:
        args = yaml.safe_load(fd)
        args = EasyDict(d=args)
    return args

args = parse_args_from_yaml(sys.argv[1])

这个代码段的功能是从一个 YAML 文件中解析配置参数,并将这些参数转换为一个 EasyDict 对象,以便更方便地访问和使用参数。我们逐步分析它的各个部分:

1. import yaml

yaml 是一个用于处理 YAML 文件的 Python 库。它能够将 YAML 文件加载为 Python 字典,或者将 Python 字典序列化为 YAML 格式。

2. import sys

sys 模块允许我们访问传递给 Python 脚本的命令行参数。在这个程序中,sys.argv[1] 表示从命令行传递的第一个参数(通常是 YAML 文件的路径)。

3. parse_args_from_yaml(yaml_path)

这是一个自定义函数,用于从指定的 YAML 文件路径 yaml_path 解析参数。

函数步骤:
  • with open(yaml_path, 'r', encoding='utf-8') as fd:: 打开 yaml_path 指定的 YAML 文件,文件对象被赋值给 fd,并以 UTF-8 编码进行读取。
  • args = yaml.safe_load(fd): 使用 yaml.safe_load() 将 YAML 文件中的内容加载为一个 Python 字典。safe_load 方法会确保加载的 YAML 文件安全且不执行任何不受信任的代码。
  • args = EasyDict(d=args): 这里使用了 EasyDict 将加载的参数转换为 EasyDict 对象。EasyDict 是一种方便的数据结构,允许你像访问属性一样访问字典的键。比如 args["key"] 可以被简化为 args.key

4. args = parse_args_from_yaml(sys.argv[1])

这一行调用了 parse_args_from_yaml 函数,并将第一个命令行参数(sys.argv[1])作为 YAML 文件的路径传递给该函数。结果是 args 包含了 YAML 文件中解析的参数,并且可以通过 args.key 的方式进行访问。

EasyDict 的功能

EasyDict 是一个 Python 库,它允许我们以点号属性访问字典中的键。通常情况下,字典是通过键值对的形式访问的,比如:

my_dict = {'learning_rate': 0.001}
print(my_dict['learning_rate'])  # 0.001

使用 EasyDict 后,你可以这样访问键:

from easydict import EasyDict
my_dict = EasyDict({'learning_rate': 0.001})
print(my_dict.learning_rate)  # 0.001

这让代码更加简洁和清晰。

完整工作流程

假设你的命令行参数是:

python your_script.py config.yaml
  • sys.argv[1] 将是 "config.yaml"
  • parse_args_from_yaml("config.yaml") 读取该 YAML 文件并返回一个 EasyDict 对象,包含了 YAML 文件中的配置参数。
  • 你可以使用 args 访问这些参数,像 args.learning_rateargs.batch_size 这样的方式。

示例 YAML 文件

如果 config.yaml 文件内容如下:

learning_rate: 0.001
batch_size: 64
epochs: 10

那么,在 Python 中可以这样访问参数:

print(args.learning_rate)  # 0.001
print(args.batch_size)     # 64
print(args.epochs)         # 10

代码改进建议

在你的代码中没有 EasyDict 的导入。为了使它正常工作,你需要添加:

from easydict import EasyDict

最终代码示例

import yaml
import sys
from easydict import EasyDict

def parse_args_from_yaml(yaml_path):
    with open(yaml_path, 'r', encoding='utf-8') as fd:
        args = yaml.safe_load(fd)
        args = EasyDict(args)
    return args

args = parse_args_from_yaml(sys.argv[1])

希望这段解释对你有帮助!

GPU

指定可见显卡

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

终端(命令行)中设置环境变量:

在命令行中启动脚本之前设置环境变量:

CUDA_VISIBLE_DEVICES=1 python your_script.py

超参数指定显卡可见

parser.add_argument('--gpu', type=str, default='', metavar='N', required=True, help='Name of gpu')

os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

pytorch

检查PyTorch是否正确安装了 CUDA 支持:

python -c "import torch; print(torch.cuda.is_available())"

tensor 输入到模型中的维度

tensor 输入到模型中的维度为(B,C,N)。
pytorch规定channel First,即中间为特征通道

PyTorch使用的是 "channel first" 格式,即通道维C在高度 H和宽度W之前。图像数据的维度通常是 (N, C, H, W),其中 N是批量大小。
  
TensorFlow通常使用 "channel last" 格式,即通道维C在高度H和宽度W之后。因此,图像数据的维度通常是(N, H, W, C) 。

tensor 维度交换

.permute()

# 创建一个示例张量
points= torch.randn(4, 3, 64, 64)

# ===使用permute交换第2和第3维度===
points_swapped = points.permute(0, 1, 3, 2)

torch.transpose

# 创建一个示例张量
tensor = torch.randn(4, 3, 64, 64)

# 使用 torch.transpose 交换第2和第3维度
tensor_swapped = torch.transpose(tensor, 2, 3)

# torch.transpose 并没有改变原始张量,而是返回了一个新的张量,这个新的张量是按照指定维度交换后的结果。

# 如果你想要直接修改原始张量而不返回一个新的张量,你可以使用 tensor.transpose_ 方法:
tensor.transpose_(1, 2)
# 这会在原地修改张量。

torch.transpose(tensor, 2, 3)tensor_swapped = tensor.permute(0, 1, 3, 2) 都用于在 PyTorch 中交换张量的维度,但它们有一些区别。

  1. 参数形式:

    • torch.transpose 的形式是 torch.transpose(input, dim0, dim1),其中 input 是输入张量,dim0dim1 是要交换的维度的索引。
    • tensor.permute 的形式是 tensor.permute(dims),其中 dims 是维度的新排列顺序的元组。
  2. 灵活性:

    • torch.transpose 允许你直接指定要交换的维度的索引,更直接。在上述例子中,torch.transpose(tensor, 2, 3) 表示交换第2和第3维度。
    • tensor.permute 则更灵活,因为你可以通过提供新的维度排列顺序的元组来实现更复杂的维度交换。在 tensor.permute(0, 1, 3, 2) 中,你可以调整元组中的数字来满足不同的需求。
  3. 内存操作:

    • 在内部实现上,两者的性能和内存操作可能略有不同。tensor.permute 通常更灵活,但在某些情况下可能需要进行数据的重排列,这可能会导致一些内存操作。torch.transpose 更直接,可能在某些情况下更高效。

总的来说,选择使用哪个取决于你的需求。如果只是简单地交换两个维度,torch.transpose 可能更直观。如果需要更灵活的维度排列,或者要进行复杂的维度交换,那么 tensor.permute 是一个更通用的选择。

ndarray 维度交换

numpy.swapaxes

arr_swapped = np.swapaxes(arr, 1, 2)

numpy.transpose

arr_transposed = np.transpose(arr, (0, 2, 1))

注意,这两种方法都不会改变原始数组,而是返回一个具有交换维度后形状的新数组。

2维

# 创建一个形状为(1024,3)的示例数组
array = np.random.rand(1024, 3)

# 使用reshape进行形状变换
reshaped_array = array.T  # 或者 array.transpose()

tensor增加 / 较少一个维度

# 创建一个NumPy数组
numpy_array = np.array([[1, 2, 3], [4, 5, 6]])

# 将NumPy数组转换为PyTorch张量
tensor = torch.tensor(numpy_array)

# 使用unsqueeze方法在第0维度上增加一个维度
tensor = tensor.unsqueeze(0)

# 使用squeeze()方法去掉所有大小为1的维度
tensor = tensor.squeeze()

numpy

列表拼接

沿着现有的轴拼接数组

np.concatenate 函数默认情况下沿着现有的轴拼接数组。

all_data = []
all_data.append(data)
all_data = np.concatenate(all_data, axis=0)

在新的维度上拼接数组

使用 np.stack

np.stack 允许你在新的轴上拼接数组。假设 all_data 是一个包含多个数组的列表:

import numpy as np

# 示例数据
all_data = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])]

# 在新的维度上拼接
result = np.stack(all_data, axis=0)
print(result)

这将输出:

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

在这个例子中,新的维度被添加为第一个维度(axis=0)。
np.expand_dims 再使用 np.concatenate

使用 np.expand_dimsnp.concatenate

你也可以先使用 np.expand_dims 在每个数组上添加一个新的轴,然后使用 np.concatenate 沿着新的轴拼接:

import numpy as np

# 示例数据
all_data = [np.array([1, 2, 3]), np.array([4, 5, 6]), np.array([7, 8, 9])]

# 在新的维度上拼接
expanded_data = [np.expand_dims(arr, axis=0) for arr in all_data]
result = np.concatenate(expanded_data, axis=0)
print(result)

这将输出相同的结果:

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

Linux

查看文件占用空间

要查看Linux系统中文件夹的大小,你可以使用du命令(磁盘使用量)。

使用以下命令来查看/home/user/documents文件夹的大小:

du -sh /home/user/documents

在上述命令中,-s标志用于汇总文件夹的大小,-h标志用于以人类可读的格式显示大小。

du(Disk Usage)命令是查看磁盘使用情况的工具。

查看当前目录下的所有子目录和它们的大小(以人类可读方式):
du -h --max-depth=1
查看指定目录(例如 /var/www)下的所有子目录和它们的大小(以人类可读方式):
du -h --max-depth=1 /var/www

--max-depth=1 选项用于指定要查看的目录层级深度。你可以根据需要调整深度值。 du 命令默认以千字节为单位显示大小,使用 -h 选项以人类可读的方式显示大小。

使用 du 命令查看每个用户的家目录占用空间
如果每个用户的文件都存放在 /home 目录下,可以用以下命令来查看每个用户在 /home 目录中的存储占用:

du -sh /home/*

du:显示目录或文件的磁盘使用情况。
-s:只显示总计而不列出每个子目录的详细信息。
-h:以人类可读的格式显示磁盘使用情况(如 GB、MB)。
这将显示每个用户在 /home 目录下使用的总磁盘空间。

查看linux还剩多少储存空间

你可以使用以下命令来查看 Linux 系统中剩余的存储空间:

df -h
  • df 命令显示文件系统的磁盘使用情况。
  • -h 参数表示以人类可读格式显示结果(例如以 GB、MB 为单位,而不是字节)。

执行该命令后,终端会列出各个挂载点的总空间、已用空间、剩余空间以及使用百分比。

如果你只想查看某个特定目录或文件系统的剩余空间,例如根目录,可以使用:

df -h /

旋转

# Generate rigid transform

anglex = np.random.uniform() * np.pi * self.rot_mag / 180.0
angley = np.random.uniform() * np.pi * self.rot_mag / 180.0
anglez = np.random.uniform() * np.pi * self.rot_mag / 180.0

cosx = np.cos(anglex)
cosy = np.cos(angley)
cosz = np.cos(anglez)
sinx = np.sin(anglex)
siny = np.sin(angley)
sinz = np.sin(anglez)
Rx = np.array([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]])
Ry = np.array([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]])
Rz = np.array([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]])

if not self.only_z:
    R_ab = Rx @ Ry @ Rz
else:
    R_ab = Rz
    
t_ab = np.random.uniform(-self.trans_mag, self.trans_mag, 3)

rand_SE3 = np.concatenate((R_ab, t_ab[:, None]), axis=1).astype(np.float32)

ref, transform_s_r = self.apply_transform(ref, rand_SE3)
# Apply to source to get reference
sample["transform_gt"] = transform_s_r
sample["pose_gt"] = se3.np_mat2quat(transform_s_r)
def np_transform(g: np.ndarray, pts: np.ndarray):
    """ Applies the SE3 transform

    Args:
        g: SE3 transformation matrix of size ([B,] 3/4, 4)
        pts: Points to be transformed ([B,] N, 3)

    Returns:
        transformed points of size (N, 3)
    """
    rot = g[..., :3, :3]  # (3, 3)
    trans = g[..., :3, 3]  # (3)

    transformed = pts[..., :3] @ np.swapaxes(rot, -1, -2) + trans[..., None, :]
    return transformed



def apply_transform(self, p0, transform_mat):
    p1 = np_transform(transform_mat, p0[:, :3])
    if p0.shape[1] == 6:  # Need to rotate normals also
        n1 = so3.transform(transform_mat[:3, :3], p0[:, 3:6])
        p1 = np.concatenate((p1, n1), axis=-1)

    gt = transform_mat

    return p1, gt
def torch_transform(g, a, normals=None):
    """ Applies the SE3 transform

    Args:
        g: SE3 transformation matrix of size ([1,] 3/4, 4) or (B, 3/4, 4)
        a: Points to be transformed (N, 3) or (B, N, 3)
        normals: (Optional). If provided, normals will be transformed

    Returns:
        transformed points of size (N, 3) or (B, N, 3)

    """
    R = g[..., :3, :3]  # (B, 3, 3)
    p = g[..., :3, 3]  # (B, 3)

    if len(g.size()) == len(a.size()):
        b = torch.matmul(a, R.transpose(-1, -2)) + p[..., None, :]
    else:
        raise NotImplementedError
        b = R.matmul(a.unsqueeze(-1)).squeeze(-1) + p  # No batch. Not checked

    if normals is not None:
        rotated_normals = normals @ R.transpose(-1, -2)
        return b, rotated_normals

    else:
        return b

空间计算

overlap_src_mask, overlap_ref_mask = self.generate_overlap_mask(xyz_src.clone(), xyz_ref.clone(), src_pred_mask, ref_pred_mask, transform_gt)


def generate_overlap_mask(self, points_src: torch.Tensor, points_ref: torch.Tensor, mask_src: torch.Tensor, mask_ref: torch.Tensor, transform_gt: torch.Tensor):
'''
计算重叠掩码
'''
    points_src[torch.logical_not(mask_src), :] = 50.0
    points_ref[torch.logical_not(mask_ref), :] = 100.0
    points_src = torch_transform(transform_gt, points_src)
    dist_matrix = torch.sqrt(torch.sum(torch.square(points_src[:, :, None, :] - points_ref[:, None, :, :]), dim=-1))  # (B, N, N)
    dist_s2r = torch.min(dist_matrix, dim=2)[0]
    dist_r2s = torch.min(dist_matrix, dim=1)[0]
    overlap_src_mask = dist_s2r < self.overlap_dist  # (B, N)
    overlap_ref_mask = dist_r2s < self.overlap_dist  # (B, N)

    return overlap_src_mask, overlap_ref_mask
dist_matrix = torch.sqrt(torch.sum(torch.square(points_src[:, :, None, :] - points_ref[:, None, :, :]), dim=-1))  # (B, N, N)

这行代码计算了两个点云之间的欧氏距离矩阵 dist_matrix。让我们逐步解释:

  • points_src[:, :, None, :]:通过在第二个维度上插入一个新的维度(Noneunsqueeze),将源点云的形状从 (B, N, 3) 变为 (B, N, 1, 3)

  • points_ref[:, None, :, :]:同样,将参考点云的形状从 (B, N, 3) 变为 (B, 1, N, 3)

  • points_src[:, :, None, :] - points_ref[:, None, :, :]:计算两个点云之间的坐标差,得到一个形状为 (B, N, N, 3) 的张量,其中每个元素表示两个点之间在每个坐标轴上的差值。

  • torch.square(...):将差值张量的每个元素进行平方操作。

  • torch.sum(..., dim=-1):在最后一个维度上求和,得到一个形状为 (B, N, N) 的张量,每个元素表示两个点之间的平方距离。

  • torch.sqrt(...):对上述平方距离取平方根,最终得到欧氏距离矩阵 dist_matrix,形状为 (B, N, N)

这个矩阵中的每个元素 (i, j) 表示源点云中第 i 个点与参考点云中第 j 个点之间的欧氏距离。

初始位姿

init_quat = t3d.euler2quat(0., 0., 0., "sxyz")
init_quat = torch.from_numpy(init_quat).expand(B, 4)
init_translate = torch.from_numpy(np.array([[0., 0., 0.]])).expand(B, 3)
pose_pred = torch.cat((init_quat, init_translate), dim=1).float().cuda()  # (B, 7)
def euler2quat(ai, aj, ak, axes='sxyz'):
    """Return `quaternion` from Euler angles and axis sequence `axes`
	从欧拉角转换四元数
    Parameters
    ----------
    ai : float
        First rotation angle (according to `axes`).
    aj : float
        Second rotation angle (according to `axes`).
    ak : float
        Third rotation angle (according to `axes`).
    axes : str, optional
        Axis specification; one of 24 axis sequences as string or encoded
        tuple - e.g. ``sxyz`` (the default).

    Returns
    -------
    quat : array shape (4,)
       Quaternion in w, x, y z (real, then vector) format

    Examples
    --------
    >>> q = euler2quat(1, 2, 3, 'ryxz')
    >>> np.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
    True
    """
    try:
        firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
    except (AttributeError, KeyError):
        _TUPLE2AXES[axes]  # validation
        firstaxis, parity, repetition, frame = axes

    i = firstaxis + 1
    j = _NEXT_AXIS[i+parity-1] + 1
    k = _NEXT_AXIS[i-parity] + 1

    if frame:
        ai, ak = ak, ai
    if parity:
        aj = -aj

    ai /= 2.0
    aj /= 2.0
    ak /= 2.0
    ci = math.cos(ai)
    si = math.sin(ai)
    cj = math.cos(aj)
    sj = math.sin(aj)
    ck = math.cos(ak)
    sk = math.sin(ak)
    cc = ci*ck
    cs = ci*sk
    sc = si*ck
    ss = si*sk

    q = np.empty((4, ))
    if repetition:
        q[0] = cj*(cc - ss)
        q[i] = cj*(cs + sc)
        q[j] = sj*(cc + ss)
        q[k] = sj*(cs - sc)
    else:
        q[0] = cj*cc + sj*ss
        q[i] = cj*sc - sj*cs
        q[j] = cj*ss + sj*cc
        q[k] = cj*cs - sj*sc
    if parity:
        q[j] *= -1.0

    return q
transform_pred = torch_quat2mat(pose_pred)
# input	(B, 7)
# return (B, 3, 4)

def torch_quat2mat(pose):
	'''
	四元数转换旋转矩阵
	'''
    # Separate each quaternion value.
    q0, q1, q2, q3 = pose[:, 0], pose[:, 1], pose[:, 2], pose[:, 3]
    # Convert quaternion to rotation matrix.
    # Ref: 	http://www-evasion.inrialpes.fr/people/Franck.Hetroy/Teaching/ProjetsImage/2007/Bib/besl_mckay-pami1992.pdf
    # A method for Registration of 3D shapes paper by Paul J. Besl and Neil D McKay.
    R11 = q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3
    R12 = 2 * (q1 * q2 - q0 * q3)
    R13 = 2 * (q1 * q3 + q0 * q2)
    R21 = 2 * (q1 * q2 + q0 * q3)
    R22 = q0 * q0 + q2 * q2 - q1 * q1 - q3 * q3
    R23 = 2 * (q2 * q3 - q0 * q1)
    R31 = 2 * (q1 * q3 - q0 * q2)
    R32 = 2 * (q2 * q3 + q0 * q1)
    R33 = q0 * q0 + q3 * q3 - q1 * q1 - q2 * q2
    R = torch.stack((torch.stack((R11, R12, R13), dim=0), torch.stack((R21, R22, R23), dim=0), torch.stack((R31, R32, R33), dim=0)), dim=0)

    rot_mat = R.permute((2, 0, 1))  # (B, 3, 3)
    translation = pose[:, 4:].unsqueeze(2)  # (B, 3, 1)
    transform = torch.cat((rot_mat, translation), dim=2)
    return transform  # (B, 3, 4)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值