【猫狗数据集】pytorch训练猫狗数据集之创建数据集

数据集下载地址:

链接:https://pan.baidu.com/s/1tJQIY0ob2EyQn3cDipPkow?pwd=7gch 
提取码:7gch

猫狗数据集的分为训练集25000张,在训练集中猫和狗的图像是混在一起的,pytorch读取数据集有两种方式,第一种方式是将不同类别的图片放于其对应的类文件夹中,另一种是实现读取数据集类,该类继承torch.utils.Dataset,并重写__getitem__和__len__。

先将猫和狗从训练集中区分开来,分别放到dog和cat文件夹下

import os

file_dir = r"E:\development_file\jupyter_notebook\pytorch\data\dogs-vs-cats\train"
path = r"E:\development_file\jupyter_notebook\pytorch\data\dogs-vs-cats"

#将某类图片移动到该类的文件夹下
def img_to_file(path):
    print("=========开始移动图片============")
     #如果没有dog类和cat类文件夹,则新建
    if not os.path.exists(path+"/dog"):
            os.makedirs(path+"/dog")
    if not os.path.exists(path+"/cat"):
            os.makedirs(path+"/cat")
    file_name_list = os.listdir(file_dir)
    print("共:{}张图片".format(len(file_name_list)))
    for imgName in file_name_list:
        # 去除后缀
        img = imgName.replace(".jpg","")
         #将图片移动到指定的文件夹中
        if img.split(".")[0] == "cat":
            shutil.move(file_dir+"\\"+imgName,path+"\\cat")
        if img.split(".")[0] == "dog":
            shutil.move(file_dir+"\\"+imgName,path+"\\dog")
    print("=========移动图片完成============")    
img_to_file(path)

 然后从dog中和cat中分别抽取1250张,共2500张图片作为测试集。

import random
train_path = r"E:\development_file\jupyter_notebook\pytorch\data\catsdogs\train"
test_path = r"E:\development_file\jupyter_notebook\pytorch\data\catsdogs\val"
def split_train_test(fileDir,tarDir):

        if not os.path.exists(tarDir):
            os.makedirs(tarDir)
        pathDir = os.listdir(fileDir)    #取图片的原始路径
        filenumber=len(pathDir)
        rate=0.1    #自定义抽取图片的比例,比方说100张抽10张,那就是0.1
        picknumber=int(filenumber*rate) #按照rate比例从文件夹中取一定数量图片
        sample = random.sample(pathDir, picknumber)  #随机选取picknumber数量的样本图片
        print("=========开始移动图片============")
        print("移动了:"+str(picknumber)+"张")
        for name in sample:
                shutil.move(fileDir+name, tarDir)
        print("=========移动图片完成============")
split_train_test(train_path+'\\dog\\',test_path+'\\dog\\')  
split_train_test(train_path+'\\cat\\',test_path+'\\cat\\')  

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一种流行的深度学习框架,可以用于构建卷积神经网络(CNN)等模型。在分类任务中,我们可以使用PyTorch训练一个CNN模型来对的图像进行分类。 首先,我们需要准备一个分类的数据集。可以在网上找到已经标注好的图像数据集,例如Kaggle上的大战数据集。这个数据集包含了数千张的图像,以及它们对应的标签。 接下来,我们需要导入必要的PyTorch库和模块,例如torch、torchvision等。 然后,我们需要定义一个CNN模型。可以使用PyTorch提供的nn模块来搭建一个简单的CNN网络,包括卷积层、池化层和全连接层等。可以根据具体任务的需求和网络结构进行调整。 在搭建好网络之后,我们需要定义损失函数和优化器。对于分类任务,可以使用交叉熵损失函数来衡量预测结果和真实标签的差异,并选择适当的优化器,如SGD、Adam等来更新模型的参数。 接下来,我们可以开始训练模型。将数据集分为训练集和测试集,使用训练集来迭代地更新模型参数,计算损失函数并通过反向传播算法更新模型。在每个epoch结束后,使用测试集来评估模型的性能,如准确率、精确率、召回率等。 最后,我们可以使用训练好的模型对新的图像进行分类预测。将图像传入模型中,得到对应的预测结果,即的标签。 总结来说,PyTorch可以用于搭建CNN模型进行分类任务。需要准备好分类的数据集,在训练过程中使用损失函数和优化器来更新模型参数,并使用测试集来评估模型性能。最终可以使用训练好的模型对新的图像进行分类预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值