API
开发了一个基于pytorch的胶囊网络工具包,以下是它的API:
### 胶囊化层
class Capsulation2D(nn.Module)
# input.shape = (batch_size, channels, height, weight)
# output.shape = (batch_size, out_channels, out_dim_capsule, height, weight)
### 反胶囊化层
class DeCapsulation2D(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, out_channels, height, weight)
### 平化层
class CapFlatten(nn.Module)
# input.shape = (batch_size, channels, dim_capsule, height, weight)
# output.shape = (batch_size, channels * height * weight, dim_capsule) which is (batch_size, num_capsules, dim_capsule)
### 反平化层