下面的代码是cnn是被MNIST,如果识别Fashion-MNIST,可以将数据集换成Fashion-MNIST即可。
第一个全连接的输入神经元个数如何确定,可以参考我的另一篇博客。即nn.lInear(1600,128)的中数字1600如何确定的?
import torch,torchvision
import torch.nn as nn
#定义模型
class CNNMnist(nn.Module):
def __init__(self):
super(CNNMnist,self).__init__()
self.feature = nn.Sequential(
nn.Conv2d(1,32,3), nn.ReLU(), nn.MaxPool2d(2,2),
nn.Conv2d(32,64,3), nn.ReLU(), nn.MaxPool2d(2,2)
)
self.classifier=nn.Sequential(
nn.Flatten(),
nn.Linear(1600, 128),nn.ReLU(),
nn.Linear(128,10)
)
def forward(self, x):
x = self.feature(x)
output = self.classifier(x)
return output
net = CNNMnist()
#加载数据集
apply_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=True, download=True,transform=apply_transform)
test_dataset = torchvision.datasets.MNIST(root='./data/mnist', train=False, download=False,transform=apply_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False)
#定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
#如果有gpu就使用gpu,否则使用cpu
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
net = net.to(device)
#训练模型
print('training on: ',device)
def test(test_loader):
net.eval()
acc = 0.0
sum = 0.0
loss_sum = 0
for batch, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = net(data)
loss = criterion(output, target)
acc+=torch.sum(torch.argmax(output,dim=1)==target).item()
sum+=len(target)
loss_sum+=loss.item()
print('test acc: %.2f%%, loss: %.4f'%(100*acc/sum, loss_sum/(batch+1)))
def train():
net.train()
loss_sum = 0
for batch, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = net(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch%200==0:
print('\tbatch: %d, loss: %.4f'%(batch, loss.item()))
for epoch in range(5):
print('epoch: %d'%epoch)
train()
test(test_loader)
实验结果: