pytorch 获取 tensor 的方法有两种:shape 和 size()
tensor 是类 Tensor() 的实例, 其中shape是其属性,而 size() 是其继承的方法,两者均可以获得 tensor 的维度。
import torch
X=torch.tensor([[1,2,3,4],[2,3,4,5],[4,6,7,8]])
print(X.shape)#torch.Size([3, 4])
print(X.size()) #torch.Size([3, 4])
print(X.shape[0]) #3 行数
print(X.shape[1]) #4 列数
print(X.size(0)) #3 行数
print(X.size(1)) #4 列数
shape是属性,使用中括号[]
size是函数,使用()