CNN
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# 忽略警告
import warnings
warnings.filterwarnings("ignore")
def load_data():
train_dataset = datasets.MNIST(
root="./data/FashionMNIST",
train=True,
transform=transforms.ToTensor(),
download=False
)
test_dataset = datasets.MNIST(
root="./data/FashionMNIST",
train=False,
transform=transforms.ToTensor(),
download=False
)
# 数据集加载
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=64,
shuffle=True)
return train_loader, test_loader
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120),
nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
return x
def train_model(model, train_loader, test_loader, criterion, optimizer):
for epoch in range(1):
sum_loss = 0
for batch, data in enumerate(train_loader):
inputs, labels = data
# 向前传播
out = model(inputs)
loss = criterion(out, labels)
# 反向传播
optimizer.zero_grad() # 注意每次迭代都需要清零
loss.backward()
optimizer.step()
model.eval()
# 每训练100个batch打印一次平均loss
sum_loss += loss.item()
if batch % 100 == 99:
print('[Epoch:%d, batch:%d] train loss: %.03f' % (epoch + 1, batch + 1, sum_loss / 100))
sum_loss = 0.0
correct = 0
total = 0
for data in test_loader:
test_inputs, labels = data
outputs_test = model(test_inputs)
_, predicted = torch.max(outputs_test.data, 1) # 输出得分最高的类
total += labels.size(0) # 统计50个batch 图片的总个数
correct += (predicted == labels).sum() # 统计50个batch 正确分类的个数
print('第{}个epoch的识别准确率为:{}%'.format(epoch + 1, 100 * correct.item() / total))
# 保存模型参数
path = "./para/FashionMNIST_CNN.pth"
torch.save(model.state_dict(), path)
def run():
train_loader, test_loader = load_data()
Net = LeNet5
if torch.cuda.is_available():
model = Net().dcuda()
else:
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
train_model(model, train_loader, test_loader, criterion, optimizer)
if __name__ == '__main__':
run()