numpy array 类型
# -*- coding: utf-8 -*-
import torch
import numpy as np
a = np.array([[1, 2, 3, 4], [1, 2, 3, 4]])
print(len(a))
print(a.size)
print(a.shape)
输出为:
2
8
(2, 4)
对于 numpy array 来说,shape、size 都是 array 的属性;
pytorch tensor 类型
tensor = torch.rand(3,4)
print(f'shape of tensor: {tensor.shape}')
print(f'size of tensor: {tensor.size()}')
print(f'Datatype of tensor: {tensor.dtype}')
print(f'len of tensor: {len(tensor)}')
输出为:
shape of tensor: torch.Size([3, 4])
size of tensor: torch.Size([3, 4])
Datatype of tensor: torch.float32
len of tensor: 3
对于 tensor 来说,shape 是 tensor 的属性;
size() 是tensor 的方法;