import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import sys
def load_data_fashion_mnist(batch_size):
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=True,download=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',train=False,download=True,transform=transforms.ToTensor())
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_workers)
return train_iter,test_iter
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress','coat','sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)
num_inputs = 784
num_outputs = 10
W = torch.tensor(np.random.normal(0,0.01,(num_inputs,num_outputs)),dtype=torch.float)
b = torch.zeros(num_outputs,dtype=torch.float)
W.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)
def softmax(X):
X_exp = X.exp()
partition = X_exp.sum(dim=1,keepdim=True)
return X_exp/partition
def net(X):
return softmax(torch.mm(X.view(-1,num_inputs),W)+b)
def cross_entropy(y_hat,y):
return -torch.log(y_hat.gather(1,y.view(-1,1)))
def accuracy(y_hat, y):
return (y_hat.argmax(dim=1) == y).float().mean().item()
def evaluate_accuracy(data_iter,net):
acc_sum,n = 0.0,0
for X,y in data_iter:
acc_sum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return acc_sum/n
def sgd(params,lr,batch_size):
for param in params:
param.data -= lr*param.grad/batch_size
num_epochs,lr = 5,0.1
def train_ch3(net,train_iter,test_iter,loss,num_epochs,batch_size,params=None,lr=None,optimizer=None):
for epoch in range(num_epochs):
train_l_sum,train_acc_sum,n = 0.0,0.0,0
for X,y in train_iter:
y_hat = net(X)
l = loss(y_hat,y).sum()
#梯度清零
if optimizer is not None:
optimizer.zero_grad()
elif params is not None and params[0].grad is not None:
for param in params:
param.grad.data.zero_()
l.backward()
if optimizer is None:
sgd(params,lr,batch_size)
else:
optimizer.step()
train_l_sum += l.item()
train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
test_acc = evaluate_accuracy(test_iter,net)
print('epoch %d,loss %.4f,train acc %.3f,test acc %.3f' %(epoch +1,train_l_sum / n,train_acc_sum / n,test_acc))
train_ch3(net,train_iter,test_iter,cross_entropy,num_epochs,batch_size,[W,b],lr)
X,y = iter(test_iter).next()
true_labels = get_fashion_mnist_labels(y.numpy())
pred_labels = get_fashion_mnist_labels(net(X).argmax(dim=1).numpy())
titles = [true + '\n' + pred for true,pred in zip(true_labels,pred_labels)]
show_fashion_mnist(X[0:9],titles[0:9])
Pytorch学习笔记——softmax模型
最新推荐文章于 2024-05-16 11:25:01 发布