今天学习了神经网络-线性层及其他层。这段代码的主要目的是为了实现一个简单的神经网络模型,并在 CIFAR-10 数据集上进行预测。
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)
class Yang(nn.Module):
def __init__(self):
super(Yang, self).__init__()
self.linear1 = Linear(196608, 10)
def forward(self, input):
output = self.linear1(input)
return output
yang = Yang()
for data in dataloader:
imgs, targets = data
print(imgs.shape)
output = torch.flatten(imgs)
print(output.shape)
output = yang(output)
print(output.shape)
1. `import torch`:导入 PyTorch 库,这是一个用于机器学习和深度学习的开源库。
2. `import torchvision`:导入 torchvision 库,这个库包含了许多常用的数据集和预训练模型。
3. `from torch import nn`:从 PyTorch 库中导入神经网络模块。
4. `from torch.nn import Linear`:从神经网络模块中导入线性层,这是神经网络中常用的一种层。
5. `from torch.utils.data import DataLoader`:从 PyTorch 库中导入 DataLoader,这是一个可以使数据加载更加容易和快速的工具。
6. `dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(), download=True)`:这行代码从 torchvision 数据集中加载 CIFAR-10 数据集。这个数据集包含了 60000 张 32x32 的彩色图像,共有 10 个类别。`train=False` 表示加载的是测试集,`transform=torchvision.transforms.ToTensor()` 表示将加载的图像转换为 PyTorch Tensor,`download=True` 表示如果数据集未在指定路径下找到,就下载数据集。
7. `dataloader = DataLoader(dataset, batch_size=64)`:创建一个 DataLoader 对象,这个对象会生成一个迭代器,每次从数据集中提取一批数据(批大小为64)。