其中主要通过设置随机数种子实现神经网络训练可多次复现。
一、关键代码:
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
torch.cuda.manual_seed(SEED)
self.train_loader = DataLoader(self.train_dataset, cfg['data']['batch_size'],shuffle=False, num_workers=self.num_workers)
二、关键位置
在代码开头就加入所有SEED相关代码。
除此之外检查
1、所有random函数出现处
2、数据加载项
3、网络初始化处