在 PyTorch 中,size()
是用于获取张量的大小(维度)的方法。具体用法如下
tensor.size()
返回一个包含张量各个维度大小的元组。例如
import torch
# 创建一个二维张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 获取张量的大小
size = tensor.size()
print(size)
输出:
torch.Size([2, 3])
在这个例子中,size()
返回一个包含两个元素的元组,表示该张量有两个维度,第一个维度大小是 2,第二个维度大小是 3。