可复现的 PyTorch

文章讨论了如何在Pytorch中保证实验结果的可复现性,主要涉及设置随机种子(包括Python,Numpy,Pytorch以及DataLoader的worker_init_fn和generator),禁用cuDNN的benchmark模式并启用deterministic算法,以确保不同环境下结果的一致性。
摘要由CSDN通过智能技术生成

这篇文章来探讨一下,如何在 Pytorch 中保证实验结果的可复现性?需要指出的是,要达到这一目的,可能会降低模型的运行速度。


设置随机种子

保证结果可复现的第一步是设置随机种子。
计算机生成的随机数是伪随机数。它通过一个初始值来产生一个随机序列,如果初始值是不变的,那么运行多次产生的随机序列也就是相同的。这个初始值一般就称为种子。

Python, Numpy, Pytorch 的随机数生成算法采用的是 Mersenne Twister 算法。这个算法的输入只有一个初始化值也不需要其他的环境信息。因此无论在任何机器、任何系统上,只要 PyTorch 的版本一致(算法部分没有改变)并且设置了随机种子,那么调用随机过程所产生的随机数就是一致的。

要注意的是:np.random.seed 只影响 NumPy 的随机过程,torch.manual_seed 也只影响 PyTorch 的随机过程。由此,程序中所有依赖 Mersenne Twister 算法产生随机数的包,都需要手动设置随机种子,才能使整个程序的随机性是可复现的。——Seed Everything - 可复现的 PyTorch(一)

设置随机种子:

np.random.seed(0)
torch.manual_seed(0) # seed the RNG for all devices (both CPU and CUDA)

多进程的随机性

DataLoader 的参数 num_workers 大于零时,Pytorch 会开启多个子进程,同时进行数据加载。多进程加载时,如何保证结果的可复现性?——可以利用 DataLoader 的两个参数 worker_init_fngenerator
这里只放上代码,具体原理可参考专门介绍 Dataloader 的文章:Pytorch Dataloader 详解

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

DataLoader(
    train_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=seed_worker,
    generator=g,
)


cuDNN benchmark 模式

torch.backends.cudnn.benchmark = True 时,Pytorch 会根据网络结构搜索最适合它的卷积实现算法,提升卷积神经网络的运行速度。

卷积前向传播的实现有许多种实现方式。benchmark 模式下,程序会在开始时花费一点额外时间,为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。
适用场景:网络结构固定,输入形状不变。具体来说,输入的 batch size,宽和高,输入通道的个数;卷积层本身的参数,包括卷积核大小,stride,padding ,输出通道的个数,这些都要固定,才能实现较好的加速效果。否则可能得不偿失。——torch.backends.cudnn.benchmark ?!

但同时,当 torch.backends.cudnn.benchmark = True 时,我们无法控制程序具体采用了哪一种卷积算法。因此要保证可复现性,需要 Disable benchmark 模式,让卷积操作采用固定的算法(默认情况):

torch.backends.cudnn.benchmark = False # force cuDNN to deterministically select an convolution algorithm

但是,选定的卷积算法可能本身是 non deterministic。所以还需要规定 Pytorch 只能使用 deterministic 的卷积算法:

torch.backends.cudnn.deterministic = True 

不仅是卷积算法,我们要规定所有运算只能使用 deterministic 算法(但如果某种运算没有相应的 deterministic 算法实现,就会报错):

torch.use_deterministic_algorithms(True)

设置上面这行代码后,就不再需要设置 torch.backends.cudnn.deterministic 了。

可以在 官方文档 查看哪些运算实现了确定性算法、哪些运算只有非确定性算法。


参考:

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值