小土堆:Pytorch深度学习:神经网络-线性层及其他层介绍

今天学习了神经网络-线性层及其他层。这段代码的主要目的是为了实现一个简单的神经网络模型,并在 CIFAR-10 数据集上进行预测。

import torchimport torchvisionfrom torch import nnfrom torch.nn import Linearfrom torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),                                       download=True)
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)
class Yang(nn.Module):    def __init__(self):        super(Yang, self).__init__()        self.linear1 = Linear(196608, 10)
    def forward(self, input):        output = self.linear1(input)        return output
yang = Yang()
for data in dataloader:    imgs, targets = data    print(imgs.shape)    output = torch.flatten(imgs)    print(output.shape)    output = yang(output)    print(output.shape)

1. `import torch`:导入 PyTorch 库,这是一个用于机器学习和深度学习的开源库。

 2. `import torchvision`:导入 torchvision 库,这个库包含了许多常用的数据集和预训练模型。 

3. `from torch import nn`:从 PyTorch 库中导入神经网络模块。

 4. `from torch.nn import Linear`:从神经网络模块中导入线性层,这是神经网络中常用的一种层。 

5. `from torch.utils.data import DataLoader`:从 PyTorch 库中导入 DataLoader,这是一个可以使数据加载更加容易和快速的工具。

6. `dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)`:这行代码从 torchvision 数据集中加载 CIFAR-10 数据集。这个数据集包含了 60000 张 32x32 的彩色图像,共有 10 个类别。`train=False` 表示加载的是测试集,`transform=torchvision.transforms.ToTensor()` 表示将加载的图像转换为 PyTorch Tensor,`download=True` 表示如果数据集未在指定路径下找到,就下载数据集。

 7. `dataloader = DataLoader(dataset, batch_size=64)`:创建一个 DataLoader 对象,这个对象会生成一个迭代器,每次从数据集中提取一批数据(批大小为64)。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值