target_onehot=torch.zeros(target.shape[0],10)如何理解?
首先来康两行代码:
target_onehot=torch.zeros(target.shape[0],10)
target_onehot.scatter_(1,target.unsqueeze(1),1.0),target_onehot.shape
这两行代码实现的功能是离散量one-hot编码
第一行:
参数分别是target.shape[0],10。下文知道target.shape[0]=4898。所以target_onehot的大小是4898X10.
target=wineq[:,-1]#选择所有行的最后一列
target,target.shape
#target的信息
(tensor(