代码地址:https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m2_base.py
参数设置:
enc_depths=(2, 2, 6, 2)
enc_channels=(96, 192, 384, 512)
enc_groups=(12, 24, 48, 64)
enc_neighbours=(16, 16, 16, 16)
Block的定义:
class Block(nn.Module):
def __init__(
self,
embed_channels,
groups,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.0,
enable_checkpoint=False,
):
设计主函数输出每一层encoder的参数量,不包含transition layer而且只有1个block:
if __name__ == '__main__':
enc_channels=(96, 192, 384, 512)
enc_groups=(12, 24, 48, 64)
for i in range(4):
model = Block(enc_channels[i], enc_groups[i])
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(str(i) + ": " + str(pytorch_total_params/1024.0/1024.0) + "M")
结果:
0: 0.054828643798828125M
1: 0.21540069580078125M
2: 0.8537750244140625M
3: 1.51434326171875M