Pytorch 需要 int 类型的数据,
1.astype
import numpy as np
arr = np.array([1,2,3,4,5]).astype(int)
print(arr.dtype)
## int32
输出的int32 并不是 pytorch 需要的类型
2.np.int64()
import numpy as np
arr = np.int64(np.array([1,2,3,4,5]))
print(arr.dtype)
## int64
输出的 int64 是pytorch需要的类型