1.None 空值,如果对一个对象判断是否存在就用 is None/is not None, 而不是 != None
2. 自定义dataset类时候的数据转换问题,自定义dataset类后,对__getitem__(self, index)函数重载,将数据转换为tensor时候在该函数内进行报错
错误原因:第一次使用__getitem__(self, index)时,所有对象都被转换成tensor类了,后续再次调用时候相当于把已经由numpy类转换成的tensor类再去转换为tensor,指定报错。
修改后的代码:
class myDataset(torch.utils.data.Dataset):
def __init__(self, feature,clean_labels,noise_labels = None ):
self.feature = feature
self.clean_labels = clean_labels
# 将数据转化为Tensor格式
self.feature = torch.Tensor(self.feature).type(torch.float32)
self.clean_labels = torch.from_numpy(self.clean_labels).type(torch.LongTensor)
if noise_labels is not None:
self.noise_labels = noise_labels
self.noise_labels = torch.from_numpy(self.noise_labels).type(torch.LongTensor)
else:
self.noise_labels = None
def __len__(self):
return len(self.clean_labels)
def __getitem__(self, index):
x = self.feature[index]
clean_y = self.clean_labels[index]
#如果noise 标签存在就将noise标签返回
if self.noise_labels is not None:
noise_y = self.noise_labels[index]
return x, [clean_y,noise_y]
else:
return x, clean_y