运用BP模型实现猫狗数据集的分类
数据集下载
首先,我们要先下载好要分类的数据集,下载网址如下:
该数据集是Kaggle在2013年公开的猫狗数据集,该数据集总共25000张图片,猫狗各12500张。
部分图片如下:
我们下载的只是最基本的图片数据,还需要自行的创建我们的数据集,一般而言我们都是通过创建类的方法来实现。
导入库
from PIL import Image # 这行代码从Pillow库中导入了Image模块,它提供了许多用于打开、操作和保存图像的函数。
import numpy as np
from torch.utils.data import Dataset # Dataset类是torch.utils.data模块中的一个抽象类,用于表示一个数据集
from torchvision import transforms
import os # os模块提供了与操作系统交互的函数,例如读取目录内容、检查文件是否存在等。
import torch
import matplotlib.pyplot as plt
import matplotlib
#设置字体为楷体
matplotlib.rcParams['font.sans-serif'] = ['KaiTi']
在这里,我们应该了解各个库的作用和用法,这些库在下面都会用到
自定义类来封装数据集
在这里,我们需要用到__init__,__len__,__getitem__三个魔术方法,下面,我们就先了解一下这三个方法。(有了解的可以直接跳过!!)
__init__
__init__方法,称为构造方法
特性:
在创建类对象时会自动执行,
在创建类对象时,会传入参数给__init__使用
下面是个简单的例子:
class student:
# 可写可不写
# name=None
# age=None
# sex=None
def __init__(self,name,age,sex):
self.name=name
self.age=age
self.sex=sex
stu=student("小红",18,"女")
print("信息添加成功")
简单了解一下就好
__len__
__len__ 是Python中的一个魔术方法,用于定义类的实例对象的长度。 该方法用于返回一个对象的长度,通常是容器类型的对象,如字符串、列表、元组和字典等。
__getitem__
__getitem__方法用于获取序列对象中指定索引位置的元素,通常使用中括号 []运算符调用。 该方法接收一个索引作为参数,并返回序列对象中指定索引位置的元素。 如果指定的索引超出了序列对象的范围,应该抛出IndexError异常
接下来,我们来了解各个魔术方法下代码的含义
init部分
def __init__(self,root_dir,lable_dir):
self.root_dir=root_dir # 文件主路径dataset/train
self.label_dir=lable_dir # 分路径 cat 和 dog
self.path=os.path.join(self.root_dir,self.label_dir) # 将文件路径的两部分连接起来,使path成为完整路径
self.img_path=os.listdir(self.path) # 查看
self.transform = transforms.Compose([ # 包含:
transforms.Resize((224,224)), # 统一大小为224*224
transforms.ToTensor() # 转化为Tensor类型
])
self.img_path获取该完整路径下的图片(如dataset\train\cat目录下的图片)
transforms.compose:
orchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起:
比如说:Resize和ToTensor
Resize它可以更改PIL类型的图片数据的大小,本质上其实更改的是像素多少
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
# 定义图像的路径
path = "C:\\Users\\yangy\\Pictures\\Scre