import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
import torchvision
import torchvision.transforms as tranfsorms
import matplotlib.pyplot as plt
import numpy
# 确定数据,确定超参数
lr = 0.15
gamma = 0
epochs = 10
bs = 128
# 实例化数据集
mnist = torchvision.datasets.FashionMNIST(root = "D:\jupyterDemo\MINST-FASHION数据集"
,download = False
,train = True
,transform = tranfsorms.ToTensor()
)
# 数据样式探索
batchdata = DataLoader(mnist
,batch_size = bs
,shuffle = True
)
for x,y in batchdata:
print(x.shape)
print(y.shape)
break
# 张量中共有多少个元素,作为输入
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique())
# 定义神经网络架构
class Model(nn.Module):
def __init__(self,in_features=10,out_features=2):
super().__init__()
self.linear1 = nn.Linear(in_features,128,bias=False)
self.output = nn.Linear(128,out_features,bias=False)
def forward(self,x):
# -1作为占位符,表示pytorch自动帮我们计算-1这个位置的维度应该是多少
x = x.view(-1,28*28)
sigma1 = torch.relu(self.linear1(x))
sigma2 = F.log_softmax(self.output(sigma1),dim=1)
return sigma2
# 定义训练函数
def fit(net,batchdata,lr=0.15,epochs=5,gamma=0):
criterion = nn.NLLLoss()
opt = optim.SGD(net.parameters() # 优化器:动量法梯度下降
,lr=lr
,momentum=gamma)
correct = 0 # 循环开始之前,预测正确的值为0
samples = 0 # 循环开始之前,模型一个样本都没有见过
for epoch in range(epochs): # 全数据共训练几次
for batch_idx,(x,y) in enumerate(batchdata):
# 核心代码区 *******************************************
y = y.view(x.shape[0])
sigma = net.forward(x) # 正向传播
loss = criterion(sigma,y)
loss.backward()
opt.step()
opt.zero_grad()
# 核心代码区 *******************************************
# 准确率
yhat = torch.max(sigma,1)[1] # 即得到预测标签
correct += torch.sum(yhat == y)
samples += x.shape[0]
if(batch_idx+1) % 125 == 0 or batch_idx == len(batchdata) - 1:
print("Epoch{}:[{}/{}({:.0f}%)] Loss:{:.6f},Accuracy:{:.3f}".format(
epoch+1
,samples
,epochs*len(batchdata.dataset)
,100*samples/(epochs*len(batchdata.dataset))
,loss.data.item()
,float(100*correct/samples))
)
# 训练与评估
torch.manual_seed(1412)
net = Model(in_features=input_,out_features=output_)
fit(net,batchdata,lr=lr,epochs=epochs,gamma=gamma)
运行结果: