线性层/全连接层

1.什么是线性层

        作用类似于经典的全连接层(Fully Connected Layer),常用于连接神经网络中的不同层次。主要功能是执行线性变换,也就是对输入数据进行矩阵乘法并添加一个偏置(bias)。

         线性变换:有线性层将输入张量乘以权重矩阵并加上偏置项。如图所受,这种线性变换可以被表示为: 

g = xW^{T} + b

        也就是每个g都是由与它连线的x乘以相应权重k再加上偏值b。

        x 是输入张量,通常是二维的,形状为 (N,in_features),其中 N 是批量大小(batch size),in_features 是输入的特征数。

        W 是权重矩阵,形状为 (out_features,in_features)。

        b 是偏置向量,形状为 (out_features)。

        特征提取: 线性层通过权重矩阵的学习,能够从输入数据中提取不同的特征。不同的神经元可能会学习到不同的特征,使得网络能够更好地对输入数据进行表示和分类。

        信息聚合: 线性层可以聚合来自上一层的所有信息,并通过权重矩阵和偏置项的调整,将这些信息转化为下一层所需的输入。

        复杂模型的构建: 线性层是构建复杂神经网络的基础组件。通过堆叠多个线性层和非线性激活函数,可以构建出能够处理复杂任务的深度神经网络。

2.Linear

import torch
import torch.nn as nn

# 创建一个线性层,输入特征数为 in_features,输出特征数为 out_features
linear = nn.Linear(in_features, out_features)

就是将一个变量从原来的 input feature=4096映射到output feature = 1000

比如以下代码帮助理解

# 假设输入张量 x 有 3 个特征
x = torch.tensor([[1.0, 2.0, 3.0]])

# 创建一个线性层,输入特征数为 3,输出特征数为 2
linear = nn.Linear(3, 2)

# 前向传播,计算输出
output = linear(x)
print(output)

        得到的每次结果基本都是随机的,比如此次结果tensor([[0.3921, 1.1069]], grad_fn=<AddmmBackward0>)。grad_fn 是 PyTorch 自动求导机制的一部分,表示这个张量是由一个或多个张量经过某种操作生成的,并且可以用于反向传播计算梯度。
        <AddmmBackward0> 具体表示这个操作是通过矩阵乘法(mm)和加法(Add)得到的。即 Addmm 是 PyTorch 中一个函数,用于描述线性层的操作:先做矩阵乘法(mm),再加上偏置(Add)。

3.对CIFAR10数据集操作

import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
from torch.nn import Linear

dataset = torchvision.datasets.CIFAR10('E:\\PyCharm_Project\\Pytorch_2.3.1\\PytorchVision\\dataset' , train = False , transform = torchvision.transforms.ToTensor(),download = True)
dataloader = DataLoader(dataset , batch_size = 64)

class Zilliax(nn.Module):
    def __init__(self):
        super(Zilliax , self).__init__()
        self.linear1 = Linear(196608 , 10)
        
    def forward(self , input):
        output = self.linear1(input)
        return output
    
z = Zilliax()

for data in dataloader:
    imgs,targets = data
    print(imgs.shape)
    print("________________________")
#    output = torch.reshape(imgs , (1,1,1,-1))     # -1的意思是要代码自己计算,打印出来是[1,1,1,196608]
    output = torch.flatten(imgs)        # 摊平,一向箔,摊成一维的tensor,打印出来是[196608]
    print(output.shape)
    print("________________________")
    output = z(output)
    print(output.shape)      # 会在处理打包末尾多余图片时出现错误,因为不满足196608
    print("________________________")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值