传入参数解析的工具
更规范地写代码,提高代码的复用性。举例说明:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
args = parser.parse_args()
# set random seeds
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
- Python标准库:通过
random.seed
确保使用标准库生成的随机数一致。 - NumPy:通过
np.random.seed
确保使用NumPy生成的随机数一致。 - PyTorch (CPU):通过
torch.manual_seed
确保使用PyTorch在CPU上生成的随机数一致。 - PyTorch (GPU):通过
torch.cuda.manual_seed_all
确保使用PyTorch在所有GPU上生成的随机数一致。
但seed一般不用传入,没有设置required=True,即不是必须要传入的参数,此处知识举个例子。