import os
import sys
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.optim as optim
'''
作者:小宇
时间:2022.4.10
'''
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class DataMake(object):
'''
数据集制作,图像显示
'''
def __init__(self):
self.data_train = MNIST(
'./data',
download=True,
transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
])
)
self.data_test = MNIST(
'./data',
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
])
)
def create_data(self):
data_train_loader = DataLoader(self.data_train,batch_size=16,shuffle=True)
data_test_loader = DataLoader(self.data_test, batch_size=16, shuffle=True)
return data_train_loader,data_test_loader
def show_mnist_image(self):
data_train_loader,data_test_loader = self.create_data()
figure = plt.figure()
num_of_images = 60
for imgs,targets in data_train_loader:
break
for index in range(num_of_images):
plt.subplot(6,10,index+1)
plt.axis('off')
img = imgs[index,...]
plt.imshow(img.numpy().squeeze(),cmap='gray_r')
plt.show()
class LeNet(nn.Module):
'''
Lenet5网络搭建,输入为(1,1,32,32),输出为:(1,10)
'''
def __init__(self):
super(LeNet,self).__init__()
self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=(5,5),stride=(1,1),padding=(0,0),dilation=(1,1)) #(6,28,28)
self.pool1 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2)) #(6,14,14)
self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=(5,5),stride=(1,1),padding=(0, 0),dilation=(1, 1)) #(16,10,10)
self.pool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)) #(16,5,5)
self.fc1 = nn.Linear(in_features=16*5*5,out_features=120) #(1,120)
self.fc2 = nn.Linear(in_features=120,out_features=84)#(1,84)
self.fc3 = nn.Linear(in_features=84, out_features=10)#(1,10)
def forward(self,x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x= self.pool2(x)
x = x.view(x.shape[0],-1) #torch.Size([1, 400])
# x = x.view(-1) #torch.Size([400])
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def train_net():
'''
训练部分
'''
get_data = DataMake()
data_train_loader,_ = get_data.create_data()
model = LeNet().to(device)
model.train()
lr = 0.05
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr = lr)
train_loss = 0
correct = 0
total = 0
epochs = 10
for epoch in range(epochs):
for batch_idx,(inputs,targets) in enumerate(data_train_loader):
optimizer.zero_grad()
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs,targets)
# print(outputs.shape,targets.shape)
loss.backward()
optimizer.step()
# print("loss:",loss)
# print(list(model.parameters()))
train_loss += loss.item()
# print(outputs,outputs.shape)
_,predicted = outputs.max(1)
# print(predicted,predicted.shape)
total +=targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx%100==0:
print(batch_idx,len(data_train_loader),'Loss: %.3f | Acc: %.3f%%(%d/%d)'%(train_loss/(batch_idx+1),100.*correct/total,correct,total))
torch.save(model,"models.pth")
def test_model():
'''
测试模型
'''
get_data = DataMake()
_,data_test_loader = get_data.create_data()
model = torch.load("models.pth",map_location=device)
model = model.to(device)
model.eval()
criterion = nn.CrossEntropyLoss()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx,(inputs,targets) in enumerate(data_test_loader):
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs,targets)
test_loss += loss.item()
# print(outputs,outputs.shape)
_, predicted = outputs.max(1)
# print(predicted,predicted.shape)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx%2==0:
print(batch_idx,len(data_test_loader),'Loss: %.3f | Acc: %.3f%%(%d/%d)'%(test_loss/(batch_idx+1),100.*correct/total,correct,total))
if __name__=='__main__':
# train_net()
test_model()
6、LeNet网络搭建、模型训练、模型测试
最新推荐文章于 2024-08-02 11:53:03 发布