pytorch 猫狗大战

数据集下载:Dogs vs. Cats
在这里插入图片描述
需要导入的包

import os
import shutil
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import models
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

input_size = 224
batch_size = 128
save_path = './weights.pt'

数据整理

下载下来的压缩文件里有两个文件夹:train,test1
train 中有 25000 张图片,其中 12500 张猫, 12500 张狗

data_file = os.listdir('./train/')
print(len(data_file)) # 25000
dog_file = list(filter(lambda x : x[:3] == 'dog', data_file))
cat_file = list(filter(lambda x : x[:3] == 'cat', data_file))
print(len(dog_file),len(cat_file)) # 12500 12500
print(cat_file[:3]) # ['cat.0.jpg', 'cat.1.jpg', 'cat.10.jpg']

下面将文件目录整理成如下形式:

├─test1
├─train
│  ├─cat
│  └─dog
└─val
    ├─cat
    └─dog

其中 train 包含原训练集中 90% 的图片,val 包含原数据集中 10% 的图片,作为验证集。猫狗图片分别放在各自的文件夹下。

root = os.getcwd()
print('current path:',root) # current path: D:\Downloads\dogs-vs-cats

# 新建文件夹
for i in ['dog','cat']:
    for j in ['train','val']:
        try:
            os.makedirs(os.path.join(root,j,i))
        except FileExistsError as e:
            pass

# 移动图片,shutil.move 相当于剪切操作
for i,file in enumerate(dog_file):
    ori_path = os.path.join(root,'train',file)
    if i < 0.9*len(dog_file):
        des_path = os.path.join(root,'train','dog')
    else:
        des_path = os.path.join(root,'val','dog')
    shutil.move(ori_path, des_path)  # import shutil

for i,file in enumerate(cat_file):
    ori_path = os.path.join(root,'train',file)
    if i < 0.9*len(cat_file):
        des_path = os.path.join(root,'train','cat')
    else:
        des_path = os.path.join(root,'val','cat')
    shutil.move(ori_path, des_path)

数据读取

对上面这种目录结构,pytorch 中自带的 ImageFolder 可以直接读取:

transform_train=transforms.Compose([
    transforms.RandomResizedCrop(input_size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
train_set=ImageFolder('train',transform=transform_train)
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle=True)

transform_val=transforms.Compose([
    transforms.Resize([input_size,input_size]),  # 注意 Resize 参数是 2 维,和 RandomResizedCrop 不同
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])
val_set=ImageFolder('val', transform=transform_val)
val_loader = DataLoader(dataset=val_set,
                        batch_size=batch_size,
                        shuffle=False)

print(train_set.class_to_idx)  # {'cat': 0, 'dog': 1}
print(len(train_set.imgs))  # 22500
print(train_set[1][0].size())  # torch.Size([3, 224, 224])
print(val_set[1][0].size())  # torch.Size([3, 224, 224])

来看 RandomResizedCrop 的数据增强效果:缩放,水平翻转,裁剪

plt.imshow(np.transpose(train_set[1][0],[1,2,0])+0.5) # tensor size: (3,224,224) range:[-0.5,0.5]
plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

训练最后一层

transfer_model = models.resnet18(pretrained=True)
for param in transfer_model.parameters():
    param.requires_grad = False

# 修改最后一层维数,即 把原来的全连接层 替换成 输出维数为2的全连接层
dim = transfer_model.fc.in_features
transfer_model.fc = nn.Linear(dim, 2)
# print(transfer_model)

net = transfer_model.to(device)

criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3)

定义训练和验证函数,需要特别注意,测试前调用 model.eval()
因为网络中的有些结构的行为在训练和测试时是不一样的,比如 dropout
一开始没有注意这个细节,结果在验证集上的准确率只有 50+%

def train():
	net.train() 
    batch_num = len(train_loader)
    running_loss = 0.0
    for i, data in enumerate(train_loader,start=1):
        # 将输入传入GPU
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 计算误差并显示
        running_loss += loss.item()
        if i % 20 == 0: 
            print(
                'batch:{}/{} loss:{:.3f}'.format(i, batch_num, running_loss / 20))
            running_loss = 0.0


#测试函数
def validate():
	net.eval() # !!!!!!!
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' %
          (100 * correct / total))

开始训练,只需1轮遍历数据集就达到了 95%的准确率,可见迁移学习的强大威力。

n_epoch = 10
for epoch in range(n_epoch):
    print('epoch {}'.format(epoch+1))
    train()
    validate()
torch.save(net.state_dict(), save_path)

'''
打印结果:
epoch 1
batch:20/176 loss:0.678
batch:40/176 loss:0.583
batch:60/176 loss:0.507
batch:80/176 loss:0.459
batch:100/176 loss:0.415
batch:120/176 loss:0.381
batch:140/176 loss:0.362
batch:160/176 loss:0.346
Accuracy on the test images: 95 %
epoch 2
batch:20/176 loss:0.323
batch:40/176 loss:0.300
batch:60/176 loss:0.290
batch:80/176 loss:0.280
batch:100/176 loss:0.280
batch:120/176 loss:0.267
batch:140/176 loss:0.254
batch:160/176 loss:0.262
Accuracy on the test images: 96 %
epoch 3
batch:20/176 loss:0.257
batch:40/176 loss:0.252
batch:60/176 loss:0.237
batch:80/176 loss:0.228
batch:100/176 loss:0.234
batch:120/176 loss:0.234
batch:140/176 loss:0.228
batch:160/176 loss:0.223
Accuracy on the test images: 96 %
'''

如果想继续提升,还可以训练所有层:

net.load_state_dict(torch.load(save_path))

for param in net.parameters():
    param.requires_grad = True

n_epoch = 10
for epoch in range(n_epoch):
    print('epoch {}'.format(epoch+1))
    train()
    validate()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值