kaggle Dogs_and_Cats实践记录

这篇博客记录了作者在kaggle Dogs_and_Cats竞赛中使用Transfer Learning的方法,特别是通过Resnet18进行深度学习的实践过程。首先介绍了数据下载,然后详细阐述了如何利用transfer learning对训练集和验证集进行准备,接着对resnet18进行fine-tuning,训练得到97%~98%的准确率。最后,博主分享了预测代码,确保对测试集图片排序后生成submission.csv。
摘要由CSDN通过智能技术生成

数据下载

kaggle Dogs_and_Cats数据集下载,可以在配置好kaggle 好之后,从相应的Competition:Dogs_and_Cats 中下载数据。

原数据包含了train.zip & test1.zip & sampleSubmission.csv

解压后,train.zip包含25000张图片,12500为Dogs, 12500为Cats, test1.zip解压之后为待预测的数据,12500张。

transfer learning

总体思路: Dogs, Cats与ImageNet中的数据有重合,因此,采用 transfer learning的方式来做。

train-set, dev-set准备

首先,训练需使用train.zip中的数据,为更快地操作,结合torchvision ImageFolder(folder.py)类,将原数据分别放进dogs , cats目录,并且使用下述脚本,将train分成train-set, dev-set.

import os
import sys
import shutil
import glob
import os.path as osp

dev_num = 2500

def split_data(root, target_dir):
    img_files = glob.glob(root+"/*.jpg")
    dev_files = img_files[-dev_num:]
    print('dev_files:', dev_files)

    for dev_file in dev_files:
        shutil.move(dev_file, osp.join(target_dir, osp.basename(dev_file)))
        


if __name__ == "__main__":
    root = sys.argv[1]
    target_dir = sys.argv[2]
    split_data(root, target_dir)

fine-tuning resnet18

构建模型,因要分类的数据不复杂,现采用resnet18作为基础网格,并且借用已有的预训练权重,只对最后一层fc层进行fine-tune.代码如下:

import torch
import torchvision

import torch.optim as optim
from torch.utils.data import *
from torchvision import models, transforms, datasets
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter(log_dir='runs')


def imshow(inp, title=None):
    """ Imshow for Tensor"""
    print(inp.size())
    inp = inp.numpy().transpose((1,2,0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])

    inp = inp*std+mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.show()

def test_show_data():
    # Test Showing, Get a batch of training data
    inputs, classes = next(iter(train_dataloader))
    # Make a grid from batch
    out = torchvision.utils.make_grid(inputs)
    imshow(out, title=[class_names[x] for x in classes
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值