import torch
from torch import nn,optim
import random
import numpy as np
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import os
input_size=28*28
num_classes=10
learnning_rate=0.001
num_epochs=5
batch_size=100
#1、准备数据
#训练数据
train_data=torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True)
#创建训练数据迭代器
train_loader=torch.utils.data.DataLoader(dataset=train_data,
batch_size=batch_size,
shuffle=True)
#测试数据
test_data=torchvision.datasets.MNIST(root='../../data',
train=False,
transform=transforms.ToTensor(),
download=True)
#测试数据迭代器(我这个是一次加载完)
test_loader=torch.utils.data.DataLoader(dataset=test_data,
batch_size=len(test_data))
is_loadmodel=0 #是否加载已有模型
if is_loadmodel==0:
#2、建立模型
model=nn.Linear(input_size,num_classes)
criterion=nn.CrossEntropyLoss() #损失函数
optimizer=optim.SGD(model.parameters(),lr=learnning_rate) #优化器
print("w:",model.weight)
print('b:',model.bias)
#3、训练模型
total_steps=len(train_loader)
for epoch in range(num_epochs):
for i,(images,labels) in enumerate(train_loader):
images=torch.reshape(images,shape=(-1,input_size)) #重组图像矩阵
outputs=model(images) #输出
loss=criterion(outputs,labels) #损失
optimizer.zero_grad() #梯度置零
loss.backward() #反向求导数
optimizer.step() #更新模型参数
if (i+1)%100==0:
print('epochs:{}/{},steps:{}/{}'.format(epoch+1,num_epochs,
i+1,total_steps))
torch.save(model,'model.ckpt')
torch.save(model.state_dict(),'model_dict.ckpt')
else:
model=torch.load('model.ckpt')
#4、测试模型
total_test=len(test_data)
with torch.no_grad():
for images,labels in test_loader:
images=torch.reshape(images,shape=(-1,input_size))
outputs=model(images)
_,predit=torch.max(outputs,1)
correct=(predit==labels).sum()
print('Accuracy:',correct.detach().numpy()/total_test)
#5、可视化
#显示前20个数据
for i in range(0,20,5):
s_index=150
for index in range(5):
image,label=test_data[i+index]
image=torch.squeeze(image,0)
s_index+=1
plt.subplot(s_index)
plt.imshow(image)
plt.text(0,-1,'label={}'.format(label))
plt.show()
我学Pytorch之三~~~Pytorch最简单的MINIST数据实现
最新推荐文章于 2023-05-15 19:53:21 发布