神经网络——线性层

一、线性层简介:

线性层输入的参数 input_features 就是输入的x,输出 out_features 就是输出的g。g与x的关系为: 

其中 W是权重,b是偏置(参数bias为True时才存在) ,weight和bais都是从分布中采样初始化,经过训练得最终结果

在从vgg模型中将224*224*3的图片转化为了1*1*4096的大小,经训练得到了1*1*1000 的结果

import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision

tensor_trans = transforms.ToTensor()

test_data = torchvision.datasets.CIFAR10(root="./dataloader",train=False,transform=tensor_trans,download=True)

dataloader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)

# 创建一个神经网络——线性层
class MyModule(nn.Module):

    def __init__(self,in_features,out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(self.in_features,self.out_features)

    def forward(self,imgs):
        imgs = self.linear(imgs)
        return imgs

linear = MyModule(196608,10)
# 输入是196608的原因是:后序代码会利用flatten方法将图片变为[196608,1,1,1]的格式,因此输入是196608

for each in dataloader:
    imgs,labels = each
    print(imgs.shape)
    output = torch.flatten(imgs)
    # flatten的作用:输入的照片,本质是多维矩阵,flatten的作用就是将多位矩阵变成一个类似与一维数组的结构并输出
    print(output.shape)
    output = linear(output)
    # 将用flatten方法加工过后的图片再放入线性层加工,使得n个输入变成m个输出(通常n > m),这就是线性层的功能
    print(output.shape)

线性层工作图示:

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值