Pytorch学习——Learn the Basic

QuickStart

working with data

介绍两个pytorch的数据加载和管理工具:
torch.utils.data.DataLoader and torch.utils.data.Dataset

Dataset负责存储数据样本以及数据标签,而DataLoader负责将Dataset的数据包装成一个可迭代访问的数据.

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, 	Lambda, Compose
import matplotlib.pyplot as plt

Pytorch专门提供了一些特定领域数据集的工具,包括:TorchText, TorchVision, and TorchAudio。

因此,在这次pytorch辅导中,我们将使用TorchVision的FashionMNIST数据集。

//PS:respectively, 分别地,各自地//

每个TorchVision数据集包含两个参数:transform和target_transform,分别用于修改样本和标签。
在这里插入图片描述
一般之后就会进行下载,数据量比较大,有几个GB。

现在有了training_data和testing_data两个Dataset类变量,将Dataset作为参数传递给DataLoader

 batch_size = 64
 # Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
	print("Shape of y: ", y.shape, y.dtype)
break

在这里插入图片描述
可以看出,DataLoader不能进可以将数据集包装成一个可迭代实例(类似range实例),也可以自动完成批处理(batching),采样,洗牌和多进程处理。

//PS:i.e. 表示:即//

有一个疑惑,这个batch到底指什么。
我依然不太明白这个batch是什么

Create Models

貌似pytrotch中是是使用继承nn.Module的方法来实现神经网络。

我们在_init__函数中定义网络的层,并在forward函数中指定数据将如何通过网络。为了加速神经网络中的操作,我们将其移动到GPU(当然,我的商务本没有GPU,所以就是CPU)。

代码:
在这里插入图片描述
结果:
在这里插入图片描述
很明显,这个是一个CNN网络。
我有点疑惑,这个pytorch是采用继承nn.Module的办法来创建一个新的神经网络类的,也就是说torch.nn中包含了很多的值得我挖掘的东西。

Optimizing the Model Parameters

为了训练一个模型,我们需要一个loss function(损失函数),和一个optimizer(优化器)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

这一部分我有点疑惑。

定义完损失函数之后就是定义训练函数

def train(dataloader, model, loss_fn, optimizer):
	 size=len(dataloader.dataset)
	 # 所以这个model真的就这么简单?
	model.train()
	for batch,(X,y) in enumerate(dataloader):
    	X,y = X.to(device),y.to(device)
    
    	# 计算本轮迭代内部的预测误差
   		pred = model(X)
    	loss = loss_fn(pred,y)
    
    	# 反向传播
    	optimizer.zero_grad()
    	loss.backward()
    	optimizer.step()
    
    if batch % 100 == 0:
        loss, current = loss.item(), batch*len(X)
        print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")

这里也有很多奇奇怪怪的。
我们还根据测试数据集检查模型的性能,以确保它正在学习。
语句:
def test(dataloader, model, loss_fn):

dataloader就是数据集,然后model就是自己想要的模型,loss_fn就是损失函数

size = len(dataloader.dataset)
num_batches = len(dataloader)
# model.eval()这个应该是nn.Module里的东西
model.eval()
test_loss, correct = 0,0
with torch.no_grad():
    for X,y in dataloader:
        # 这个X.to(device)是什么
        X,y = X.to(device),y.to(device)
        # ??model(X)
        pred = model(X)
        # ?? 什么是loss_fn(pred,y).item()
        test_loss += loss_fn(pred,y).item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
# 这个到底是什么print语句??
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

模型的训练过程将会持续好几个迭代(epochs),在每一个迭代中,模型都会做出更好的预测,我希望在每一个迭代中看见精度提升和损失下降。
正式开始训练模型:

epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n----------------------------")
    train(train_dataloader,model,loss_fn,optimizer)
    test(test_dataloader,model,loss_fn)
print("Done!")

Saving Models

接下来就是保存你的模型

torch.save(model.state_dict(),"model.pth")
print("Saved PyTorch Model State to model.pth")

Loading Models

保存完模型,就可以加载模型进行现实预测了
整个加载模型的过程包括重新建造模型结构和加载模型状态字典

# re-create the model structure
model = NeuralNetWork()
# load the state dictionary
model.load_state_dict(torch.load("model.pth"))

接下来进行预测

model.eval()
# ??test_data[0][0]
x,y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted,actual = Classes[pred[0].argmax(0)],Classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')

结果:

Predicted: "Ankle boot", Actual: "Ankle boot"
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值