pytorch 复现训练结果

最近工作需要,查了一些资料,如何复现训练结果,现做个记录,如有不完整,还请指教

1、设置随机种子

torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

2、关于cudnn

2.1
torch.backends.cudnn.enabled = False
#禁用cudnn,运算速度会明显下降

2.2
torch.backends.cudnn.benchmark = True
#大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法, 
#来达到优化运行效率的问题

#前提条件是输入的数据维度比较一致,如果网络的输入数据在每次 iteration 都变化的话,会导致 cnDNN 每
#次都会去寻找一遍最优配置,这样反而会降低运行效率

2.3
torch.backends.cudnn.deterministic = True
#可以将cudnn中具有随机性的操作固定

3、some issues about reproducing on pytorch below the url:

https://discuss.pytorch.org/t/how-to-get-deterministic-behavior/18177

https://discuss.pytorch.org/t/non-reproducible-result-with-gpu/1831

好的,我明白了。您想了解如何使用 PyTorch 复现一个深度学习模型,是吗? 首先,您需要确保已经安装了 PyTorch。然后,您需要编写模型的代码,并将其保存在一个 Python 文件中。接下来,您可以使用 PyTorch 提供的数据加载器来加载训练数据和测试数据。最后,您可以使用 PyTorch 提供的优化器和损失函数来训练模型。 以下是一个简单的示例,展示了如何使用 PyTorch 复现一个简单的线性回归模型: ```python import torch import torch.nn as nn import torch.optim as optim # 定义模型 class LinearRegression(nn.Module): def __init__(self): super(LinearRegression, self).__init__() self.linear = nn.Linear(1, 1) def forward(self, x): out = self.linear(x) return out # 准备数据 x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]]) y_train = torch.tensor([[2.0], [4.0], [6.0], [8.0]]) # 初始化模型、损失函数和优化器 model = LinearRegression() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) # 训练模型 for epoch in range(1000): optimizer.zero_grad() outputs = model(x_train) loss = criterion(outputs, y_train) loss.backward() optimizer.step() # 测试模型 x_test = torch.tensor([[5.0], [6.0], [7.0], [8.0]]) y_test = model(x_test) print(y_test) ``` 这个示例中,我们定义了一个简单的线性回归模型,使用均方误差作为损失函数,使用随机梯度下降作为优化器。我们使用 PyTorch 提供的数据类型 `torch.tensor` 来准备数据,并使用 `model.parameters()` 来获取模型的参数。在训练过程中,我们使用 `optimizer.zero_grad()` 来清除梯度,使用 `loss.backward()` 来计算梯度,使用 `optimizer.step()` 来更新参数。在测试过程中,我们使用训练好的模型来预测新的数据。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值