Windows 平台下AMD 显卡加速pytorch训练

Windows 11已经支持使用directml加速 pytorch了。

2021,11,16更新: directml-pytorch已经推出:

pip install pytorch-directml

详细教程:(4条消息) Windows下用amd显卡训练 : Pytorch-directml 重大升级,改为pytorch插件形式,兼容更好_znsoft的博客-CSDN博客_amd显卡 pytorch

-----------------------------------------------------------------------------

以 下为旧内容,依然适用,但是不推荐了。看横线以上的。

官方训练原理解释: ONNX Runtime Training Technical Deep Dive - Microsoft Tech Community

检查 支持的设备

import onnxruntime as ort
ort.get_device()

ONNX运行时(ORT)能够通过优化的后端训练现有的PyTorch模型。为此,我们为pythorch引入了一个pythorch API,称为ORTTrainer,可用于将pythorch模型的训练后端(实例torch.nn.Module)切换到orttrainer。这需要对trainer代码进行一些更改,比如替换PyTorch优化器,还可以选择设置标志来启用其他特性,比如mixed-precisiontraining。下面是一个将ONNX运行时培训集成到PyTorchpre-training脚本中的示例代码片段:

注:目前的API是实验性的,预计在不久的将来会有重大变化。我们的目标是改进接口,以提供与Pythorch训练的无缝集成,这需要对用户的训练代码进行最小的更改。



import torch
...
import onnxruntime
from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer

# Model definition
class Net(torch.nn.Module):
  def __init__(self, D_in, H, D_out):
    ...
  def forward(self, x):
    ...

model = Net(D_in, H, H_out)
criterion = torch.nn.Functional.cross_entropy
description = ModelDescription(...)
optimizer = 'SGDOptimizer'
trainer = ORTTrainer(model, criterion, description, optimizer, ...)

# Training Loop
for t in range(1000):
  # forward + backward + weight update
  loss, y_pred = trainer.train_step(x, y, learning_rate)
  ...

  • 6
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值