1.线性神经网络模型
2.Linear Layers参数介绍
3.代码实战
- 实现下图Fully Connected的类似部分(图中数据与代码无关)
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class My_module(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.linear = nn.Linear(in_features=196608, out_features=10) # 使torch.Size([196608])->torch.Size([10])
def forward(self, x):
output = self.linear(x)
return output
my_module = My_module()
test_dataset = datasets.CIFAR10(root="datasets", transform=transforms.ToTensor(), download=True)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=True, drop_last=False)
for data in test_dataloader:
imgs, labels = data
imgs_tensor = torch.flatten(imgs) # torch.flatten()作用是将张量展平,使torch.Size([64, 3, 32, 32])->torch.Size([196608])
imgs_linear = my_module(imgs_tensor)