在 PyTorch 中,argparse
模块通常用于解析命令行参数,使得脚本更加灵活和可配置。通过 argparse
,你可以为脚本添加参数,使得在运行时可以轻松地调整超参数、文件路径等配置项,而无需修改代码本身。
1. 安装和导入 argparse
argparse
是 Python 标准库的一部分,因此无需额外安装。你只需在脚本中导入它即可:
import argparse
2. 创建解析器
你需要创建一个 ArgumentParser
对象,它用于存储将要解析的所有信息:
parser = argparse.ArgumentParser(description="PyTorch Training Script")
3. 添加参数
使用 add_argument
方法为解析器添加参数。你可以指定参数名称、类型、默认值和帮助信息:
parser.add_argument('--batch-size', type=int, default=64, help='input batch size for training (default: 64)')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train (default: 10)')
parser.add_argument('--learning-rate', type=float, default=0.01, help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
4. 解析参数
在所有参数添加完毕后,调用 parse_args
方法来解析命令行参数:
args = parser.parse_args()
5. 使用参数
解析后的参数可以通过 args
对象来访问,并在脚本中使用。例如:
import torch
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
# 设置随机种子
set_seed(args.seed)
# 检查是否使用 CUDA
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# 打印配置信息
print(f"Using device: {device}")
print(f"Batch size: {args.batch_size}")
print(f"Learning rate: {args.learning_rate}")
# 示例:创建数据加载器
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True
)
# 示例:创建模型和优化器
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum)
# 示例:训练循环
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
总结
通过使用 argparse
,你可以使得 PyTorch 脚本更加灵活和可配置,便于调整训练参数、超参数和其他配置项。这样不仅提高了代码的可读性和可维护性,还方便了实验的管理和复现。