数据类型numpy和tensor
当pytorch提出数据类型不对的时候,可以查看数据类型以及修改成正确的类型。废话少说:直接上代码
train_x,test_x,train_y,test_y=train_test_split(X,Y)
print(type(train_x))
print(train_x.shape)
train_x = torch.from_numpy(train_x).type(torch.float32)
train_y = torch.from_numpy(train_y).type(torch.int64)
test_x = torch.from_numpy(test_x).type(torch.float32)
test_y = torch.from_numpy(test_y).type(torch.LongTensor)
print(type(train_x))
print(train_x.shape)
这样numpy的数据类型和tensor的数据类型就一目了然了,并且两者也可以相互转换。
输出的结果如下:
<class 'numpy.ndarray'>
(112, 4)
<class 'torch.Tensor'>
torch.Size([112, 4])