浅层神经网络的图像分类pytorch

图像分类的pytorch实现

1.数据集读入

使用鱼和猫两类图像,将训练用的数据集放入trian/fish和train/cat文件夹中,同理放入验证和测试的数据集

(1)建立训练数据集

import torchvision
from torchvision import transforms
train_data_path = './train/'
transforms = transforms.Compose([
       transforms.Resize((64,64)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225] ),
        ])  #裁剪为统一分辨率、将图像数据转化为张量、归一化
train_data =         torchvision.datasets.ImageFolder(root=train_data_path,transform=transforms)
#ImageFolder的作用是将每个目录下的图片定义为一个标签

(2)建立验证和测试数据集

val_data_path = './val/'
val_data =         torchvision.datasets.ImageFolder(root=val_data_path,transform=transforms)

test_data_path = './test/'
test_data =         torchvision.datasets.ImageFolder(root=test_data_path,transform=transforms)
训练集用于训练过程中更新模型
验证集用于评价模型的泛化能力(不是与训练数据的拟合程度!!!),不同来直接更新模型
测试集对模型的性能作出评价

(3)建立数据加载器

#batch_size的选择原则:为了尽可能提高GPU的利用率,通过改变batch_size的大小观察GPU利用率情况
import torch.utils.data
batch_size = 64
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size)
val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size)
test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

2 创建一个浅层神经网络

import torch.nn as nn
import torch.nn.functional as F
class SimpleNet(nn.Module):
        def __init__(self):
            super(SimpleNet, self).__init__()
            self.fc1 = nn.Linear(12288, 84) #注意这里的12288=64×64×3
            self.fc2 = nn.Linear(84, 50)
            self.fc3 = nn.Linear(50,2) #最后输出2分类
       
        def forward(self, x):
            x = x.view(-1, 12288)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = F.softmax(self.fc3(x)) #这里注意最后一次需要softmax()函数输出,但是可以不写这个函数,后面会提到
            return x
 simplenet = SimpleNet()

3 损失函数和优化器

损失函数用于确定预测与实际标签的差别,然后利用这个信息更新权重

多分类任务常用的损失函数为交叉熵损失函数CrossEntropyLoss(),回归任务常用的损失函数为MSELoss(),也可以定义自己的损失函数

loss_fn = nn.CrossEntropyLoss()
#由于交叉熵损失函数的封装中包含了softmax(),因此forward()方法变为:
def forward(self, x):
            x = x.view(-1, 12288)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)

优化器的作用是寻找合适参数使得损失函数的值尽可能小,常见的优化器包括:SGD、AdaGrad、AMSProp、Adam,最常用的是Adam优化器,Adam对每个参数使用了一个学习率,并根据参数的变化调整学习率

import torch.optim as optim
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)
#Ir为学习率,一般从0.001开始取

4 训练

建立一个通用的训练代码,使损失函数和优化器可以作为参数传递

for epoch in range(epochs):
    for batch in train_loader:
         optimizer.zero_grad() #每次循环后将梯度置0
         input, target = batch
         output = model(input)
         loss = loss_fn(output,target)
         loss.backward()
         optimizer.step() #更新所有参数,用在梯度被backward()计算好之后
  
  ####使用GPU
  if torch.cuda.is_available():
         device = torch.device('cuda')
  else:
         device = torch.device('cpu')
  model.to(device)
  #这里的model变量就是前面的simplenet

做一个训练整合(较通用)

def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device):
      for epoch in range(epochs):
           training_loss = 0.0
           valid_loss = 0.0
           model.train()
           for batch in train_loader:
                optimizer.zero_grad()
                inputs, targets = batch
                inputs = inputs.to(device)
                targets = targets.to(device)
                output = model(inputs)
                loss = loss_fn(output, target)
                loss.backward()
                optimizer.step()
                training_loss += loss.data.item()
            training_loss /= len(train_loader.dataset)
            
            model.eval()
            num_correct = 0
            num_examples = 0
            for batch in val_loader:
                 inputs, targets = batch
                 inputs = inputs.to(device)
                 targets = targets.to(device)
                 output = model(inputs)
                 loss = loss_fn(output,targets)
                 valid_loss += loss.data.item()
                 correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets)
                 num_correct += torch.sum(correct).item()
                 num_examples += correct.shape[0]
             valid_loss /= len(val_loader.dataset)
             
             
              print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy =                         {:.2f}'.format(epoch, training_loss,valid_loss, num_correct / num_examples)) 

5 预测

from PIL import Image
labels = ['cat','fish']

img = Image.open(\"./val/fish/100_1422.JPG\")
img = transforms(img).to(device)
img = torch.unsqueeze(img, 0)  #在张量前面增加一次批次为1的新维度,使其变为一个新的张量

simplenet.eval()
prediction = F.softmax(simplenet(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction])

6 模型的保存与加载

(1)可以直接使用torch.save()执行,但是这种方法如果在后续改变了模型结构,就可能会出问题

#保存
torch.save(simplenet,'/tmp/simplenet')
#加载
simplenet = torch.load('/tmp/simplenet')

(2) 保存模型的state_dict (常用)

#保存
torch.save(simplenet.state_dict(), "/tmp/simplenet")
#加载
simplenet = SimpleNet()
simplenet_state_dict = torch.load("/tmp/simplenet")
simplenet.load_state_dict(simplenet_state_dict)
  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值