完整文件:https://github.com/JintuZheng/Blog-/blob/master/Demo_LogicRegression_MNIST.py
包导入准备
import torchvision.datasets
import torchvision.transforms
import torch.utils.data
import matplotlib.pyplot as plt
import torch.nn
import torch.optim
from debug import ptf_tensor
设置超参数
# Hyperparameters超参数
BATCH_SIZE=100
NUM_EPOCHS=5
DEVICE='cuda:0'
数据集下载
########################## 训练集的准备 ##############################################
train_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#root:下载数据存放到哪里,train:下载训练集还是测试集,transfrom:数据转化的形式
test_dataset=torchvision.datasets.MNIST(root='D:/DataTmp/mnist',train=False, transform=torchvision.transforms.ToTensor(),download=True)
【1】设置dataloader,分批读取数据,因为我们没办法一次训练过多数据
#由于数据集里面有上万条数据,我们需要分批从数据集读取数据
train_dataloader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE)
print('The len of train dataset={}'.format(len(train_dataset)))
test_dataloader=torch.utils.data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE)
print('The len of test dataset={}'.format(len(test_dataset)))
【2】查看数据格式
for images,labels in train_dataloader:
print('The images size is {}',format(images.size()))
print('The labels size is {}'.format(labels.size()))
break #本循环就是执行一次
线性分类器准备
【1】构建一层的分类器
fc=torch.nn.Linear(28*28,10) #只使用一层线性分类器
fc.to(DEVICE)#如果用CPU去掉
【2】构建损失函数
criterion=torch.nn.CrossEntropyLoss()
【3】根据假设函数的参数构建优化器
optimizer=torch.optim.Adam(fc.parameters())
开始迭代训练
for epoch in range(NUM_EPOCHS):
for idx, (images,labels) in enumerate(train_dataloader):
x =images.reshape(-1,28*28)
x=x.to(DEVICE)# 如果用CPU去掉
labels=labels.to(DEVICE)# 如果用CPU去掉
optimizer.zero_grad() #梯度清零
preds=fc(x) #计算预测
loss=criterion(preds,labels) #计算损失
loss.backward() # 计算参数梯度
optimizer.step() # 更新迭代梯度
if idx % 100 ==0:
print('epoch={}:idx={},loss={:g}'.format(epoch,idx,loss))
检验最后的正确率
correct=0
total=0
for idx,(images,labels) in enumerate(test_dataloader):
x =images.reshape(-1,28*28) #对所有的图片进行reshape size(m,28*28)
x=x.to(DEVICE)
labels=labels.to(DEVICE)
preds=fc(x)
predicted=torch.argmax(preds,dim=1) #在dim=1中选取max值的索引
if idx ==0:
print('x size:{}'.format(x.size()))
print('preds size:{}'.format(preds.size()))
print('predicted size:{}'.format(predicted.size()))
total+=labels.size(0)
correct+=(predicted == labels).sum().item()
#print('##########################\nidx:{}\npreds:{}\nactual:{}\n##########################\n'.format(idx,predicted,labels))
accuracy=correct/total
print('{:1%}'.format(accuracy))
参数数据的保存和复原
#保存
torch.save(fc.state_dict(), 'D:/DataTmp/mnist/tst.pth')
fc=torch.nn.Linear(28*28,10) #只使用一层线性分类器
#复原
fc.to(DEVICE)#如果用CPU去掉
fc.load_state_dict(torch.load('D:/DataTmp/mnist/tst.pth'))