保证特定平台上的可复现性主要包括两点:(1) Controlling sources of randomness;(2) configure PyTorch to Avoid using nondeterministic algorithms for some operations (但这可能会使得算法更慢)
Controlling sources of randomness
Random number generator
defset_random_seed(seed):# Random number generators in other libraries# If you are using any other libraries that use random number generators, refer to the documentation for those libraries to see how to set consistent seeds for them.
np.random.seed(seed)# Python
random.seed(seed)# PyTorch random number generator
torch.manual_seed(seed)if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
In some versions of CUDA, RNNs and LSTM networks may have non-deterministic behavior. See torch.nn.RNN() and torch.nn.LSTM() for details and workarounds.
DataLoader
当 num_workers > 0 时需要进行如下设置:DataLoader will reseed workers following Randomness in multi-process data loading algorithm. Use worker_init_fn() and generator to preserve reproducibility (Make sure that your dataloader loads samples in the same order every call.):
defset_random_seed(seed):# Random number generators in other libraries# If you are using any other libraries that use random number generators, refer to the documentation for those libraries to see how to set consistent seeds for them.
np.random.seed(seed)# Python
random.seed(seed)# PyTorch random number generator
torch.manual_seed(seed)if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)