PyTorch基础知识讲解(一)完整训练流程示例

Tutorial

大多数机器学习工作流程涉及处理数据、创建模型、优化模型参数和保存训练好的模型。本教程向你介绍一个用PyTorch实现的完整的ML工作流程,并提供链接来了解这些概念中的每一个。

我们将使用FashionMNIST数据集来训练一个神经网络,预测输入图像是否属于以下类别之一。T恤/上衣、长裤、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包或踝靴。

1. 数据处理

PyTorch有两个处理数据的方法:Torch.utils.data.DataLoaderTorch.utils.data.Dataset。Dataset存储了样本及其相应的标签,而DataLoader则围绕Dataset包装了一个可迭代的数据。

torchvision.datasets模块包含了许多真实世界的视觉数据的数据集对象,如CIFARCOCO。在本教程中,我们使用FashionMNIST数据集。每个TorchVision数据集都包括两个参数:transformtarget_transform,分别用来修改样本和标签。

import torch
from torchvision import datasets
from torchvision import transforms

train_data = datasets.FashionMNIST(root="D:/datasets/DL/", 
                                   train=True,
                                   download=True,
                                   transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root="D:/datasets/DL/", 
                                train=False,
                                download=True,
                                transform=transforms.ToTensor())

我们将数据集作为参数传递给DataLoader。这在我们的数据集上包裹了一个可迭代的数据集,并支持自动批处理、采样、洗牌和多进程数据加载。在这里,我们定义了一个64的批处理大小,即dataloader可迭代的每个元素将返回一个批次,包括64个元素的特征和标签。

from torch.utils.data import DataLoader

batch_size = 64
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(train_data, batch_size=batch_size)

for X,y in train_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break
Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])
Shape of y:  torch.Size([64]) torch.int64

2. 网络模型定义

为了在PyTorch中定义一个神经网络,我们创建一个继承自nn.Module的类。我们在__init__函数中定义网络的层,并在forward函数中指定数据将如何通过网络。为了加速神经网络的操作,如果有GPU的话,我们把它移到GPU上。

输入是28*28, 输出包含10个类

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")

from torch import nn
class MyRSNN(nn.Module):
    def __init__(self, n_in=28*28, n_out=10):
        super(MyRSNN, self).__init__()
        self.flat_layer = nn.Flatten()
        self.n_hidden = 64
        self.network = nn.Sequential(
            nn.Linear(n_in, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, n_out)
        )
    def forward(self, X):
        return self.network(self.flat_layer(X))
Using device cuda
model = MyRSNN().to(device)
model
MyRSNN(
  (flat_layer): Flatten(start_dim=1, end_dim=-1)
  (network): Sequential(
    (0): Linear(in_features=784, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=10, bias=True)
  )
)

3. 损失函数、模型优化、模型训练、模型评价

损失函数就是一个函数,用来评价模型预测的结果和真实结果之间的差距,优化器提供了一个算法,通过不断的调节参数使得预测结果和真实结果的差距越来越小

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(model, dataloader, lf, opt):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        # compute loss
        y_pred = model(X)
        loss = lf(y_pred, y)
        # back propagation
        opt.zero_grad()
        loss.backward()
        opt.step()
        if batch % 200 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
def test(model, dataloader, lf):
    size = len(dataloader.dataset)
    batch_num = len(dataloader)
    model.eval()
    loss, acc = 0., 0.

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            y_pred = model(X)
            loss += lf(y_pred, y).item()
            acc += (y_pred.argmax(1) == y).type(torch.float).sum().item()
    avg_loss = loss/batch_num
    avg_acc = acc/size
    print(f"Test Error: \n Accuracy: {100*avg_acc:>0.2f}%, Avg loss: {avg_loss:>8f} \n")

# 训练多个 epoch
epochs = 10
for e in range(epochs):
    print(f"------- Epoch {e} --------")
    train(model, train_dataloader, loss_fn, optimizer)
    test(model, test_dataloader, loss_fn)
print("Done!")
------- Epoch 0 --------
loss: 2.302572  [    0/60000]
loss: 2.278226  [12800/60000]
loss: 2.263920  [25600/60000]
loss: 2.279993  [38400/60000]
loss: 2.265715  [51200/60000]
Test Error: 
 Accuracy: 31.01%, Avg loss: 2.240282 

------- Epoch 1 --------
loss: 2.253209  [    0/60000]
loss: 2.209677  [12800/60000]
loss: 2.195682  [25600/60000]
loss: 2.215378  [38400/60000]
loss: 2.191957  [51200/60000]
Test Error: 
 Accuracy: 37.75%, Avg loss: 2.135047 

------- Epoch 2 --------
loss: 2.170743  [    0/60000]
loss: 2.075798  [12800/60000]
loss: 2.044513  [25600/60000]
loss: 2.065165  [38400/60000]
loss: 2.013370  [51200/60000]
Test Error: 
 Accuracy: 41.41%, Avg loss: 1.910970 

------- Epoch 3 --------
loss: 1.980009  [    0/60000]
loss: 1.803920  [12800/60000]
loss: 1.756796  [25600/60000]
loss: 1.774534  [38400/60000]
loss: 1.720578  [51200/60000]
Test Error: 
 Accuracy: 46.69%, Avg loss: 1.598630 

------- Epoch 4 --------
loss: 1.691669  [    0/60000]
loss: 1.478196  [12800/60000]
loss: 1.456628  [25600/60000]
loss: 1.474672  [38400/60000]
loss: 1.443763  [51200/60000]
Test Error: 
 Accuracy: 60.15%, Avg loss: 1.341288 

------- Epoch 5 --------
loss: 1.435102  [    0/60000]
loss: 1.234250  [12800/60000]
loss: 1.234853  [25600/60000]
loss: 1.256342  [38400/60000]
loss: 1.251099  [51200/60000]
Test Error: 
 Accuracy: 63.44%, Avg loss: 1.166800 

------- Epoch 6 --------
loss: 1.251537  [    0/60000]
loss: 1.064981  [12800/60000]
loss: 1.084116  [25600/60000]
loss: 1.118475  [38400/60000]
loss: 1.128412  [51200/60000]
Test Error: 
 Accuracy: 65.01%, Avg loss: 1.051667 

------- Epoch 7 --------
loss: 1.125720  [    0/60000]
loss: 0.946272  [12800/60000]
loss: 0.984997  [25600/60000]
loss: 1.029155  [38400/60000]
loss: 1.044023  [51200/60000]
Test Error: 
 Accuracy: 66.17%, Avg loss: 0.971739 

------- Epoch 8 --------
loss: 1.034441  [    0/60000]
loss: 0.859551  [12800/60000]
loss: 0.916106  [25600/60000]
loss: 0.967752  [38400/60000]
loss: 0.982405  [51200/60000]
Test Error: 
 Accuracy: 67.27%, Avg loss: 0.913043 

------- Epoch 9 --------
loss: 0.964701  [    0/60000]
loss: 0.793666  [12800/60000]
loss: 0.865933  [25600/60000]
loss: 0.922777  [38400/60000]
loss: 0.935715  [51200/60000]
Test Error: 
 Accuracy: 68.13%, Avg loss: 0.867866 

Done!

4. 模型保存、模型加载、模型推理

通常,我们会把模型参数保存下来,随后可以加载并用于推理预测

model_path = "C01_10.pth"
torch.save(model.state_dict(), model_path)
print("Model parameters saved!")
Model parameters saved!
model_2 = MyRSNN()
model_2.load_state_dict(torch.load("C01_10.pth"))
<All keys matched successfully>
def predict(sample_idx=0):
    classes = [
        "T-shirt/top",
        "Trouser",
        "Pullover",
        "Dress",
        "Coat",
        "Sandal",
        "Shirt",
        "Sneaker",
        "Bag",
        "Ankle boot",
    ]
    model_2.eval()
    X, y = test_data[sample_idx][0], test_data[sample_idx][1]
    print("Ground Truth: ", classes[y])
    with torch.no_grad():
        y_pred = model_2(X).flatten()
    print("Model prediction: ", classes[y_pred.argmax()])
predict(0)
Ground Truth:  Ankle boot
Model prediction:  Ankle boot
predict(190)
Ground Truth:  Trouser
Model prediction:  Dress

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
龙良曲是一位知名的人工智能领域专家,他在PyTorch课程资源方面有着丰富的经验和贡献。 首先,龙良曲所提供的PyTorch课程资源非常丰富。他通过多年的学术研究和实践经验,深入理解了PyTorch框架的原理和应用,在课程资源的准备上投入了大量的心血。无论是针对初学者的入门教程,还是针对专业人士的深入讲解,都能够帮助学习者系统地掌握PyTorch的使用方法和技巧。 其次,龙良曲的PyTorch课程资源涵盖了各个层次和方向的内容。无论您是想要学习PyTorch基础知识,还是想要专门研究深度学习领域的前沿技术,都能够在龙良曲的课程资源中找到合适的内容。他提供的课程涵盖了PyTorch的基本概念、张量操作、网络模型的构建、迁移学习、生成对抗网络等各个方面,让学习者能够全面地了解和应用PyTorch。 此外,龙良曲的PyTorch课程资源还充分结合了实际应用场景和案例分析。他通过丰富的示例代码和实验,帮助学习者将理论知识应用到实际问题中去。这种结合实践的方式,能够加深学习者对PyTorch的理解和运用能力,提高其在工作和实验中的效率和准确性。 总的来说,龙良曲的PyTorch课程资源是非常有价值和实用性的。无论您是初学者还是专业人士,都可以通过他的课程资源快速、系统地学习和应用PyTorch,进一步拓展人工智能领域的知识和技能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值