在用pytorch写Dataset类时,以下写法有问题:
from torch.utils.data import Dataset
class H5Dataset(Dataset):
def __len__(self):
return self.raw.shape[0] * 10
def __getitem__(self, index):
index = index % self.raw.shape[0]
return self.raw[index]
以上代码省略了初始化部分,假设batch_size=4,self.raw.shape[0]=155,即训练样本有155个,上述代码是想在一个epoch里对155个训练样本遍历10次,当一个batch中随机产生的index都是155的整数倍时(如0,155,310,465),index % 155都为0,此时,一个batch中的4个样本都一样,BN会出现问题,loss出现NAN。
所以,最好这样写:
from torch.utils.data import Dataset
class H5Dataset(Dataset):
def __len__(self):
return self.raw.shape[0]
def __getitem__(self, index):
return self.raw[index]
再把epoch的个数设置得多点就可以了。