torch 中,张量有一个方法是 size,通过查看声明可以看出 size 的参数和返回值:
@overload
def size(self) -> Size: ...
@overload
def size(self, dim: _int) -> _int: ...
通过构造一个张量,并对该张量使用 size 方法查看返回的结果:
torch.ones(1, 2, 3, 4).size() # torch.Size([1, 2, 3, 4])
torch.ones(1, 2, 3, 4).size(-4) # 1
torch.ones(1, 2, 3, 4).size(-1) # 4
torch.ones(1, 2, 3, 4).size(1) # 2
torch.ones(1, 2, 3, 4).size(0) # 1
可以看出 size 在无参数的情况下,输出的是一个 Size 对象,内容为对应的张量的维度。当加入参数 dim 的时候,会直接输出一个 int 表示第 dim 个维度的值。
我的理解是,将 Size 对象视作一个元组,则 dim 与 Size[dim] 的效果是一样的:
@overload
def size(self) -> Size:
return self.dimension # 假设 self.dimension 是以 Size 对象形式存放的维度信息
@overload
def size(self, dim: _int) -> _int:
return self.dimension[dim] # 则 dim 的效果等同于直接返回 Size 对象的第 dim 个元素