一、线性层简介:
线性层输入的参数 input_features 就是输入的x,输出 out_features 就是输出的g。g与x的关系为:
其中 W是权重,b是偏置(参数bias为True时才存在) ,weight和bais都是从分布中采样初始化,经过训练得最终结果
在从vgg模型中将224*224*3的图片转化为了1*1*4096的大小,经训练得到了1*1*1000 的结果
import torch
from torch import nn
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
tensor_trans = transforms.ToTensor()
test_data = torchvision.datasets.CIFAR10(root="./dataloader",train=False,transform=tensor_trans,download=True)
dataloader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
# 创建一个神经网络——线性层
class MyModule(nn.Module):
def __init__(self,in_features,out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = nn.Linear(self.in_features,self.out_features)
def forward(self,imgs):
imgs = self.linear(imgs)
return imgs
linear = MyModule(196608,10)
# 输入是196608的原因是:后序代码会利用flatten方法将图片变为[196608,1,1,1]的格式,因此输入是196608
for each in dataloader:
imgs,labels = each
print(imgs.shape)
output = torch.flatten(imgs)
# flatten的作用:输入的照片,本质是多维矩阵,flatten的作用就是将多位矩阵变成一个类似与一维数组的结构并输出
print(output.shape)
output = linear(output)
# 将用flatten方法加工过后的图片再放入线性层加工,使得n个输入变成m个输出(通常n > m),这就是线性层的功能
print(output.shape)
线性层工作图示: