第四次作业:猫狗大战挑战赛

一、使用VGG模型进行猫狗大战

内容包含: 1.主要步骤解读 2. 代码 及解释 3.运行截图 4.知识点补充 5.遇到的问题及解决方法

import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torchvision
from torchvision import models,transforms,datasets
import time
import json


# 判断是否存在GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Using gpu: %s ' % torch.cuda.is_available())
Using gpu: False 

1. 下载数据

Jeremy Howard 提供了数据的下载,链接为:http://files.fast.ai/data/dogscats.zip

在他整理的数据集中,猫和狗的图片放在单独的文件夹中, 同时还提供了一个Validation数据。如果没有GPU设备,请减少用做训练的图像数据量即可。

因为这个代码需要在colab上跑,速度会相对较慢。因此,我们重新整理了数据,制作了新的数据集,训练集包含1800张图(猫的图片900张,狗的图片900张),测试集包含2000张图。下载地址为:http://fenggao-image.stor.sinaapp.com/dogscats.zip

import wget
#引入基本的库
url='http://fenggao-image.stor.sinaapp.com/dogscats.zip'
file_name=wget.download(url)
print(file_name)

import zipfile
import os
src_path='dogscats.zip'
target_path='dogscats'
if(not os.path.isdir(target_path+"/Images")):
    z = zipfile.ZipFile(src_path, 'r')
    z.extractall(path=target_path)
    z.close()
else:
    print("文件已解压")

``

问题及解决方法

注意:这里要变化一下,如果直接用! wget +地址来下载的话,报错
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0Qy9HIzr-1666228621305)(C:\Users\asus\AppData\Roaming\Typora\typora-user-images\image-20221020090911645.png)]

修改成这个样子
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZsihNkWZ-1666228621307)(C:\Users\asus\AppData\Roaming\Typora\typora-user-images\image-20221020090855393.png)]
OK

2. 数据处理

datasets 是 torchvision 中的一个包,可以用做加载图像数据。它可以以多线程(multi-thread)的形式从硬盘中读取数据,使用 mini-batch 的形式,在网络训练中向 GPU 输送。在使用CNN处理图像时,需要进行预处理。图片将被整理成 224 × 224 × 3 224\times 224 \times 3 224×224×3 的大小,同时还将进行归一化处理。

torchvision 支持对输入数据进行一些复杂的预处理/变换 (normalization, cropping, flipping, jittering 等)。

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

vgg_format = transforms.Compose([
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])

data_dir = './dogscats'

dsets = {
   x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
         for x in ['train', 'valid']}

dset_sizes = {
   x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes

通过下面代码可以查看 dsets 的一些属性解析:

print(dsets[‘train’].classes)#分的文件夹的名字下的类别
print(dsets[‘train’].class_to_idx)#按顺序为这些类别定义索引为0,1,2,…
print(dsets[‘train’].imgs[:5])#返回从文件下中得到的图片的路径以及其类别
print('dset_sizes: ', dset_sizes)#返回该文件夹的内的图片数量

# 通过下面代码可以查看 dsets 的一些属性

print(dsets['train'].classes)
print(dsets['train'].class_to_idx)
print(dsets['train'].imgs[:5])
print('dset_sizes: ', dset_sizes)

运行结果:
在这里插入图片描述

使用DataLoader解析:

torch.utils.data.DataLoader主要是对数据进行batch的划分,
batch_size批训练集的大小
num_workers:多线程来读数据
shuffle:是否打乱数据
使用DataLoader的好处是,可以快速的迭代数据。

loader_train = torch.utils.data.DataLoader(dsets['train'], batch_size
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值