这里写目录标题
前言
https://github.com/skanti/Scan2CAD
文章项目地址在这里
可以下载文章和代码
因为要写网络方面的注释,所以干脆发到这里好了。
代码运行顺序
可视化
后补
编译
后补
转换网络输入量
后补
网络预测输出heatmap
figure4和figure10是论文中给出的网络结构图
下面展示一些 关键代码注释
。
1.运行run.sh
里面衔接main.py中的命令行argparse模块提供了一些初始参数。
2.main.py
gkern3d_dim32 = torch.FloatTensor(torch.from_numpy(kernels.gaussian3d(7, 1.5))).view(1, 1, 7, 7, 7).cuda()
其中gaussian3d是kernels.py中的一个函数,生成一个7x7x7的高斯分布的数组,sigma=1.5,再转换成tensor,用于后续网络输出时用来计算。
model = modelnn.Model3d().cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=0.0005)
n_model_params = int(sum([np.prod(p.size()) for p in filter(lambda p: p.requires_grad, model.parameters())]))
这里导入模型,并设置了优化器、学习速率等等,后续再详写。模型文件存放在model.py中。
3. model.py
第一个看一下总的类Model3d
class Model3d(nn.Module):
def __init__(self):
super(Model3d, self).__init__()
self.encode0 = Encode0()
self.encode1 = Encode1()
self.decode = Decode()
self.bottleneck = Bottleneck()
self.feature2heatmap = Feature2Heatmap()
self.feature2scale = Feature2Scale()
self.feature2float = Feature2Float()
# -> freeze
for param in self.encode1.block0.parameters():
param.requires_grad = False
for param in self.encode1.block1.parameters():
param.requires_grad = False
for param in self.encode1.block2.parameters():
param.requires_grad = False
# <-
def forward(self, x, y):
assert len(x.shape) == 5
assert len(y.shape) == 5
batch_size = y.shape[0]
# -> encode0 and encode1
x = self.encode0(x)
y = self.encode1(y)
# <-
# -> concat
z_bottleneck = torch.cat((x, y), 1)
z_bottleneck = self.bottleneck(z_bottleneck)
# <-
# -> heatmap
z = self.decode(z_bottleneck)
z = self.feature2heatmap(z)
# <-
# -> classify
match = self.feature2float(