1.什么是线性层
作用类似于经典的全连接层(Fully Connected Layer),常用于连接神经网络中的不同层次。主要功能是执行线性变换,也就是对输入数据进行矩阵乘法并添加一个偏置(bias)。
线性变换:有线性层将输入张量乘以权重矩阵并加上偏置项。如图所受,这种线性变换可以被表示为:
也就是每个g都是由与它连线的x乘以相应权重k再加上偏值b。
x 是输入张量,通常是二维的,形状为 (N,in_features),其中 N 是批量大小(batch size),in_features 是输入的特征数。
W 是权重矩阵,形状为 (out_features,in_features)。
b 是偏置向量,形状为 (out_features)。
特征提取: 线性层通过权重矩阵的学习,能够从输入数据中提取不同的特征。不同的神经元可能会学习到不同的特征,使得网络能够更好地对输入数据进行表示和分类。
信息聚合: 线性层可以聚合来自上一层的所有信息,并通过权重矩阵和偏置项的调整,将这些信息转化为下一层所需的输入。
复杂模型的构建: 线性层是构建复杂神经网络的基础组件。通过堆叠多个线性层和非线性激活函数,可以构建出能够处理复杂任务的深度神经网络。
2.Linear
import torch
import torch.nn as nn
# 创建一个线性层,输入特征数为 in_features,输出特征数为 out_features
linear = nn.Linear(in_features, out_features)
就是将一个变量从原来的 input feature=4096映射到output feature = 1000
比如以下代码帮助理解
# 假设输入张量 x 有 3 个特征
x = torch.tensor([[1.0, 2.0, 3.0]])
# 创建一个线性层,输入特征数为 3,输出特征数为 2
linear = nn.Linear(3, 2)
# 前向传播,计算输出
output = linear(x)
print(output)
得到的每次结果基本都是随机的,比如此次结果tensor([[0.3921, 1.1069]], grad_fn=<AddmmBackward0>)。grad_fn 是 PyTorch 自动求导机制的一部分,表示这个张量是由一个或多个张量经过某种操作生成的,并且可以用于反向传播计算梯度。
<AddmmBackward0> 具体表示这个操作是通过矩阵乘法(mm)和加法(Add)得到的。即 Addmm 是 PyTorch 中一个函数,用于描述线性层的操作:先做矩阵乘法(mm),再加上偏置(Add)。
3.对CIFAR10数据集操作
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn import Linear
dataset = torchvision.datasets.CIFAR10('E:\\PyCharm_Project\\Pytorch_2.3.1\\PytorchVision\\dataset' , train = False , transform = torchvision.transforms.ToTensor(),download = True)
dataloader = DataLoader(dataset , batch_size = 64)
class Zilliax(nn.Module):
def __init__(self):
super(Zilliax , self).__init__()
self.linear1 = Linear(196608 , 10)
def forward(self , input):
output = self.linear1(input)
return output
z = Zilliax()
for data in dataloader:
imgs,targets = data
print(imgs.shape)
print("________________________")
# output = torch.reshape(imgs , (1,1,1,-1)) # -1的意思是要代码自己计算,打印出来是[1,1,1,196608]
output = torch.flatten(imgs) # 摊平,一向箔,摊成一维的tensor,打印出来是[196608]
print(output.shape)
print("________________________")
output = z(output)
print(output.shape) # 会在处理打包末尾多余图片时出现错误,因为不满足196608
print("________________________")