Pytorch学习笔记——线性层和非线性层的使用

1. 前言

在深度学习中,线性层和非线性层是构建神经网络的基本单元。本文将通过PyTorch实现一个简单的网络,详细讲解线性层与非线性层的使用和区别。

2. 导入必要的库

首先,我们需要导入PyTorch以及一些常用的模块:

import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader

3. 加载数据集

使用torchvision加载CIFAR-10数据集,并将其转换为Tensor格式。

dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)
  • root="data1":数据存储路径。
  • train=False:加载测试集。
  • transform=torchvision.transforms.ToTensor():将图像数据转换为Tensor。
  • download=True:如果数据集不存在,则下载。
  • drop_last=True:如果最后一个batch大小小于batch_size,则丢弃。

4. 定义线性层网络结构

构建一个包含线性层的简单神经网络:

class NN(nn.Module):
    def __init__(self):
        super(NN, self).__init__()
        self.linear1 = nn.Linear(196608, 10)  # 定义一个线性层

    def forward(self, input):
        output = self.linear1(input)  # 前向传播
        return output
  • nn.Linear(196608, 10):定义一个线性层,输入维度为196608,输出维度为10。

5. 实例化网络并打印输出

使用DataLoader加载数据,遍历数据并打印输出结果。

mynn = NN()  # 实例化网络

for data in dataloader:
    imgs, targets = data
    print(imgs.shape)  # 打印图像的形状
    output = torch.flatten(imgs)  # 展平图像
    print(output.shape)  # 打印展平后的形状
    output = mynn(output)  # 输入到网络中
    print(output.shape)  # 打印输出的形状
    print("------------------")
  • torch.flatten(imgs):将图像展平为一维。
  • 将展平后的图像输入到网络中,得到输出。

输出结果:

torch.Size([64, 3, 32, 32])
torch.Size([196608])
torch.Size([10])
------------------

每次遍历数据加载器,我们可以看到原始图像的形状,展平后的形状,以及通过线性层后的输出形状。

6. 定义非线性层网络结构

为了演示非线性层,我们可以在网络中加入激活函数,例如ReLU(Rectified Linear Unit):

class NNWithNonLinearity(nn.Module):
    def __init__(self):
        super(NNWithNonLinearity, self).__init__()
        self.linear1 = nn.Linear(196608, 10)
        self.relu = nn.ReLU()  # 定义ReLU激活函数

    def forward(self, input):
        output = self.linear1(input)
        output = self.relu(output)  # 应用激活函数
        return output
  • nn.ReLU():定义ReLU激活函数。
  • 将线性层的输出通过ReLU激活函数,增加非线性。

实例化非线性网络并打印输出:

mynn_nonlin = NNWithNonLinearity()

for data in dataloader:
    imgs, targets = data
    output = torch.flatten(imgs)
    output = mynn_nonlin(output)
    print(output)
    print("------------------")

7. 总结

线性层和非线性层是神经网络的基本构件。线性层执行线性变换,而非线性层(例如激活函数)引入非线性,从而使网络能够拟合复杂的函数。本文通过实例演示了如何在PyTorch中使用这些层,理解了它们的工作原理和应用。

通过这种方式,我们可以更好地理解和构建复杂的神经网络,提高模型的表现力和泛化能力。

  • 9
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Pytorch是机器学习中的一个重要框架,它与TensorFlow一起被认为是机器学习的两大框架。Pytorch学习可以从以下几个方面入手: 1. Pytorch基本语法:了解Pytorch的基本语法和操作,包括张量(Tensors)的创建、导入torch库、基本运算等\[2\]。 2. Pytorch中的autograd:了解autograd的概念和使用方法,它是Pytorch中用于自动计算梯度的工具,可以方便地进行反向传播\[2\]。 3. 使用Pytorch构建一个神经网络:学习使用torch.nn库构建神经网络的典型流程,包括定义网络结构、损失函数、反向传播和更新网络参数等\[2\]。 4. 使用Pytorch构建一个分类器:了解如何使用Pytorch构建一个分类器,包括任务和数据介绍、训练分类器的步骤以及在GPU上进行训练等\[2\]。 5. Pytorch的安装:可以通过pip命令安装Pytorch,具体命令为"pip install torch torchvision torchaudio",这样就可以在Python环境中使用Pytorch了\[3\]。 以上是一些关于Pytorch学习笔记,希望对你有帮助。如果你需要更详细的学习资料,可以参考引用\[1\]中提到的网上帖子,或者查阅Pytorch官方文档。 #### 引用[.reference_title] - *1* [pytorch自学笔记](https://blog.csdn.net/qq_41597915/article/details/123415393)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [Pytorch学习笔记](https://blog.csdn.net/pizm123/article/details/126748381)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值