图片分类:png图片文件夹的处理(minst图片文件)

用pytorch框架,把文件夹中的png图片进行读取,并分成比例为8:2的训练集和测试集

以下是文件夹的形式,图片是已经被分类放入不同的文件里了

 接下来就对文件图片进行划分训练集和测试集

1.导入包

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms    # 图像预处理包
from PIL import Image      # 读取图片
import os   # 打开文件夹,获取所有图片的地址

 2.定义函数读取文件夹中的图片

class Mydata(Dataset):      # 继承Dataset类
    def __init__(self, root_dir, transform=None):  # 类的实例化,主函数创建实例对象时,自动调用该函数,主要作为外部信息传入类中私有方法的接口
        self.transform = transform
        self.datalist = []
        self.root_dir = root_dir     # 获取文件夹的路径/目录地址

        # 获取每一个分类文件的名字
        self.label_dir_zero = '0'
        self.label_dir_one = '1'
        self.label_dir_two = '2'
        self.label_dir_three = '3'
        self.label_dir_four = '4'
        self.label_dir_five = '5'
        self.label_dir_six = '6'
        self.label_dir_seven = '7'
        self.label_dir_eight = '8'
        self.label_dir_nine = '9'

        # 将两个路径拼接起来(包含所有图片的路径 + 每一个分类文件夹的名字)
        self.path_zero = os.path.join(self.root_dir, self.label_dir_zero)
        self.path_one = os.path.join(self.root_dir, self.label_dir_one)
        self.path_two = os.path.join(self.root_dir, self.label_dir_two)
        self.path_three = os.path.join(self.root_dir, self.label_dir_three)
        self.path_four = os.path.join(self.root_dir, self.label_dir_four)
        self.path_five = os.path.join(self.root_dir, self.label_dir_five)
        self.path_six = os.path.join(self.root_dir, self.label_dir_six)
        self.path_seven = os.path.join(self.root_dir, self.label_dir_seven)
        self.path_eight = os.path.join(self.root_dir, self.label_dir_eight)
        self.path_nine = os.path.join(self.root_dir, self.label_dir_nine)

        # 以列表的形式返回指定目录self.path_zero下的图片名
        self.img_path_zero = os.listdir(self.path_zero)
        # 遍历‘0’文件下包含所有图片名的列表
        for j in range(len(self.img_path_zero)):
            # 得到每一个图片名
            img_name = self.img_path_zero[j]
            # 包含所有图片的路径 + 每一个分类文件夹的名字 + 图片名 = 每一张图片的路径
            img_zero_path = os.path.join(self.root_dir, self.label_dir_zero, img_name)
            # 将元组(label, dir)添加到列表里
            self.datalist.append(('0', img_zero_path))

        self.img_path_one = os.listdir(self.path_one)
        for j in range(len(self.img_path_one)):
            img_name = self.img_path_one[j]
            img_one_path = os.path.join(self.root_dir, self.label_dir_one, img_name)
            self.datalist.append(('1', img_one_path))

        self.img_path_two = os.listdir(self.path_two)
        for j in range(len(self.img_path_two)):
            img_name = self.img_path_two[j]
            img_two_path = os.path.join(self.root_dir, self.label_dir_two, img_name)
            self.datalist.append(('2', img_two_path))

        self.img_path_three = os.listdir(self.path_three)
        for j in range(len(self.img_path_three)):
            img_name = self.img_path_three[j]
            img_three_path = os.path.join(self.root_dir, self.label_dir_three, img_name)
            self.datalist.append(('3', img_three_path))

        self.img_path_four = os.listdir(self.path_four)
        for j in range(len(self.img_path_four)):
            img_name = self.img_path_four[j]
            img_four_path = os.path.join(self.root_dir, self.label_dir_four, img_name)
            self.datalist.append(('4', img_four_path))

        self.img_path_five = os.listdir(self.path_five)
        for j in range(len(self.img_path_five)):
            img_name = self.img_path_five[j]
            img_five_path = os.path.join(self.root_dir, self.label_dir_five, img_name)
            self.datalist.append(('5', img_five_path))

        self.img_path_six = os.listdir(self.path_six)
        for j in range(len(self.img_path_six)):
            img_name = self.img_path_six[j]
            img_six_path = os.path.join(self.root_dir, self.label_dir_six, img_name)
            self.datalist.append(('6', img_six_path))

        self.img_path_seven = os.listdir(self.path_seven)
        for j in range(len(self.img_path_seven)):
            img_name = self.img_path_seven[j]
            img_seven_path = os.path.join(self.root_dir, self.label_dir_seven, img_name)
            self.datalist.append(('7', img_seven_path))

        self.img_path_eight = os.listdir(self.path_eight)
        for j in range(len(self.img_path_eight)):
            img_name = self.img_path_eight[j]
            img_eight_path = os.path.join(self.root_dir, self.label_dir_eight, img_name)
            self.datalist.append(('8', img_eight_path))

        self.img_path_nine = os.listdir(self.path_nine)
        for j in range(len(self.img_path_nine)):
            img_name = self.img_path_nine[j]
            img_nine_path = os.path.join(self.root_dir, self.label_dir_nine, img_name)
            self.datalist.append(('9', img_nine_path))
        print(len(self.datalist))    # 68992

        # 从中获取20000张图片
        self.new_dataset = []
        for m in self.datalist[:10000]:
            # print(type(m))   # <class 'tuple'>
            label, img_path = m
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(m)
            else:
                self.new_dataset.append(m)
        for n in self.datalist[-10001: -1]:
            label, img_path = n
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(n)
            else:
                self.new_dataset.append(n)

    def __getitem__(self, item):        # 实例化__getitem__方法,item表示索引,返回的图像必须是tensor
        label, img_path = self.new_dataset[item]     # 列表中索引为item的值,分别赋值给 label,img_path
        img = Image.open(img_path)    # 打开图片,将图片信息赋值给img
        if self.transform != None:
            img = self.transform(img)   # 对图片进行预处理
        # 创建tensor张量
        labels = {'0': torch.tensor(0),
                  '1': torch.tensor(1),
                  '2': torch.tensor(2),
                  '3': torch.tensor(3),
                  '4': torch.tensor(4),
                  '5': torch.tensor(5),
                  '6': torch.tensor(6),
                  '7': torch.tensor(7),
                  '8': torch.tensor(8),
                  '9': torch.tensor(9)}
        label = labels[label]
        return img, label

    def __len__(self):
        return len(self.new_dataset)     # 返回列表的长度

   

 注意:if img.mode != 'L'  这一句代码在写的时候要注意,我这里是png图片,如果是其他后缀的图片(JPG等),要把'L'换了,否则会出现以下错误:mydataset中的长度为0,也就是其判断全为假,无法添加图片到new_dataset中

 那么,如何判断‘L’应改成什么呢?

通过debug ,将上示代码打上小红点debug的过程中会看到蓝色那一行,在28x28的前面有 mode=L,其他类型的图片可以通过这个方法进行判断,这样其他已经分类的图片文件也可以

 用相同的方法处理啦

3.对图片进行处理(transform)

transform_img = transforms.Compose([
    transforms.Resize([28, 28]),  # 将图片短边缩放至28,长宽比保持不变:
    transforms.ToTensor(),  # 把图片进行归一化,并把数据转换成Tensor类型(即:把灰度范围从0-255换成0-1之间)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 把转换为tensor类型后的0-1再次转换为 -1~1
])

4.导入文件夹并进行划分

path = r'C:\Users\33165\PycharmProjects\pythonProject\minst\data\all_data'
mydataset = Mydata(path, transform=transform_img)
print(mydataset.__len__())  # 20000
train_data, test_data = random_split(mydataset, lengths=[16000, 4000])   # 划分数据集8:2

5.对训练集和测试集进行封装

# 将train_data,test_data分成batch_size大小为一批,进行shuffle(打乱),最后封装
train_data = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
test_data = DataLoader(test_data, batch_size=128, shuffle=True, num_workers=0)

完整代码:

import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms    # 图像预处理包
from PIL import Image      # 读取图片
import os   # 打开文件夹,获取所有图片的地址

# 数据处理
class Mydata(Dataset):      # 继承Dataset类
    def __init__(self, root_dir, transform=None):  # 类的实例化,主函数创建实例对象时,自动调用该函数,主要作为外部信息传入类中私有方法的接口
        self.transform = transform
        self.datalist = []
        self.root_dir = root_dir     # 获取文件夹的路径/目录地址

        # 获取每一个分类文件的名字
        self.label_dir_zero = '0'
        self.label_dir_one = '1'
        self.label_dir_two = '2'
        self.label_dir_three = '3'
        self.label_dir_four = '4'
        self.label_dir_five = '5'
        self.label_dir_six = '6'
        self.label_dir_seven = '7'
        self.label_dir_eight = '8'
        self.label_dir_nine = '9'

        # 将两个路径拼接起来(包含所有图片的路径 + 每一个分类文件夹的名字)
        self.path_zero = os.path.join(self.root_dir, self.label_dir_zero)
        self.path_one = os.path.join(self.root_dir, self.label_dir_one)
        self.path_two = os.path.join(self.root_dir, self.label_dir_two)
        self.path_three = os.path.join(self.root_dir, self.label_dir_three)
        self.path_four = os.path.join(self.root_dir, self.label_dir_four)
        self.path_five = os.path.join(self.root_dir, self.label_dir_five)
        self.path_six = os.path.join(self.root_dir, self.label_dir_six)
        self.path_seven = os.path.join(self.root_dir, self.label_dir_seven)
        self.path_eight = os.path.join(self.root_dir, self.label_dir_eight)
        self.path_nine = os.path.join(self.root_dir, self.label_dir_nine)

        # 以列表的形式返回指定目录self.path_zero下的图片名
        self.img_path_zero = os.listdir(self.path_zero)
        # 遍历‘0’文件下包含所有图片名的列表
        for j in range(len(self.img_path_zero)):
            # 得到每一个图片名
            img_name = self.img_path_zero[j]
            # 包含所有图片的路径 + 每一个分类文件夹的名字 + 图片名 = 每一张图片的路径
            img_zero_path = os.path.join(self.root_dir, self.label_dir_zero, img_name)
            # 将元组(label, dir)添加到列表里
            self.datalist.append(('0', img_zero_path))

        self.img_path_one = os.listdir(self.path_one)
        for j in range(len(self.img_path_one)):
            img_name = self.img_path_one[j]
            img_one_path = os.path.join(self.root_dir, self.label_dir_one, img_name)
            self.datalist.append(('1', img_one_path))

        self.img_path_two = os.listdir(self.path_two)
        for j in range(len(self.img_path_two)):
            img_name = self.img_path_two[j]
            img_two_path = os.path.join(self.root_dir, self.label_dir_two, img_name)
            self.datalist.append(('2', img_two_path))

        self.img_path_three = os.listdir(self.path_three)
        for j in range(len(self.img_path_three)):
            img_name = self.img_path_three[j]
            img_three_path = os.path.join(self.root_dir, self.label_dir_three, img_name)
            self.datalist.append(('3', img_three_path))

        self.img_path_four = os.listdir(self.path_four)
        for j in range(len(self.img_path_four)):
            img_name = self.img_path_four[j]
            img_four_path = os.path.join(self.root_dir, self.label_dir_four, img_name)
            self.datalist.append(('4', img_four_path))

        self.img_path_five = os.listdir(self.path_five)
        for j in range(len(self.img_path_five)):
            img_name = self.img_path_five[j]
            img_five_path = os.path.join(self.root_dir, self.label_dir_five, img_name)
            self.datalist.append(('5', img_five_path))

        self.img_path_six = os.listdir(self.path_six)
        for j in range(len(self.img_path_six)):
            img_name = self.img_path_six[j]
            img_six_path = os.path.join(self.root_dir, self.label_dir_six, img_name)
            self.datalist.append(('6', img_six_path))

        self.img_path_seven = os.listdir(self.path_seven)
        for j in range(len(self.img_path_seven)):
            img_name = self.img_path_seven[j]
            img_seven_path = os.path.join(self.root_dir, self.label_dir_seven, img_name)
            self.datalist.append(('7', img_seven_path))

        self.img_path_eight = os.listdir(self.path_eight)
        for j in range(len(self.img_path_eight)):
            img_name = self.img_path_eight[j]
            img_eight_path = os.path.join(self.root_dir, self.label_dir_eight, img_name)
            self.datalist.append(('8', img_eight_path))

        self.img_path_nine = os.listdir(self.path_nine)
        for j in range(len(self.img_path_nine)):
            img_name = self.img_path_nine[j]
            img_nine_path = os.path.join(self.root_dir, self.label_dir_nine, img_name)
            self.datalist.append(('9', img_nine_path))
        print(len(self.datalist))    # 68992

        # 从中获取20000张图片
        self.new_dataset = []
        for m in self.datalist[:10000]:
            # print(type(m))   # <class 'tuple'>
            label, img_path = m
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(m)
            else:
                self.new_dataset.append(m)
        for n in self.datalist[-10001: -1]:
            label, img_path = n
            img = Image.open(img_path)
            if img.mode != 'L':
                self.datalist.remove(n)
            else:
                self.new_dataset.append(n)

    def __getitem__(self, item):        # 实例化__getitem__方法,item表示索引,返回的图像必须是tensor
        label, img_path = self.new_dataset[item]     # 列表中索引为item的值,分别赋值给 label,img_path
        img = Image.open(img_path)    # 打开图片,将图片信息赋值给img
        if self.transform != None:
            img = self.transform(img)   # 对图片进行预处理
        # 创建tensor张量
        labels = {'0': torch.tensor(0),
                  '1': torch.tensor(1),
                  '2': torch.tensor(2),
                  '3': torch.tensor(3),
                  '4': torch.tensor(4),
                  '5': torch.tensor(5),
                  '6': torch.tensor(6),
                  '7': torch.tensor(7),
                  '8': torch.tensor(8),
                  '9': torch.tensor(9)}
        label = labels[label]
        return img, label

    def __len__(self):
        return len(self.new_dataset)     # 返回列表的长度

transform_img = transforms.Compose([
    transforms.Resize([28, 28]),  # 将图片短边缩放至28,长宽比保持不变:
    transforms.ToTensor(),  # 把图片进行归一化,并把数据转换成Tensor类型(即:把灰度范围从0-255换成0-1之间)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 把转换为tensor类型后的0-1再次转换为 -1~1
])
path = r'C:\Users\33165\PycharmProjects\pythonProject\minst\data\all_data'
mydataset = Mydata(path, transform=transform_img)
print(mydataset.__len__())  # 20000
train_data, test_data = random_split(mydataset, lengths=[16000, 4000])   # 划分数据集8:2
# 将train_data,test_data分成batch_size大小为一批,进行shuffle(打乱),最后封装
train_data = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=0)
test_data = DataLoader(test_data, batch_size=128, shuffle=True, num_workers=0)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值