Pytorch_入门概念

一、神经网络的组成部分

在第一部分中我们了解了 Pytorch 的相关基础知识,在这一篇文章中我将使用 Pytorch 进入深度学习的学习,学习如果使用 Pytorch 搭建神经网络中的一些基础代码,具体讲包含如下内容:

  • 神经网络的组成部分
    • 神经元
    • 神经网络层
  • 如何使用 Pytorch 完成数据加载工作以及相应的数据预处理
  • 训练神经网络模型并验证

神经网络是一种由多个神经元以一定的方式联结形成的网络结构,是一种仿照生物神经系统结构和功能的人工智能技术。神经网络通常由输入层输出层和若干个隐藏层组成,每个层包含若干个神经元。

神经网络的基本组成单位是神经元,它模拟了生物神经元的行为特征,包括输入信号的接收、加权求和、非线性激活等过程。神经元接收来自前一层神经元的输入信号,将输入信号进行加权求和,并通过激活函数将结果转换为输出信号,并将输出信号传递给下一层神经元。如图1所示为一个经典的以全连接(Full Connected, FC)方式形成的神经网络,每个圆圈代表一个神经元,圆圈间的连线代表神经元之间的联结。

Image Name

1. 神经元


神经生物学家Warren MeCulloch和数学家Walter Pitts于1943年提出了一种基于早期的神经元理论学说的人工神经网络模型,称为MP模型(McCulloch-Pitts模型)。该模型是一种具有生物神经元特征的人工神经网络模型,被认为是神经网络研究的开端。

MP模型的基本思想是将神经元视为一个二进制变量,其输出值只能为0或1。神经元接收来自其他神经元的输入信号,通过一个阈值函数对输入信号进行加权和处理,并产生一个二进制输出值。MP模型中的神经元只有两种状态,即兴奋态(输出值为1)和抑制态(输出值为0),通过神经元之间的连接,可以实现复杂的计算功能。MP模型的主要贡献在于将生物神经元的工作原理转化为数学模型,为后续的神经网络研究奠定了基础。虽然MP模型非常简单,但它的基本思想和理论对于神经网络的发展和应用具有重要意义,为人工智能和机器学习的发展奠定了基础。如图2所示为MP神经元模型。

Image Name

如上图所示, u 1 , . . . . . . u j , . . . . . . u n u_1, ......u_j, ......u_n u1,......uj,......un是一个 n n n维向量,代表与第 i i i个神经元相连接的其他神经元传递的信号; w 1 i , . . . . . , w j i , . . . . . . w n i w_{1i}, ....., w_{ji}, ......w_{ni} w1i,.....,wji,......wni分别代表其他神经元和第 i i i个神经元之间连接的权重值;代表第 i i i个神经元的阈值; x i x_i xi则称为第 i i i个神经元的输入,可表示为如式1所示; f ( x i ) f(x_i) f(xi)是非线性函数,如式2所示。

KaTeX parse error: \tag works only in display equations
KaTeX parse error: \tag works only in display equations

神经元通常由以下几个部分组成:

  • **输入(Inputs):**神经元接收来自其他神经元或外部环境的输入数据。
  • **权重(Weights):**每个输入都与一个权重相关联,用于调整输入的重要性。
  • **激活函数:**激活函数将加权输入映射到输出。常用的激活函数包括Sigmoid、ReLU和Tanh等。
  • **偏差(Bias):**偏差是一个可学习的参数,用于调整神经元输出的阈值。
import torch

class Neuron(torch.nn.Module):
    def __init__(self, input_size):
        super(Neuron, self).__init__()
        # 定义可学习的权重参数,形状为 (input_size,),与输入特征数量相对应
        self.weights = torch.nn.Parameter(torch.randn(input_size))
        # 定义可学习的偏置参数,初始化为随机值,标量
        self.bias = torch.nn.Parameter(torch.randn(1))

    def forward(self, inputs):
        # 计算加权和,点乘输入和权重,然后加上偏置
        weighted_sum = torch.sum(inputs * self.weights) + self.bias
        # 应用 sigmoid 激活函数,将结果压缩到 [0, 1] 范围内
        output = torch.sigmoid(weighted_sum)
        return output

# 创建一个具有3个输入的神经元
neuron = Neuron(3)

# 输入数据
inputs = torch.tensor([0.5, -0.3, 0.1])

# 计算输出
output = neuron(inputs)
print(output)
tensor([0.7111], grad_fn=<SigmoidBackward>)

这是一个简单的神经元模型,用 PyTorch 构建。让我解释一下这个模型的结构和功能:

  • Neuron 类:是一个继承自 torch.nn.Module 的自定义神经元模型。继承自 torch.nn.Module 的基类允许你定义具有可学习参数的自定义神经网络模型。
  • init 方法:这是模型的构造函数,它接受一个参数 input_size,表示输入特征的数量。在这个方法中,模型初始化了两个可学习的参数:weights 和 bias,这两个参数都被包装成 torch.nn.Parameter 对象,以便在模型的训练过程中进行优化。
    • weights 是一个形状为 (input_size,) 的可学习权重向量,它与输入特征进行点乘。
    • bias 是一个标量值,它用于调整模型的输出。
  • forward 方法:这是模型的前向传播方法。在前向传播过程中,输入 inputs 与权重 weights 进行点乘,然后将点乘结果与 bias 相加,得到加权和 weighted_sum。然后,通过 sigmoid 激活函数对加权和进行激活,将结果作为模型的输出返回。

这个神经元模型可以用于二分类问题,其中 input_size 表示输入特征的数量,模型通过学习适当的权重和偏置来进行二元分类。在训练过程中,你可以使用标准的 PyTorch 优化器和损失函数来训练这个模型,以便它能够适应你的分类任务。

2. 神经网络层


神经网络由多个神经元层组成。每一层都由许多神经元组成,并且通常具有相同的结构和激活函数。以下是一些常见的神经网络层类型:

  • 全连接层(Fully Connected Layer):每个神经元都与前一层的所有神经元相连接。
  • 卷积层(Convolutional Layer):应用卷积操作来提取输入数据中的空间特征。
  • 池化层(Pooling Layer):通过减少特征图的大小来降低计算量,并保留重要的特征。
  • 循环层(Recurrent Layer):通过在神经网络中引入时间维度来处理序列数据。

以下是一个包含两个全连接层的神经网络示例代码:

import torch

# 定义神经网络类,继承自 torch.nn.Module
class NeuralNetwork(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NeuralNetwork, self).__init__()
        # 定义第一个全连接层,输入大小为 input_size,输出大小为 hidden_size
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        # 定义第二个全连接层,输入大小为 hidden_size,输出大小为 output_size
        self.fc2 = torch.nn.Linear(hidden_size, output_size)

    def forward(self, inputs):
        # 使用 ReLU 激活函数计算第一个全连接层的输出
        hidden = torch.relu(self.fc1(inputs))
        # 使用 sigmoid 激活函数计算第二个全连接层的输出,最终的模型输出
        output = torch.sigmoid(self.fc2(hidden))
        return output

# 创建一个具有2个输入、3个隐藏神经元和1个输出的神经网络
net = NeuralNetwork(2, 3, 1)

# 输入数据
inputs = torch.tensor([0.5, -0.3])

# 计算输出
output = net(inputs)
print(output)
tensor([0.4567], grad_fn=<SigmoidBackward>)

3. 损失函数


神经网络的目标是最小化预测输出与真实标签之间的差异。损失函数衡量了这种差异,并提供一个可优化的目标。常见的损失函数包括均方误差(Mean Squared Error)、交叉熵损失(Cross-Entropy Loss)等。

  • 均方误差(Mean Squared Error,MSE):计算预测值与目标值之间的平方差的平均值。
  • 交叉熵损失(Cross-Entropy Loss):在分类问题中,计算预测概率分布与真实标签之间的交叉熵。

以下是一个使用均方误差作为损失函数的示例:

import torch

# 随机生成一些示例数据
predictions = torch.tensor([0.9, 0.2, 0.1])
labels = torch.tensor([1.0, 0.0, 0.0])

# 计算均方误差损失
loss_function = torch.nn.MSELoss()
loss = loss_function(predictions, labels)
print(loss)
tensor(0.0200)

4. 优化器


优化器用于更新神经网络的参数以最小化损失函数。它使用梯度下降算法来调整参数的值。常见的优化器包括随机梯度下降(Stochastic Gradient Descent,SGD)、Adam等。

以下是一个使用Adam优化器进行参数更新的示例:

import torch

# 创建一个神经网络和损失函数
net = NeuralNetwork(2, 3, 1)
loss_function = torch.nn.MSELoss()

# 创建一个Adam优化器
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)

# 输入数据和真实标签
inputs = torch.tensor([0.5, -0.3])
labels = torch.tensor([1.0])

# 前向传播
output = net(inputs)
loss = loss_function(output, labels)

# 反向传播和参数更新
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(loss)
tensor(0.1231, grad_fn=<MseLossBackward>)

通过理解和应用这些神经网络的组成部分,您将能够构建和训练自己的深度学习模型。

二、PyTorch中的层

在PyTorch中,神经网络层(Layers)是神经网络的基本组成部分,用于对输入数据进行转换和提取特征。PyTorch提供了丰富的层类型和功能,使得构建和训练深度学习模型变得更加便捷和灵活。这里将介绍PyTorch中的一些常用层,并提供示例代码来帮助读者理解和学习。

目录

  1. 全连接层(Fully Connected Layer)
  2. 卷积层(Convolutional Layer)
  3. 池化层(Pooling Layer)
  4. 循环神经网络层(Recurrent Neural Network Layer)
  5. 转置卷积层(Transpose Convolutional Layer)
  6. 归一化层(Normalization Layer)
  7. 激活函数层(Activation Function Layer)
  8. 损失函数层(Loss Function Layer)
  9. 优化器层(Optimizer Layer)

1. 全连接层


全连接层,也被称为线性层或密集层,是最简单的神经网络层之一。它将输入的每个元素与权重相乘,并加上偏置项,然后通过激活函数进行非线性变换。全连接层的输出形状由其输入形状和输出维度确定。

下面是一个创建全连接层的示例代码:

import torch
import torch.nn as nn

# 定义输入和输出维度
input_size = 784
output_size = 10

# 创建全连接层
fc_layer = nn.Linear(input_size, output_size)

# 打印全连接层的权重和偏置项
print("权重:", fc_layer.weight)
print("偏置项:", fc_layer.bias)
权重: Parameter containing:
tensor([[-0.0292,  0.0176, -0.0219,  ..., -0.0264,  0.0119, -0.0114],
        [-0.0172, -0.0216, -0.0100,  ...,  0.0320, -0.0094, -0.0115],
        [-0.0158,  0.0301, -0.0223,  ..., -0.0194, -0.0009,  0.0238],
        ...,
        [ 0.0035,  0.0335,  0.0249,  ...,  0.0230,  0.0258, -0.0262],
        [-0.0345,  0.0155, -0.0178,  ...,  0.0357,  0.0137,  0.0007],
        [-0.0071,  0.0290,  0.0182,  ...,  0.0297,  0.0172, -0.0019]],
       requires_grad=True)
偏置项: Parameter containing:
tensor([ 0.0324,  0.0022, -0.0309, -0.0158, -0.0260, -0.0338, -0.0045, -0.0188,
        -0.0037,  0.0122], requires_grad=True)

2. 卷积层


卷积层是卷积神经网络中的核心层之一,用于从输入数据中提取空间特征。卷积层通过滑动窗口(卷积核)在输入上进行局部感知,并输出对应的特征图。PyTorch中的卷积层包括二维卷积层和三维卷积层,分别用于处理二维和三维数据。

下面是一个创建二维卷积层的示例代码:

import torch
import torch.nn as nn

# 定义输入通道数、输出通道数和卷积核大小
in_channels = 3
out_channels = 16
kernel_size = 3

# 创建二维卷积层
conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size)

# 打印二维卷积层的权重和偏置项
print("权重:", conv_layer.weight)
print("偏置项:", conv_layer.bias)
权重: Parameter containing:
tensor([[[[-3.6062e-02, -1.0831e-01, -1.8235e-01],
          [-1.1229e-01,  1.1231e-01,  1.2790e-01],
          [-7.6108e-02,  1.1414e-01,  1.4705e-01]],

         [[-4.1806e-02, -8.6680e-05,  5.5572e-02],
          [-1.7353e-01,  9.9678e-02,  8.4714e-02],
          [-1.4960e-01, -1.2091e-01, -1.5792e-01]],

         [[ 1.6896e-01,  1.6563e-01,  1.2287e-01],
          [-2.4582e-02,  2.3866e-02,  5.7180e-02],
          [-3.6860e-02,  4.8184e-02, -1.4439e-01]]],


        [[[ 9.8774e-02,  1.5412e-01,  1.4060e-01],
          [ 7.1751e-03, -4.9612e-02, -1.5085e-01],
          [-5.4055e-02, -9.9365e-02, -1.3711e-01]],

         [[ 1.7693e-01,  1.3813e-01, -5.0148e-02],
          [ 7.8184e-02,  1.2830e-01, -1.8312e-01],
          [-1.1992e-01, -1.9194e-01, -1.9078e-01]],

         [[-5.3462e-02, -1.2867e-01, -1.0925e-02],
          [ 3.5722e-03,  2.0914e-02,  8.7383e-03],
          [ 1.3804e-01,  1.2241e-01, -1.4674e-01]]],


        [[[-5.8650e-02, -1.8996e-01,  1.5064e-01],
          [-1.7766e-01, -4.7220e-02, -1.2675e-01],
          [-2.3848e-02, -1.0913e-01, -1.0859e-01]],

         [[-1.0538e-01, -4.3229e-02, -1.6024e-01],
          [-1.9046e-01,  6.5022e-02,  1.1860e-01],
          [ 6.7025e-02, -1.8756e-01, -1.8292e-01]],

         [[ 6.2805e-02,  1.1809e-02,  4.6006e-02],
          [ 8.0074e-02,  1.0206e-01, -1.0466e-01],
          [-1.2855e-01, -1.4031e-01, -8.3772e-02]]],


        [[[-1.2358e-01, -2.1925e-02, -8.4578e-02],
          [ 1.1147e-02, -6.3246e-02,  5.2268e-02],
          [ 4.6418e-02,  2.1144e-02,  1.7721e-01]],

         [[ 9.0155e-02, -1.2519e-01, -3.4674e-02],
          [ 1.7345e-02, -1.6761e-01,  1.8707e-01],
          [ 9.0477e-02, -9.1471e-04, -9.4137e-02]],

         [[ 1.5679e-01, -1.3293e-01,  1.7856e-01],
          [ 8.7667e-02, -1.0920e-01,  8.0001e-02],
          [-2.5766e-02,  9.9656e-02,  4.1608e-02]]],


        [[[-1.9037e-01, -1.3460e-01, -5.1148e-02],
          [ 1.0287e-01, -1.7130e-01, -4.0268e-02],
          [ 1.1129e-01,  5.1074e-02,  1.8437e-01]],

         [[-1.0889e-02,  2.7962e-02,  1.4198e-01],
          [-7.4635e-02,  1.3345e-01, -2.9000e-02],
          [ 4.9944e-03, -1.4798e-03, -2.9579e-02]],

         [[ 1.3687e-01, -9.0204e-02,  1.5016e-01],
          [ 1.8816e-02, -6.1756e-02,  5.5918e-02],
          [ 1.7162e-01,  2.7002e-02, -1.7337e-01]]],


        [[[ 1.8760e-01,  1.2394e-01, -1.7309e-01],
          [ 1.7666e-01,  1.7402e-01, -3.9097e-02],
          [ 1.7780e-01, -2.3493e-02,  1.1324e-01]],

         [[-1.9171e-01, -1.1542e-01, -9.1078e-02],
          [-3.2620e-03,  1.1829e-01, -1.8681e-01],
          [-1.9202e-01,  1.2009e-01,  1.6312e-01]],

         [[ 1.3745e-01,  8.4450e-03, -6.3764e-02],
          [-1.6661e-01, -1.6183e-01, -6.5783e-02],
          [-7.4341e-02,  1.4031e-01, -1.6009e-01]]],


        [[[-1.6315e-01,  1.1231e-01,  1.0717e-01],
          [-8.3239e-02, -1.8107e-01, -1.0144e-01],
          [-1.2889e-01, -1.6415e-01,  1.0282e-01]],

         [[-1.0547e-01,  7.3083e-02,  9.0126e-02],
          [ 9.1563e-02,  9.0577e-04,  6.3363e-03],
          [-6.0492e-02,  1.2098e-02,  1.7538e-01]],

         [[-3.6617e-02, -1.4061e-01,  5.4565e-02],
          [-7.8747e-02, -1.1793e-01,  2.9076e-02],
          [-1.5725e-01,  1.6877e-01, -1.3358e-01]]],


        [[[ 9.7572e-02, -3.5398e-02, -1.4387e-02],
          [ 4.2864e-02, -4.4962e-02, -9.2191e-02],
          [ 9.2448e-02, -1.2951e-01,  1.7036e-01]],

         [[-4.1832e-02, -3.5799e-02,  1.8343e-02],
          [ 1.0235e-01,  4.9717e-02,  1.5122e-01],
          [-8.0757e-02,  1.5591e-01, -1.4577e-01]],

         [[ 1.0169e-01,  1.2279e-02,  7.3877e-03],
          [ 3.2893e-02,  4.8011e-02, -1.5046e-01],
          [-3.1043e-02, -7.3972e-03,  1.0526e-01]]],


        [[[ 1.6907e-01, -3.4342e-02,  4.6242e-02],
          [ 3.6345e-02,  1.6763e-01,  8.2050e-02],
          [-9.2567e-02,  6.7075e-02, -7.4998e-02]],

         [[-1.9099e-01, -1.3073e-01,  2.1420e-02],
          [-1.9151e-01, -1.5382e-01,  4.5285e-02],
          [ 1.0206e-01, -3.8949e-03,  5.0588e-02]],

         [[ 2.6557e-03, -5.3072e-02, -6.2959e-02],
          [-1.3186e-01, -3.5519e-02,  1.0446e-01],
          [-1.7102e-01,  1.0622e-01,  1.8678e-01]]],


        [[[-2.2548e-03, -3.1101e-02,  6.7139e-02],
          [-1.8991e-01, -9.2259e-02,  5.4250e-02],
          [-6.7490e-02, -3.0403e-02,  1.3894e-01]],

         [[-2.9711e-02,  1.3089e-01,  1.4989e-01],
          [-1.4463e-01, -8.4889e-02,  8.4860e-02],
          [-6.6112e-03,  1.7246e-01,  3.1585e-02]],

         [[ 8.4949e-02, -1.2271e-01, -5.8491e-02],
          [-1.5170e-01,  8.9345e-02,  4.9451e-02],
          [ 1.2050e-01, -1.6071e-01,  2.3907e-02]]],


        [[[-1.2493e-02, -4.9467e-02, -9.0157e-02],
          [ 5.8240e-02, -1.3586e-01,  1.0429e-01],
          [-2.5696e-02, -6.6934e-02, -1.3256e-01]],

         [[-1.8781e-01, -9.4382e-02,  9.0596e-02],
          [ 1.6392e-01, -4.5904e-02, -1.2793e-01],
          [ 1.8197e-01,  1.6007e-01, -1.4013e-01]],

         [[-1.3479e-01,  1.7448e-01, -1.6815e-02],
          [-8.6067e-02,  1.5120e-01, -1.6255e-01],
          [-1.4240e-01,  1.8225e-01,  1.0323e-01]]],


        [[[ 1.8674e-01,  1.3040e-01,  1.0223e-01],
          [-1.4688e-01,  1.7013e-02,  7.8249e-02],
          [-8.5060e-03,  1.5221e-02, -1.3749e-01]],

         [[-2.3502e-02,  1.8105e-01,  1.1986e-01],
          [ 1.0078e-01,  1.2930e-01, -4.1422e-02],
          [-9.1035e-02, -1.5103e-01, -1.4178e-01]],

         [[ 1.5914e-03,  1.9122e-02,  8.1261e-02],
          [-2.5986e-02, -3.2954e-02, -1.2220e-01],
          [-1.4379e-01, -1.3325e-01, -3.8443e-02]]],


        [[[ 6.1646e-02,  3.9501e-03, -5.7518e-02],
          [ 1.2571e-01,  5.7255e-02,  1.6580e-01],
          [ 5.7418e-02, -1.4850e-01,  8.0357e-02]],

         [[-1.1739e-01,  6.6844e-02,  7.2799e-02],
          [ 1.9003e-01,  1.6125e-01,  2.1432e-02],
          [ 8.1119e-02,  9.2799e-03, -1.6096e-01]],

         [[ 1.1459e-01, -1.6446e-01,  9.0816e-02],
          [ 1.0114e-01, -1.7420e-01, -8.8830e-02],
          [-2.9408e-02, -1.0182e-01, -1.0072e-01]]],


        [[[ 1.6403e-01,  1.8487e-01, -2.2400e-02],
          [ 1.5823e-01, -5.3700e-02, -1.6983e-01],
          [ 7.7607e-02, -1.2672e-01,  2.8621e-02]],

         [[-1.5368e-01, -2.8543e-02, -1.6217e-01],
          [ 5.8100e-02,  7.3457e-02,  6.9796e-02],
          [ 6.5893e-02, -9.0955e-02,  1.1849e-01]],

         [[ 1.4581e-02, -1.0972e-01,  1.5005e-01],
          [-1.0861e-01, -4.4161e-02,  1.0241e-01],
          [-1.4663e-01,  1.6888e-01,  1.8902e-02]]],


        [[[-1.3665e-01,  1.8165e-01, -5.8072e-02],
          [ 9.3178e-02,  1.9131e-01, -3.2562e-02],
          [-1.1652e-01, -1.2072e-02, -1.2702e-01]],

         [[-7.8824e-03,  1.6049e-01,  1.3276e-01],
          [-2.4556e-02,  7.2487e-02, -1.1772e-02],
          [ 7.0178e-02,  1.6885e-01,  2.4307e-02]],

         [[-3.9709e-03, -1.5664e-01,  4.2569e-02],
          [ 1.1845e-01,  1.9123e-01, -1.5515e-01],
          [ 1.3468e-02,  1.2187e-01, -9.4735e-02]]],


        [[[-2.0543e-02, -6.6244e-02,  5.8826e-02],
          [ 1.5542e-01, -9.3407e-02,  3.8018e-02],
          [-4.2657e-02, -1.5728e-01, -1.8638e-01]],

         [[-1.0184e-01, -1.3268e-01, -5.1986e-02],
          [-1.0289e-01, -4.3372e-02,  1.2301e-02],
          [-7.8460e-02,  1.4338e-01,  1.2939e-01]],

         [[-3.9168e-02, -9.8682e-02,  3.0624e-02],
          [-1.1603e-01, -8.8590e-02,  1.9069e-01],
          [-2.4992e-02,  2.8455e-02,  1.0765e-01]]]], requires_grad=True)
偏置项: Parameter containing:
tensor([ 0.1031, -0.1123,  0.0991,  0.0320, -0.0007, -0.0248,  0.0734, -0.0895,
        -0.1511,  0.0188,  0.1813,  0.0153, -0.0721,  0.1908, -0.0800, -0.1570],
       requires_grad=True)

3. 池化层


池化层用于减小特征图的空间维度,降低模型的参数数量,并增强模型的平移不变性。最大池化和平均池化是常用的池化方式,它们分别选择局部区域中的最大值和平均值作为输出。

下面是一个创建最大池化层的示例代码:

import torch
import torch.nn as nn

# 定义池化区域大小和步幅
kernel_size = 2
stride = 2

# 创建最大池化层
pool_layer = nn.MaxPool2d(kernel_size, stride)

# 打印最大池化层的参数
print("池化区域大小:", pool_layer.kernel_size)
print("步幅:", pool_layer.stride)
池化区域大小: 2
步幅: 2

4. 循环神经网络层


循环神经网络(Recurrent Neural Network, RNN)层用于处理序列数据,具有记忆性和上下文感知能力。RNN层通过在时间步之间共享权重,实现对序列的逐步处理,并输出相应的隐藏状态。

下面是一个创建RNN层的示例代码:

import torch
import torch.nn as nn

# 定义输入特征维度、隐藏状态维度和层数
input_size = 10
hidden_size = 20
num_layers = 2

# 创建RNN层
rnn_layer = nn.RNN(input_size, hidden_size, num_layers)

# 打印RNN层的参数
print("输入特征维度:", rnn_layer.input_size)
print("隐藏状态维度:", rnn_layer.hidden_size)
print("层数:", rnn_layer.num_layers)
输入特征维度: 10
隐藏状态维度: 20
层数: 2

5. 转置卷积层


转置卷积层,也被称为反卷积层,用于实现上采样操作,将低维特征图转换为高维特征图。转置卷积层通过反向卷积操作将输入特征图映射到更大的输出特征图。

下面是一个创建转置卷积层的示例代码:

import torch
import torch.nn as nn

# 定义输入通道数、输出通道数和卷积核大小
in_channels  = 3
out_channels = 16
kernel_size  = 3

# 创建转置卷积层
transconv_layer = nn.ConvTranspose2d(in_channels, out_channels, kernel_size)

# 打印转置卷积层的权重和偏置项
print("权重:", transconv_layer.weight)
print("偏置项:", transconv_layer.bias)
权重: Parameter containing:
tensor([[[[ 8.0337e-02, -5.3290e-02, -1.4227e-02],
          [-1.9393e-04,  6.7705e-02,  2.3693e-02],
          [-5.0714e-02, -5.4791e-02,  7.8475e-02]],

         [[ 7.3623e-02,  4.6865e-02,  1.3385e-02],
          [ 6.5051e-02,  3.6389e-02,  5.5755e-02],
          [ 4.0090e-02, -3.1917e-02,  1.9548e-02]],

         [[-8.0792e-02, -2.7151e-02,  3.4276e-02],
          [-6.5219e-02, -3.2363e-03,  5.4203e-02],
          [-3.7077e-02,  9.2059e-03,  5.2119e-02]],

         [[-2.1103e-02, -2.7336e-02, -2.6543e-02],
          [-4.4347e-02, -5.8803e-02, -4.8358e-02],
          [ 3.4405e-04, -6.5506e-02,  3.0627e-02]],

         [[-8.1671e-02,  8.1619e-02,  1.8560e-02],
          [-1.9866e-02,  2.6681e-02,  5.6004e-03],
          [ 1.3659e-02,  3.6211e-02, -5.3699e-02]],

         [[-4.5724e-02, -7.5193e-02, -3.8209e-03],
          [ 6.6623e-02,  1.5908e-03, -3.2417e-02],
          [ 7.7808e-03, -1.4226e-02,  6.9678e-02]],

         [[ 6.7328e-03,  1.3855e-02, -7.9766e-02],
          [ 6.6506e-03, -5.7292e-02,  5.5111e-02],
          [-3.2646e-02,  2.9270e-02,  6.3490e-02]],

         [[ 6.5985e-02, -1.8153e-02,  6.3027e-02],
          [-3.6988e-02,  4.7413e-02, -7.4083e-02],
          [-7.3442e-02,  6.3578e-02,  5.4356e-02]],

         [[ 8.3311e-02, -5.5200e-03,  3.2794e-02],
          [-6.7711e-02, -2.3519e-02, -4.3690e-03],
          [-4.6577e-02,  1.9896e-02, -1.2318e-02]],

         [[ 5.1379e-02, -6.1027e-02,  1.6910e-02],
          [-5.2694e-02,  7.7022e-02, -2.4261e-02],
          [-3.7686e-02, -3.3537e-02,  3.7553e-02]],

         [[ 6.0146e-02, -7.9421e-02,  1.4621e-02],
          [ 8.9704e-03, -1.8245e-02, -1.0750e-02],
          [ 3.6439e-02, -3.8996e-02, -3.5425e-02]],

         [[ 7.2467e-02, -5.6918e-02, -5.1322e-03],
          [ 3.4404e-02,  3.5839e-02,  7.9543e-02],
          [-5.1992e-02, -3.8758e-02, -5.0793e-02]],

         [[-5.1583e-03, -2.6107e-02, -7.5185e-02],
          [ 5.3435e-02,  3.8434e-02, -3.4169e-02],
          [ 3.1811e-02,  6.2372e-02, -8.0381e-02]],

         [[-4.1993e-03,  8.2090e-02,  7.8353e-02],
          [ 5.4283e-02, -2.5811e-02, -1.2738e-02],
          [-4.0004e-03, -2.9511e-02,  7.0177e-02]],

         [[ 5.0226e-03, -5.3602e-02, -1.6094e-02],
          [-2.9816e-02, -7.8850e-02, -1.2259e-02],
          [ 6.9906e-02,  8.3293e-02,  3.4800e-02]],

         [[ 4.5200e-02,  2.5704e-02, -6.9073e-02],
          [ 9.3102e-03, -6.0777e-02,  7.9753e-02],
          [ 4.9606e-02, -2.8049e-02, -6.2391e-02]]],


        [[[-4.1973e-02,  5.8681e-02, -1.8394e-04],
          [ 2.5699e-03, -6.9590e-02, -4.7623e-02],
          [ 7.3279e-02,  5.3475e-02,  5.3339e-02]],

         [[-5.5722e-03,  8.2688e-02, -6.8631e-02],
          [-5.9470e-02,  3.2761e-02, -8.4429e-03],
          [ 8.2403e-02,  7.4375e-03, -4.2118e-02]],

         [[-3.7643e-02,  1.7817e-02,  5.1786e-02],
          [ 5.1679e-02,  6.8788e-02, -1.8910e-02],
          [ 3.9397e-02,  5.3203e-02,  7.3478e-02]],

         [[-5.7531e-02, -7.5814e-02, -3.0943e-02],
          [ 1.9954e-02,  4.9566e-02,  4.4010e-02],
          [-3.8031e-02, -3.4986e-02, -2.7309e-02]],

         [[-5.8555e-02, -1.4236e-02,  7.7856e-02],
          [ 4.3719e-02, -4.2244e-02, -1.5429e-02],
          [ 6.5436e-02,  6.0987e-02, -5.0919e-02]],

         [[-1.5866e-02,  7.6553e-02,  6.2690e-02],
          [-2.9115e-02,  2.9367e-02,  2.2441e-02],
          [ 5.6212e-04, -2.9069e-02,  6.0923e-02]],

         [[-4.4368e-02, -7.6175e-02, -5.5346e-02],
          [ 3.7252e-02,  4.2245e-02, -6.7618e-02],
          [ 6.7123e-02,  1.3725e-02, -5.2082e-02]],

         [[ 8.0160e-02, -4.8511e-02, -4.6862e-02],
          [ 8.2586e-02, -5.8021e-02,  5.9757e-02],
          [ 1.3346e-02, -1.9400e-02, -4.4925e-02]],

         [[ 1.1365e-02,  3.8480e-02, -3.1870e-02],
          [-6.3343e-02,  4.9030e-04, -5.2625e-02],
          [-4.8018e-02,  3.2482e-02, -2.8751e-02]],

         [[ 2.8370e-02,  4.2409e-02, -3.4847e-02],
          [-5.7933e-02, -3.4665e-03,  7.8222e-03],
          [ 5.8363e-03, -7.3786e-03, -3.4644e-02]],

         [[-8.2729e-02, -1.9254e-02,  2.4401e-02],
          [ 5.2219e-02,  6.7396e-02, -7.9750e-02],
          [ 8.2343e-02,  1.7320e-02,  3.9825e-02]],

         [[ 1.2627e-02,  6.3698e-02,  2.5438e-02],
          [-7.2730e-03,  7.3225e-02,  5.8660e-03],
          [ 7.7996e-02, -4.5768e-02, -2.3353e-02]],

         [[-2.8095e-02,  6.1991e-02, -1.3111e-02],
          [ 4.0495e-02, -3.7304e-02,  3.0700e-02],
          [ 2.0243e-02, -2.1378e-02, -8.1052e-02]],

         [[-9.0821e-03,  5.9058e-02, -2.0641e-02],
          [-3.8654e-02, -5.8057e-02, -1.7079e-02],
          [-5.6377e-02, -3.2291e-02,  9.0084e-03]],

         [[ 5.3056e-03,  3.0105e-02, -4.4222e-02],
          [ 3.3466e-03,  2.9610e-02,  1.3245e-02],
          [-7.6440e-02,  7.3664e-02,  2.5589e-02]],

         [[ 6.4978e-02,  6.0809e-02,  1.5091e-02],
          [-2.5361e-02, -7.1247e-02, -1.7257e-02],
          [-2.4302e-02,  5.5621e-02, -1.2518e-03]]],


        [[[ 3.5745e-03,  8.7403e-05,  1.0212e-02],
          [-6.5297e-03,  2.0753e-02, -1.4180e-02],
          [ 7.2512e-03,  2.6745e-02, -5.2314e-02]],

         [[ 3.0589e-02, -7.1941e-02, -4.8909e-02],
          [-5.8514e-02, -6.5524e-02, -1.3371e-02],
          [ 4.3143e-02,  1.2035e-02,  5.6237e-03]],

         [[ 6.5425e-02, -8.0787e-02,  5.0200e-02],
          [-2.7970e-02,  4.9760e-03, -2.1178e-02],
          [ 5.1334e-02,  4.0819e-02,  4.2965e-02]],

         [[-4.1454e-03, -4.7567e-02, -2.6652e-02],
          [ 4.4646e-02, -4.0077e-02, -5.8836e-02],
          [-1.9060e-02,  7.6014e-03, -2.9705e-02]],

         [[ 3.1283e-02,  6.6525e-02,  5.7250e-02],
          [ 2.7015e-02, -4.6071e-02, -7.8722e-02],
          [ 6.0060e-02,  6.4585e-02,  2.0736e-02]],

         [[ 7.6352e-02,  1.8075e-03,  8.6711e-03],
          [ 1.2982e-02, -7.1410e-02, -5.2776e-02],
          [-7.0903e-02,  7.8236e-03, -2.5906e-02]],

         [[ 2.4674e-02,  1.0393e-03,  5.3176e-03],
          [-2.6809e-02,  5.4121e-02, -8.2610e-02],
          [-5.1189e-02, -5.2251e-03,  1.1300e-02]],

         [[ 4.1785e-02, -2.9654e-02, -2.3086e-02],
          [ 2.3844e-02,  2.2076e-02,  1.2446e-03],
          [-4.2982e-02, -1.2210e-02, -7.3293e-02]],

         [[-2.8980e-02, -7.6738e-02,  2.2373e-02],
          [-9.5916e-03,  2.3340e-02,  7.7785e-02],
          [-6.2898e-02,  5.5991e-02,  4.2896e-02]],

         [[ 3.7710e-02, -7.2854e-02, -2.3186e-02],
          [ 6.1990e-02,  5.5058e-02,  2.2754e-03],
          [-7.0830e-02,  6.1900e-02, -7.2029e-02]],

         [[-2.5010e-02,  1.9187e-02, -1.7299e-02],
          [ 7.4672e-03,  4.5133e-04,  4.3270e-02],
          [ 3.0643e-02,  4.8289e-02,  4.5050e-04]],

         [[ 7.2089e-02, -4.4900e-02, -8.0196e-04],
          [ 5.1503e-02,  6.7008e-02,  3.6474e-02],
          [ 6.4141e-02, -9.5832e-04,  1.8785e-02]],

         [[ 1.7959e-02, -2.6235e-02, -5.3580e-02],
          [ 2.9562e-02,  6.5374e-02,  3.7110e-02],
          [-5.8930e-02, -7.7408e-03,  6.9015e-02]],

         [[ 3.9906e-02, -6.9693e-02,  6.6588e-02],
          [-4.7758e-04,  2.0158e-02, -1.6543e-02],
          [-1.6675e-02, -3.6007e-02, -2.8361e-02]],

         [[ 6.0041e-02,  4.7118e-02, -7.6530e-02],
          [-1.4904e-02, -7.1457e-02, -7.9006e-02],
          [ 1.5107e-02, -1.0651e-02, -6.9624e-02]],

         [[ 1.2723e-02, -4.7516e-02, -5.4287e-02],
          [-4.5850e-02, -3.2351e-02,  4.9753e-02],
          [-5.8311e-02, -4.2825e-02, -3.8202e-02]]]], requires_grad=True)
偏置项: Parameter containing:
tensor([-0.0486,  0.0357, -0.0204, -0.0689,  0.0183,  0.0416, -0.0358, -0.0796,
         0.0273,  0.0273, -0.0195, -0.0703,  0.0743,  0.0803,  0.0407, -0.0437],
       requires_grad=True)

6. 归一化层


归一化层用于调整神经网络的激活值分布,提升模型的收敛速度和泛化能力。常用的归一化层包括批归一化(Batch Normalization)和层归一化(Layer Normalization)。

下面是一个创建批归一化层的示例代码:

import torch
import torch.nn as nn

# 定义特征维度
num_features = 16

# 创建批归一化层
bn_layer = nn.BatchNorm2d(num_features)

# 打印批归一化层的参数
print("特征维度:", bn_layer.num_features)
print("均值:", bn_layer.running_mean)
print("方差:", bn_layer.running_var)
特征维度: 16
均值: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
方差: tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

7. 激活函数层


激活函数层用于引入非线性变换,增加神经网络的表达能力。常用的激活函数包括ReLU、Sigmoid和Tanh等。

下面是一个使用ReLU激活函数的示例代码:

import torch
import torch.nn as nn

# 创建激活函数层(ReLU)
activation_layer = nn.ReLU()

# 定义输入张量
input_tensor = torch.randn(10)

# 对输入张量进行激活函数变换
output_tensor = activation_layer(input_tensor)

# 打印输出张量
print("输出张量:", output_tensor)
输出张量: tensor([0.0000, 0.0000, 1.7597, 0.0000, 0.9116, 1.0483, 0.9865, 0.0000, 0.3673,
        0.7428])

三、数据加载与预处理

在深度学习任务中,数据的加载和预处理是非常重要的步骤。PyTorch提供了强大的数据加载和预处理工具,使得我们能够高效地处理各种类型的数据。这里将介绍PyTorch中的数据加载和预处理方法,并提供使用示例。

1. 数据加载

PyTorch中的数据加载主要通过torch.utils.data模块实现。该模块提供了DatasetDataLoader两个核心类,分别用于定义数据集和数据加载器。

🚩Dataset

Dataset类是一个抽象类,用于表示数据集。我们可以继承该类并实现自定义的数据集。在自定义数据集中,我们需要实现两个方法:__len____getitem____len__方法返回数据集的样本数量,__getitem__方法根据索引返回单个样本。

以下是一个自定义数据集的示例:

import torch
from torch.utils import data
from torch.utils.data import Dataset

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        # 初始化数据集
        self.data_list = data_list

    def __len__(self):
        # 返回数据集大小
        return len(self.data_list)

    def __getitem__(self, index):
        # 根据索引获取样本
        sample = self.data_list[index]
        return sample

在上述示例中,MyDataset类接受一个数据列表作为输入,并实现了__len____getitem__方法。

🚩DataLoader

torch.utils.data.DataLoader是PyTorch中一个重要的类,用于高效加载数据集。它可以处理数据的批次化、打乱顺序、多线程数据加载等功能。以下是一个简单的示例:

import torch.utils.data as data

my_dataset = MyDataset([1, 2, 3, 4, 5])

my_dataloader = data.DataLoader(my_dataset, 
                                batch_size=4, 
                                shuffle=True)

for batch in my_dataloader:
    print(batch)
tensor([4, 5, 2, 1])
tensor([3])

在这个示例中,我们首先创建了一个MyDataset实例my_dataset,它包含了一个整数列表。然后,我们使用DataLoader类创建了一个数据加载器my_dataloader,它将my_dataset作为输入,并将数据分成大小为4的批次,并对数据进行随机化。最后,遍历my_dataloader,并打印出每个批次的数据。

总结一下,torch.utils.data.Dataset用于构建数据集,torch.utils.data.DataLoader用于加载数据集,并对数据进行批量处理和随机化。下面是一个完整的示例,展示了如何使用这两个类来加载和处理数据:

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        # 初始化数据集
        self.data_list = data_list

    def __len__(self):
        # 返回数据集大小
        return len(self.data_list)

    def __getitem__(self, index):
        # 根据索引获取样本
        sample = self.data_list[index]
        return sample

my_dataset = MyDataset([1, 2, 3, 4, 5])

my_dataloader = data.DataLoader(my_dataset, 
                                batch_size=4, 
                                shuffle=True)

for batch in my_dataloader:
    print(batch)
tensor([3, 1, 5, 2])
tensor([4])

除了上述介绍的基本用法,torch.utils.data模块还有许多其他的功能和选项。下面介绍一些常用的选项和功能。

2. 数据预处理

数据预处理是在将数据输入模型之前对数据进行的一系列操作,以提高模型的性能和准确性。PyTorch提供了多种数据预处理方法,包括常见的数据变换、标准化、图像增强等。以下是一些常见的数据预处理方法:

🚩Tensor转换

将数据转换为torch.Tensor类型是数据预处理的第一步。torch.Tensor是PyTorch中表示张量的主要数据类型。

import torch  

data = [1, 2, 3, 4, 5]  
tensor = torch.tensor(data)  

🚩数据变换

数据变换是对数据进行形状调整或维度变换的操作。PyTorch提供了一系列的数据变换方法,如torchvision.transforms模块中的ResizeToTensor等。

from torchvision import transforms  

transform = transforms.Compose([  
    transforms.Resize((224, 224)),  
    transforms.ToTensor()  
])  

# 对数据进行变换  
transformed_data = transform(data)  

🚩数据标准化

数据标准化是对数据进行平均值和标准差的缩放,以使得数据具有零均值和单位方差。这通常用于提高模型的收敛性和稳定性。

import torchvision.transforms as transforms  

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],  
                                 std=[0.229, 0.224, 0.225])  

# 对图像进行标准化  
normalized_image = normalize(image)  

🚩图像增强

图像增强是对图像进行变换或添加噪声,以增加训练数据的多样性和鲁棒性。PyTorch提供了torchvision.transforms模块中的多种图像增强方法,如随机裁剪、翻转、旋转等。

import torchvision.transforms as transforms  

transform = transforms.Compose([  
    transforms.RandomCrop(224),  
    transforms.RandomHorizontalFlip(),  
    transforms.RandomRotation(30)  
])  

# 对图像进行增强  
transformed_image = transform(image)  

本节介绍了PyTorch中的数据加载和预处理方法。通过自定义数据集和数据加载器,我们可以高效地加载和处理数据。同时,PyTorch提供了多种数据预处理方法,如数据变换、标准化和图像增强,以提高模型的性能和准确性。

四、模型训练与验证

1. 模型训练

PyTorch中的模型训练主要涉及以下几个步骤:

  1. 准备数据:首先,我们需要准备好训练数据和对应的标签。可以使用torch.utils.data模块中的DatasetDataLoader类来加载和批量处理数据。
  2. 定义模型:接下来,我们需要定义模型的结构。可以使用torch.nn模块中的各种层和模型来构建自己的神经网络模型。
  3. 定义损失函数:为了训练模型,我们需要定义损失函数来度量模型预测结果与真实标签之间的差异。可以使用torch.nn模块中的各种损失函数,如均方误差(MSE)、交叉熵损失等。
  4. 定义优化器:为了更新模型的参数,我们需要选择一个优化器来优化模型的损失函数。可以使用torch.optim模块中的各种优化器,如随机梯度下降(SGD)、Adam等。
  5. 训练模型:在每个训练迭代中,我们需要执行以下步骤:
  • 前向传播:将输入数据通过模型,得到模型的输出结果。
  • 计算损失:将模型的输出结果与真实标签计算损失函数的值。
  • 反向传播:根据损失函数的梯度,计算模型参数的梯度。
  • 参数更新:使用优化器根据梯度信息更新模型的参数。

以下是一个简单的模型训练示例:

import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader  

# 准备数据  
train_dataset = MyDataset(train_data)  
train_dataloader = DataLoader(train_dataset,  
                              batch_size=64,  
                              shuffle=True)  

# 定义模型  
model = MyModel()  

# 定义损失函数  
loss_fn = nn.CrossEntropyLoss()  

# 定义优化器  
optimizer = optim.SGD(model.parameters(), lr=0.01)  

# 模型训练  
for epoch in range(num_epochs):  
    for batch in train_dataloader:  
        inputs, labels = batch  

        # 前向传播  
        outputs = model(inputs)  

        # 计算损失  
        loss = loss_fn(outputs, labels)  

        # 反向传播  
        optimizer.zero_grad()  
        loss.backward()  

        # 参数更新  
        optimizer.step()  

在上述示例中,我们使用自定义的数据集和数据加载器准备训练数据,定义了模型、损失函数和优化器,并在每个训练迭代中执行了前向传播、计算损失、反向传播和参数更新的步骤。

2. 模型验证

在模型训练之后,我们需要对模型进行验证以评估其性能和准确性。模型验证的步骤与模型训练类似,但不需要进行参数更新。

以下是一个简单的模型验证示例:

# 准备验证数据  
val_dataset = MyDataset(val_data)  
val_dataloader = DataLoader(val_dataset, batch_size=64)  

# 模型验证  
model.eval()  # 设置模型为评估模式  

with torch.no_grad():  # 禁止梯度计算  
    for batch in val_dataloader:  
        inputs, labels = batch  

        # 前向传播  
        outputs = model(inputs)  

        # 在这里可以对模型输出进行后处理,如计算准确率、绘制预测结果等  

在上述示例中,我们使用自定义的验证数据集和数据加载器准备验证数据,并使用model.eval()将模型设置为评估模式。然后,在验证数据上进行前向传播,并根据需要对模型输出进行后处理。

介绍了PyTorch中的模型训练和验证方法。通过准备数据、定义模型、损失函数和优化器,以及执行训练和验证循环,我们可以高效地训练和评估深度学习模型。

闯关练习

👉练习1: 请使用 .DataLoader 加载列表 [12,1,2,3,4,5]my_dataloaderbatch_size 设置为3,不打乱数据,输出my_dataloader,并将第一行输出放到answer1中

answer1 = 0

👉练习2: 请使用 .DataLoader 加载列表 [12,1,2,3,4,5]my_dataloaderbatch_size 设置为4,不打乱数据,输出my_dataloader,并将第最后一行输出放到answer2中

answer2 = 0

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值