图像分类的pytorch实现
1.数据集读入
使用鱼和猫两类图像,将训练用的数据集放入trian/fish和train/cat文件夹中,同理放入验证和测试的数据集
(1)建立训练数据集
import torchvision
from torchvision import transforms
train_data_path = './train/'
transforms = transforms.Compose([
transforms.Resize((64,64)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225] ),
]) #裁剪为统一分辨率、将图像数据转化为张量、归一化
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=transforms)
#ImageFolder的作用是将每个目录下的图片定义为一个标签
(2)建立验证和测试数据集
val_data_path = './val/'
val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=transforms)
test_data_path = './test/'
test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=transforms)
训练集 | 用于训练过程中更新模型 |
---|---|
验证集 | 用于评价模型的泛化能力(不是与训练数据的拟合程度!!!),不同来直接更新模型 |
测试集 | 对模型的性能作出评价 |
(3)建立数据加载器
#batch_size的选择原则:为了尽可能提高GPU的利用率,通过改变batch_size的大小观察GPU利用率情况
import torch.utils.data
batch_size = 64
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)
2 创建一个浅层神经网络
import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(12288, 84) #注意这里的12288=64×64×3
self.fc2 = nn.Linear(84, 50)
self.fc3 = nn.Linear(50,2) #最后输出2分类
def forward(self, x):
x = x.view(-1, 12288)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.softmax(self.fc3(x)) #这里注意最后一次需要softmax()函数输出,但是可以不写这个函数,后面会提到
return x
simplenet = SimpleNet()
3 损失函数和优化器
损失函数用于确定预测与实际标签的差别,然后利用这个信息更新权重
多分类任务常用的损失函数为交叉熵损失函数CrossEntropyLoss(),回归任务常用的损失函数为MSELoss(),也可以定义自己的损失函数
loss_fn = nn.CrossEntropyLoss()
#由于交叉熵损失函数的封装中包含了softmax(),因此forward()方法变为:
def forward(self, x):
x = x.view(-1, 12288)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
优化器的作用是寻找合适参数使得损失函数的值尽可能小,常见的优化器包括:SGD、AdaGrad、AMSProp、Adam,最常用的是Adam优化器,Adam对每个参数使用了一个学习率,并根据参数的变化调整学习率
import torch.optim as optim
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)
#Ir为学习率,一般从0.001开始取
4 训练
建立一个通用的训练代码,使损失函数和优化器可以作为参数传递
for epoch in range(epochs):
for batch in train_loader:
optimizer.zero_grad() #每次循环后将梯度置0
input, target = batch
output = model(input)
loss = loss_fn(output,target)
loss.backward()
optimizer.step() #更新所有参数,用在梯度被backward()计算好之后
####使用GPU
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
model.to(device)
#这里的model变量就是前面的simplenet
做一个训练整合(较通用)
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device):
for epoch in range(epochs):
training_loss = 0.0
valid_loss = 0.0
model.train()
for batch in train_loader:
optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
output = model(inputs)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
training_loss += loss.data.item()
training_loss /= len(train_loader.dataset)
model.eval()
num_correct = 0
num_examples = 0
for batch in val_loader:
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
output = model(inputs)
loss = loss_fn(output,targets)
valid_loss += loss.data.item()
correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
num_correct += torch.sum(correct).item()
num_examples += correct.shape[0]
valid_loss /= len(val_loader.dataset)
print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,valid_loss, num_correct / num_examples))
5 预测
from PIL import Image
labels = ['cat','fish']
img = Image.open(\"./val/fish/100_1422.JPG\")
img = transforms(img).to(device)
img = torch.unsqueeze(img, 0) #在张量前面增加一次批次为1的新维度,使其变为一个新的张量
simplenet.eval()
prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction])
6 模型的保存与加载
(1)可以直接使用torch.save()执行,但是这种方法如果在后续改变了模型结构,就可能会出问题
#保存
torch.save(simplenet,'/tmp/simplenet')
#加载
simplenet = torch.load('/tmp/simplenet')
(2) 保存模型的state_dict (常用)
#保存
torch.save(simplenet.state_dict(), "/tmp/simplenet")
#加载
simplenet = SimpleNet()
simplenet_state_dict = torch.load("/tmp/simplenet")
simplenet.load_state_dict(simplenet_state_dict)