UV-Net的复现

本论文主要复现了UV-Net: Learning from Boundary Representations论文中的UV-Net
官网代码地址:https://github.com/AutodeskAILab/UV-Net

SolidLetters数据集介绍

SolidLetters数据集由96k个3D形状组成。
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
train.txt和test.txt存的是训练和测试的标签名称,是一个字符串。
在这里插入图片描述

环境配置

conda create -n uv_net python=3.9
pip install pytorch-lightning==1.6.4
pip install scikit-learn
conda install -c dglteam dgl
pip install pandas
上面的都安装完后classification.py就可以运行了。
上面那种安装有很多没有指定其版本故若要进行GPU版本的pytorch运算可能不太行。若你只是想能跑或者自己本身的CUDA版本够高的话运行起来是没问题的。

一个批次加载的是?

import argparse#用来解析命令行参数的模块
import pathlib#用于处理文件路径的模块
import time
from pytorch_lightning import Trainer
#它封装了模型的训练过程,使得用户可以不必编写标准的训练循环代码。通过使用 Trainer,您只需要定义模型、数据加载和优化器,剩下的训练/验证/测试循环、日志记录、模型保存等都由 Trainer 自动管理。
from pytorch_lightning.callbacks import ModelCheckpoint#是一个回调函数(Callback),用于在训练过程中保存模型。
from pytorch_lightning.loggers import TensorBoardLogger#是另一个回调函数,用于将训练过程中的数据记录到 TensorBoard 中。
from pytorch_lightning.utilities.seed import seed_everything#用于设置随机种子,以确保实验的可重复性。
#PyTorch Lightning 是一个用于训练和研究深度学习模型的轻量级框架

from datasets1.solidletters import SolidLetters
from uvnet.models import Classification

parser = argparse.ArgumentParser("UV-Net solid model classification")
#创建了一个ArgumentParser对象,这是argparse模块的主体,用于处理命令行参数。括号内的字符串 "UV-Net solid model classification" 是程序的描述
parser.add_argument(
    "traintest", choices=("train", "test"), help="Whether to train or test"
)
#它是必须指定的,并且只能是 "train" 或 "test" 中的一个
parser.add_argument("--dataset", choices=("solidletters",), help="Dataset to train on")
#它指定了要在哪个数据集上进行训练。目前只有一个选项 "solidletters"
parser.add_argument("--dataset_path", type=str, help="Path to dataset")
#它需要一个字符串值,用于指定数据集的路径
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
#它接受一个整数值,指定了批处理大小,默认值是64
parser.add_argument(
    "--num_workers",
    type=int,
    default=0,
    help="Number of workers for the dataloader. NOTE: set this to 0 on Windows, any other value leads to poor performance",
)
#用于指定数据加载时使用的工作进程数量。
parser.add_argument(
    "--checkpoint",
    type=str,
    default=None,
    help="Checkpoint file to load weights from for testing",
)
#此参数允许用户指定一个检查点文件的路径,该文件包含了模型的权重。这主要用于测试阶段,可以从先前的训练阶段保存的检查点中加载模型权重,以便进行评估或继续训练。如果不提供此参数(即保持默认值None),则不会从检查点加载权重。
parser.add_argument(
    "--experiment_name",
    type=str,
    default="classification",
    help="Experiment name (used to create folder inside ./results/ to save logs and checkpoints)",
)
#,用于存储当前实验的日志文件和检查点。
parser = Trainer.add_argparse_args(parser)

#PyTorch Lightning Trainer类所支持的所有命令行参数添加到已存在的argparse.ArgumentParser对象中。
#这样做的好处是,你可以在命令行中指定Trainer相关的所有设置(比如GPUs数量、epochs数等),而无需在代码中硬编码这些设置。
#简单理解为当前的命令行参数有了Trainer的命令行参数
args = parser.parse_args()
#解析命令行提供的参数,并将解析后的参数保存在args变量中
results_path = (
    pathlib.Path(__file__).parent.joinpath("results").joinpath(args.experiment_name)
)
"""
pathlib.Path(__file__)获取当前脚本文件的路径。
.parent获取该路径的父目录,即脚本所在的目录。
.joinpath("results")在父目录下添加一个名为results的子目录。
.joinpath(args.experiment_name)在results目录下进一步添加以实验名称命名的子目录。
"""
if not results_path.exists():
    results_path.mkdir(parents=True, exist_ok=True)

# Define a path to save the results based date and time. E.g.
# results/args.experiment_name/0430/123103
month_day = time.strftime("%m%d")
hour_min_second = time.strftime("%H%M%S")
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=str(results_path.joinpath(month_day, hour_min_second)),
    filename="best",
    save_last=True,
)#通过使用ModelCheckpoint回调,你可以自动地保存训练过程中的关键模型,无需手动干预

trainer = Trainer.from_argparse_args(
    args,
    callbacks=[checkpoint_callback],
    logger=TensorBoardLogger(
        str(results_path), name=month_day, version=hour_min_second,
    ),
)

if args.dataset == "solidletters":
    Dataset = SolidLetters
else:
    raise ValueError("Unsupported dataset")

if args.traintest == "train":
    # Train/val
    seed_everything(workers=True)
    model = Classification(num_classes=Dataset.num_classes())
    train_data = Dataset(root_dir=args.dataset_path, split="train")
    val_data = Dataset(root_dir=args.dataset_path, split="val")
    train_loader = train_data.get_dataloader(
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
    )
    print("66666666666666666666666666666666666666666666666666")
    for element in train_loader:#遍历迭代器
        print(element)
        print(element['graph'])
        print("66666666666666666666666666666666666666666666666666666666")

在这里插入图片描述
可以看到迭代器返回的是一个字典序列。

  • ‘graph’: 对应一个图
  • ‘filename’:文件名称
  • ‘label’:文件所对应的标签。(每一个文件对应一个标签其值是一个0-25的整数)

每一个点的特征所对应的维度是:(10,10,7)

每一条边的特征所对应的维度是: (10,6)

当batch_size==2时候,我发现这俩小图直接弄到一块了,这俩图的编号重新编了一下。其实就是对于图Batch的处理。

datasets

base.py

"""
该文件主要定义了一个BaseDataset()类
BaseDataset()类里面的函数主要作用有
+ 加载数据
+ 缩放数据
+ 数据格式变换
+ 数据的长度
+ 将数据转化成一个个的batch
+ 返回指定索引的数据
"""
from torch.utils.data import Dataset, DataLoader#从torch.utils.data模块导入Dataset, DataLoader这俩类
from torch import FloatTensor#torch 库中导入 FloatTensor 类。FloatTensor 类是 torch 库中用来表示浮点数张量(tensor)的数据类型
import dgl#导入 DGL库
#DGL 是一个用于图神经网络的开源库,提供了丰富的图神经网络模型和工具,方便用户进行图数据的建模和分析。
from dgl.data.utils import load_graphs#DGL(Deep Graph Library)库的 data.utils 模块中导入 load_graphs 函数。load_graphs 函数通常用于加载图数据集
from datasets import util#从datasets的库中导入 util 模块
from tqdm import tqdm#从tqdm 库中导入tqdm模块
from abc import abstractmethod
#abc模块中导入abstractmethod装饰器。abstractmethod 装饰器用于声明一个抽象方法,该方法在子类中必须被实现,否则会导致子类无法实例化。通过使用 abstractmethod 装饰器,可以定义抽象基类(Abstract Base Class)和规范子类需要实现的方法。

class BaseDataset(Dataset):
    #定义了一个BaseDataset类该类继承了Dataset类
    @staticmethod
    #修饰类方法,静态方法。不传入代表实例对象的self参数,并且不强制要求传递任何参数,可以被类直接调用。静态方法是独立于类的一个单独函数,只是寄存在一个类名下。
    #静态方法就是类对外部函数的封装,有助于优化代码结构和提高程序的可读性。
    #简言之静态方法就是一个独立的函数只不过在类内罢了
    @abstractmethod
    #抽象方法。用于程序接口的控制。含有abstractmethod 方法的类不能实例化,继承了含abstractmethod方法的子类必须复写所有abstractmethod装饰的方法,未被装饰的不重写。
    #简言之有这个方法的类不能实例化,其子类必须重写该函数
    def num_classes():
        pass

    def load_graphs(self, file_paths, center_and_scale=True):
        #file_paths: 这是一个文件路径列表,即包含多个文件路径
        #center_and_scale表示是否需要对数据进行居中和缩放处理
        self.data = []
        for fn in tqdm(file_paths):#弄一个进度条
            if not fn.exists():#如果文件不存在,则会跳过当前文件的处理,继续处理下一个文件
                continue
            sample = self.load_one_graph(fn)#加载一个图
            if sample is None:
                continue
            if sample["graph"].edata["x"].size(0) == 0:
                # Catch the case of graphs with no edges
                continue
            self.data.append(sample)#放到数据列表里
        if center_and_scale:
            self.center_and_scale()
        self.convert_to_float32()
    
    def load_one_graph(self, file_path):#这里加载的实际上就是.bin文件
        #该函数的作用是加载一张图
        #输入:一个文件路径
        #输出:该路径的第一个图和名称
        graph = load_graphs(str(file_path))[0][0]
        sample = {
   "graph": graph, "filename": file_path.stem}
        #用于获取路径的基本名称(不包含扩展名),可以理解为这里加载了一张图和其对应的标签
        return sample

    def center_and_scale(self):#对加载的图数据进行居中和缩放处理,确保数据合理范围内
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"], center, scale = util.center_and_scale_uvgrid(
                self.data[i]["graph"].ndata["x"], return_center_scale=True
            )
            self.data[i]["graph"].edata["x"][..., :3] -= center
            self.data[i]["graph"].edata["x"][..., :3] *= scale

    def convert_to_float32(self):#将加载的图数据转换为FloatTensor类型。
        for i in range(len(self.data)):
            self.data[i]["graph"].ndata["x"] = self.data[i]["graph"].ndata["x"].type(FloatTensor)
            self.data[i]["graph"].edata["x"] = self.data[i]["graph"].edata["x"].type(FloatTensor)

    def __len__(self):#返回self.data的长度
        return len(self.data)

    def __getitem__(self, idx):#返回指定索引处的图数据样本。
        sample = self.data[idx]
        if self.random_rotate:
            rotation = util.get_random_rotation()
            sample["graph"].ndata["x"] = util.rotate_uvgrid(sample["graph"].ndata["x"], rotation)
            sample["graph"].edata["x"] = util.rotate_uvgrid(sample["graph"].edata["x"], rotation)
        return sample

    def _collate(self, batch):#_collate:将多个图数据样本合并为一个批次,并返回该批次的图数据和文件名
        batched_graph = dgl.batch([sample["graph"] for sample in batch])
        batched_filenames = [sample["filename"] for sample in batch]
        return {
   "graph": batched_graph, "filename": batched_filenames}

    def get_dataloader(self, batch_size=128, shuffle=True, num_workers=0):#返回一个数据生成器,该生成器可以迭代地生成数据批次,以便进行训练或测试。
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self._collate,
            num_workers=num_workers,  # Can be set to non-zero on Linux
            drop_last=True,
        )

util.py

"""
该文件主要定义了一些用到的函数,将其放在一个文件里作为一个工具包来使用。
"""
import random
import numpy as np
import torch
from scipy.spatial.transform import Rotation
#Rotation是一个用于表示和操作旋转的类。它提供了许多用于创建、组合和应用旋转的方法。


def bounding_box_uvgrid(inp: torch.Tensor):
    pts = inp[..., :3].reshape((-1, 3))#获得这个二维张量的前3列,并且将其重新形状为(-1, 3)的二维数组,其中-1表示自动计算该维度的大小。
    mask = inp[..., 6].reshape(-1)
    #因为下标是从0开始的,故6实际上指的就是第7列。所以这句就是获得第7列的数据,也就是掩码,并且将其重新形状为一维数组。
    point_indices_inside_faces = mask == 1#mask==1 会检查mask数组中的每个位置是不是1
    pts = pts[point_indices_inside_faces, :]#取出掩码为1的那些向量
    return bounding_box_pointcloud(pts)


def bounding_box_pointcloud(pts: torch.Tensor):#计算给定点云的边界框
    x = pts[:, 0]
    y = pts[:, 1]
    z = pts[:, 2]
    box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]]
    #取前3列取每一列的最小值和最大值。
    return torch.tensor(box)


def center_and_scale_uvgrid(inp: torch.Tensor, return_center_scale=False):
    bbox = bounding_box_uvgrid(inp)#bbox是(2,3)大小的
    diag = bbox[1] - bbox[0]#计算边界框的对角线长度 diag,并计算缩放比例 scale,使得经过缩放后的 uv 网格能够被包含在一个以原点为中心、边长为 2 的立方体内。
    scale = 2.0 / max(diag[0], diag[1], diag[2])
    center = 0.5 * (bbox[0] + bbox[1])#求出中心点
    inp[..., :3] -= center#计算 uv 网格的中心点 center,并将 uv 网格中的点坐标减去中心点、乘以缩放比例的操作,实现中心化和缩放。
    inp[..., :3] *= scale
    if return_center_scale:
        return inp, center, scale
    return inp


def get_random_rotation():#生成一个随机旋转的函数
    """Get a random rotation in 90 degree increments along the canonical axes"""
    axes = [
        np.array([1, 0, 0]),
        np.array([0, 1, 0]),
        np.array([0, 0, 1]),
    ]#定义了三个轴向量 axes,分别对应于三个坐标轴的单位向量
    angles = [0.0, 90.0, 180.0, 270.0]
    axis = random.choice(axes)#从axes随机选一个(这里因为是二维的故就是随机选一行),
    angle_radians = np.radians(random.choice(angles))#选择一个角度,并将角度值转换为弧度
    return Rotation.from_rotvec(angle_radians * axis)#函数创建一个旋转矩阵,并将其作为结果返回。


def rotate_uvgrid(inp, rotation):#它接受一个输入张量 inp 和一个旋转矩阵 rotation,然后将图中的节点特征按给定的旋转进行旋转。
    """Rotate the node features in the graph by a given rotation"""
    Rmat = torch.tensor(rotation.as_matrix()).float()
    orig_size = inp[..., :3].size()
    inp[..., :3] = torch.mm(inp[..., :3].view(-1, 3), Rmat).view(
        orig_size
    )  # Points
    inp[..., 3:6] = torch.mm(inp[..., 3:6].view(-1, 3), Rmat).view(
        orig_size
    )  # Normals/tangents
    return inp

#无效的字体列表
INVALID_FONTS = [
    "Bokor",
    "Lao Muang Khong",
    "Lao Sans Pro",
    "MS Outlook",
    "Catamaran Black",
    "Dubai",
    "HoloLens MDL2 Assets",
    "Lao Muang Don",
    "Oxanium Medium",
    "Rounded Mplus 1c",
    "Moul Pali",
    "Noto Sans Tamil",
    "Webdings",
    "Armata",
    "Koulen",
    "Yinmar",
    "Ponnala",
    "Noto Sans Tamil",
    "Chenla",
    "Lohit Devanagari"<
  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值