AMD显卡GPU训练模型

一、背景

现在NVIDIA显卡做模型训练非常方便,有很多开源项目做加速。但是我的显卡是A卡,这就很无奈,不支持cuda,所以我也从网上找了下如何使用A卡GPU进行训练。我的显卡是RX 5700,显存8G。

二、DirectML

DirectML是微软发布的一套基于DirectX12的机器学习底层推理API,具有与DirectX12接口相似的风格。 所有与DirectX12兼容的硬件显卡都可以支持。 包括:Intel GPU,AMD GPU,NVIDIA GPU。
链接: DirectML
利用DirectML,我就可以结合pytorch,将计算过程都放到GPU上训练,模型参数也可以放到显存上,同样也支持TensorFlow 。

三、使用

1. 环境

python3.8到Python3.10都可以。

pip install torch-directml

我的环境:

	python 						  3.9.17
	torch                         2.0.1
	torch-directml                0.2.0.dev230426
	torchdata                     0.6.1
	torchtext                     0.15.2
	torchvision                   0.15.1

2. demo

device获取

import torch_directml
dml = torch_directml.device()
device = dml

resnet101测试

# resnet101
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch_directml
# import apex.amp as amp

dml = torch_directml.device()
# Set device to GPU
device = dml
# Define batch size
batch_size = 64
# Enable mixed precision training if desired
use_amp = False
# Define transforms for the data
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define ResNet101 model
class ResNet101(nn.Module):
    def __init__(self):
        super(ResNet101, self).__init__()
        self.resnet101 = torchvision.models.resnet101(pretrained=False)
        self.fc = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.resnet101(x)
        x = self.fc(x)
        return x

# Initialize model and move it to the GPU
model = ResNet101().to(device)

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Define loss function
criterion = nn.CrossEntropyLoss()

# if use_amp:
#     model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

# Train the model
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        if use_amp:
            print("use amp")
            # with amp.scale_loss(loss, optimizer) as scaled_loss:
            #     scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

# Test the model
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

如果显存不够,可以将batch_size再调小点。
训练过程:

Files already downloaded and verified
Files already downloaded and verified
f:\software\conda\install\envs\dl\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
f:\software\conda\install\envs\dl\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.
  warnings.warn(msg)
[1,   100] loss: 2.358
[1,   200] loss: 2.233
[1,   300] loss: 2.153
[1,   400] loss: 2.086
[1,   500] loss: 2.074
[1,   600] loss: 2.035
[1,   700] loss: 1.991
  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值