05 - 全连接层
概念
nn.Linear
在PyTorch中是用于创建全连接层(全连接神经网络层)的类。全连接层也被称为密集连接层或者全连接神经网络层,是深度学习中常用的一种网络层类型。
全连接层的作用是将上一层的所有节点都连接到当前层的每个节点,这样每个节点都与上一层的所有节点相连,形成了完全连接的网络结构。这种结构可以学习到输入数据之间的复杂非线性关系,从而提高模型的表达能力。
在PyTorch中,使用nn.Linear
可以方便地创建全连接层,指定输入特征的维度和输出特征的维度。例如,下面是一个创建全连接层的示例:
import torch
import torch.nn as nn
# 定义一个全连接层,输入特征维度为10,输出特征维度为5
linear_layer = nn.Linear(in_features=10, out_features=5)
# 打印全连接层的权重和偏置项
print("全连接层的权重:")
print(linear_layer.weight)
print("全连接层的偏置项:")
print(linear_layer.bias)
# 输入数据
input_data = torch.randn(1, 10)
# 计算全连接层的输出
output = linear_layer(input_data)
print("全连接层的输出:")
print(output)
在这个示例中,我们创建了一个输入维度为10、输出维度为5的全连接层,并打印了其权重和偏置项。然后,我们将输入数据传递给全连接层,计算得到输出。这个输出就是经过全连接层计算后得到的结果。
示例
import torch
import torchvision
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=True,
transform=torchvision.transforms.ToTensor())
# drop_last=True防止最后因为数据不一样多的报错
dataloader = DataLoader(dataset, batch_size=64, drop_last=True)
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.linear = Linear(196608, 10)
def forward(self, input):
output = self.linear(input)
return output
net = MyNet()
for data in dataloader:
imgs, targets = data
print(imgs.shape)
output = torch.reshape(imgs, (1, 1, 1, -1))
# torch.flatten的作用是将输入张量(tensor)按照指定的维度(或者默认按照第一个维度)展平为一维张量。
# 就是说使用self.linear = Linear(196608, 10)的效果跟torch.flatten的效果是一样的
print(output.shape)
output = net(output)
print(output.shape)
torch.flatten的作用是将输入张量(tensor)按照指定的维度(或者默认按照第一个维度)展平为一维张量。就是说使用self.linear = Linear(196608, 10)的效果跟torch.flatten的效果是一样的。