多分类-mnist数据集-Pytorch实现
model.parameters()
for param in model.parameters(): # model.parameters() 返回 generate 迭代器
print(type(param), param.size())
out:
<class 'torch.nn.parameter.Parameter'> torch.Size([100, 784])
<class 'torch.nn.parameter.Parameter'> torch.Size([100])
<class 'torch.nn.parameter.Parameter'> torch.Size([10, 100])
<class 'torch.nn.parameter.Parameter'> torch.Size([10])
由 model.parameters() 的输出结果,可以看出返回的是 各个层的 权重w 和 偏置 b。
在 单个训练集循环中,模型在训练集上做出预测,并反向更新(backpropagates)预测误差来调整模型的参数。
下载数据集
training_data = datasets.MNIST(
root='data', # 数据下载后存储的根目录文件夹名称
train=True,
download=True,
transform=ToTensor(),
)
test_data = datasets.MNIST(
root='data',
train=False,
download=True,
transform=ToTensor(),
)
加载数据集
batch_size = 60 # 初始化训练轮数
# 加载 mnist 数据集 : 返回 containing batch_size=10 features and labels respectively
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True) # 共60000条数据
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True) # 共10000条数据
测试:
for X, y in test_dataloader:
print(f"Shape of X[N, C, H, W]: {
X.shape}")
print(f"Shape of y: {
y.size()}")
break
out:
Shape of X[N, C, H, W]: torch.Size([60, 1, 28, 28])
Shape of y: torch.Size([60])
机器是否支持GPU
# 训练集是否使用 CPU 或者 GPU device 进行训练
device = "cuda" if torch.cuda.is_available else "cpu"
定义神经网络
class NeuralNetwork(nn.Module):
""":arg
在Pytorch中定义神经网络,我们创建 NeuralNetwork 类 继承nn.Module。我们定义在 __init__ 函数中定义网络的层数,
在 forword() 中明确 数据怎样通过 网络。为了加速神经网络的训练,如果GPU可用就把它放到GPU上。
init : 初始化神经网络的曾说以及训练数据要通过的方法
"""
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 30), # z = x·w^T + b, 参数表示:(输入层神经元个数, 输出层神经元个数)
nn.Sigmoid(),
nn.Linear(30, 10),
nn.Sigmoid()
)
# 前向传播
def forward(self, x):
# print('-'*6 + 'forward' + '-'*6)
x = self.flatten(x) # 把 (1, 28, 28) 转为 (1, 784)
# print('----flatten()-----')
# print(x.shape)
logits = self.linear_relu_stack(x)
return logits
pass
使用GPU进行训练
model = NeuralNetwork().to(device) # 使用 GPU 训练
print(model)
优化模型参数
loss_fn = nn