CNN数据集——自己建立数据集要点

前言

新学cnn,尝试建立数据集。这是过了几天的描述,时间再长估计就要忘记了。
参考文章Pytorch学习(三)定义自己的数据集及加载训练
我是建立图片数据集,图片为四位验证码。
图片已经准备好,图片名称为图中的四位验证码。

0S22.jpg
图片名称为0S22.jpg

1.要点总结

1. 将训练集、测试集、验证集的图片放在三个文件夹中,尽量保证这三个文件夹不在移动

比例我忘记了,大概是6:2:2吧,以后查到了我再改

2.将图片切割,生成真正的训练集

由于四位验证码的可能性太大,节点太多(26+10)**4 = 1679616个节点,仅最后一层就这么多节点,所以将图片切割,生成单个字母的photo_cut,这样训练最后一层仅36个节点,大大减少计算量,最后检测四位验证码也就0.2秒(循环检测的,以后有时间学一下进程,直接四个进程检测)

将test_photo的图片保存在test_cut_photo中

文件夹需要提前建立,要不会报错


from PIL import Image
import os

def cut_photo(file_name_from, name, save_name):
    """
    输入图片名称,按固定位置将图片分割为四部分,以便更好地识别字母/数字
    :param name: 图片中字符名称
    :return:
    """
    img = Image.open(str(file_name_from)+"/" + str(name) + ".jpg")
    if img.mode == "P":
        img = img.convert('RGB')
    # name = '1D61'
    list_x_start = [2, 18, 30, 44]
    list_x_end = [16, 32, 44, 58]

    for i in range(4):
        x_start = list_x_start[i]
        x_end = list_x_end[i]
        y_start = 4
        y_end = 26
        box = [x_start, y_start, x_end, y_end]
        # print(box)
        # box1 = (0, 0, 16, 28)
        # box2 = (16, 0, 32, 28)
        # box3 = (32, 0, 46, 28)
        # box4 = (46, 0, 60, 28)
        region = img.crop(box)
        # region.show()
        region.resize((22, 22))
        file_name = str(save_name)+'/' + str(name) + str(i) + '_' + str(name[i]) + '.jpg'
        # print(file_name)

        region.save(file_name)


def get_photo_name(file_name_from):
    """
    获取photo文件下图片的名字
    :return:图片名称列表
    """
    photo_name_chr = os.listdir('./' + str(file_name_from))
    list_photo_name = []


    return photo_name_chr


if __name__ == '__main__':

    list_photo_name = get_photo_name('test_photo')
    for name in list_photo_name:
        cut_photo('test_photo', name[0:4], 'test_cut_photo')

直接将切割后的图片保存,不需要返回
切割后的图片名称保存格式[四位验证码]+[切割字符在的位置]+[切割字符的名称]

3.测试集、验证集同理

可以建立列表一次循环解决。

4.将切割好的图片数据集建立文件,包含图片位置和标签两个信息。

一开始学习时pytorch官方文档给的数据集是cv文件,我以为只能识别cv文件,其实什么文件都可以,我就是用的txt文件,excel应该也可以
注意一点,图片位置应该和标签在同一行,中间加个空格或者’,'逗号随意。

D:/python/report_new_2233/photo_cut/00163_6.jpg  6
D:/python/report_new_2233/photo_cut/00170_0.jpg  0
D:/python/report_new_2233/photo_cut/00171_0.jpg  0

由于有三个训练集、测试集、验证集,所以需要三个txt文件。

import os
from PIL import Image
from photo_cut import *


def make_name(file_name_from, img_file_txt):
    """
    生成图片、标签文件对应的文件
    :return:
    """
    photo_name_chr = os.listdir('./' + str(file_name_from))
    # print(photo_name_chr)
    with open(img_file_txt, 'w') as fp:
        for i in photo_name_chr:
            root_name = 'D:/python/report_new_2233/' + str(file_name_from) + '/' + str(i)
            fp.write(root_name)
            fp.write('  ')
            fp.write(i[-5])
            fp.write('\n')


def show_photo():
    """
    查看图片路径是否正确
    :return:
    """
    with open('name.txt', 'r') as fp2:
        a = fp2.readline().split()
        # print(a[0])
        img = Image.open(a[0])
        img.show()

保存txt文件后一定要读一行测试一下文件路径是否正确,我当时忘记了’/’

此时数据集建立完成,但教程还没完,因为pytorch读取也很难受

5.需要建立一个类,读取数据集txt文件

类中包含三中方法(隐藏方法,是叫这个吧)
主要是第三个方法__getitem__(self, index),他负责读取图片地址和标签,返回图片张量和标签张量。

6.重点事项

!!!注意大坑,返回图片格式为[通道数, 长或宽,长或宽]如果用Image.open方法读取的图片维度为[ 长或宽,长或宽,通道数],是不能卷积

想转换维度,要么直接Tensor转换,要么转置

class Mydataset(Dataset):
    """
    TypeError: show() takes 1 positional argument but 2 were given
    现在直接读取6张图片,为什么

    """

    def __init__(self, txt_name, train=True, transform=None, target_tranform=None, loader=default_loader, identify=None):
        super(Mydataset, self).__init__()
        # self.img_label =
        line_list = []
        if train:
            file_name = r'D:/python/pytorch_learn/install/name.txt'
        else:
            file_name = r'D:/python/pytorch_learn/install/test_ph.txt'
        if identify:
            file_name = r'D:/python/pytorch_learn/install/CESHI.TXT'
            print('这是验证集')
        with open(file_name, 'r') as fp:
            lines = fp.readlines()
            for i in lines:
                line = i.split()
                line_list.append(line)

        self.img = line_list
        self.transform = transform
        self.target_transform = target_tranform
        self.loader = loader

    def __len__(self):
        return len(self.img)

    def __getitem__(self, index):
        path, label = self.img[index]

        img = self.loader(path)  # 读取出来维度为3, 22, 14 ?? 无法正常显示图片
        # img = Image.open(path)
        img = np.array(img, dtype=np.float32)
        # # img = np.array(img, dtype=np.float32).reshape(14, 22, 3)

        # img = read_image(path)  # 读取数据不需要transform转换
        # print(img.size())  # torch.Size([3, 22, 14]) ————需要转至
        # img = np.array(img)  # 924
        # print(img.size)
        label_one_hot = one_hot(label)  # 不是独热编码,只是字符都变成了数字

        if self.transform:
            # img = torch.from_numpy(img)  # 所以numpy转Tensor唯独不变
            # print(img.size())
            img = self.transform(img)  # Totensor维度改变 torch.Size([3, 22, 14])

            # print(img.size())
            # img = img.squeeze(0)  # 只能去除维度=1的维度
            # img = img.T  # 现在可以正常显示,但是图片完全翻转了
            # print(img.size())
            # 转置前维度为(3, 22, 14), 转置后torch.Size([14, 22, 3])
            # print(label, img)
        return img, label_one_hot, path

if __name__ == "__main__":
    # mydata = Mydataset('name.txt')
    # print(mydata.__getitem__(5))

    train_data = Mydataset(
        txt_name='name.txt',  # 暂时未使用
        train=True,
        transform=ToTensor()
    )
    # data_train = DataLoader(train_data, batch_size=64)
    img, label, path = train_data[6]
    photo = Image.open(path)
    plt.title(label)
    plt.imshow(photo)
    plt.show()

返回三个变量的原因

正常情况下 __getitem__应该只返回图片张量和标签,但是我的图片经过Tensor转换后无法正常显示,因为图片Tensor转换后维度为[通道数, 长或宽,长或宽],正常图片显示维度为[ 长或宽,长或宽,通道数]

2021.10.22补充
写rnn分类图片的时候突然发现的代码,可以直接显示ToTensor转换后的图片
利用dataset的数据集提取可以做到

	train_data = torchvision.datasets.FashionMNIST(
        root='D\python\pytorch_learn\install\data',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=False,
    )
	plt.imshow(train_data.data[0].numpy(), cmap='gray')
    plt.show()

这个代码可以将显示Tensor的维度的代码

torch.Size([1, 28, 28])
# 一般读取图片的格式为[ 28, 28, 1]即通道数在尺寸后面

在这里插入图片描述

图片数据必须为float格式

直接读取的图片像素为整数,需要转换为np的同时调整为float再tranform要不然无法正常读取。
这一点最坑的是他不显示问题,就是预测不出来。不管什么图片,预测结果是’RRRR’。就是找不到问题。

plot显示图片数据,需要先imshow再show,直接show无法正常显示。

这点我还未找到原因,有时间再说。
或者也可以使用

plt.ion()

打开自动显示图像模式(我自己还没试)

也可以参考这篇文章链接: matplotlib中plot.show()不显示图片的问题.

在这里插入图片描述

pytorch自己会转换独热编码,你不能输入独热编码,不然会报错

输入必须为数字,不能是字符‘a’

但是也不能输入字符,必须将输入转换为数字,因为’a’无法进行矩阵运算,也无法转换为独热编码,需要你自己将字符转换为数字,怎么转化随意,比如我将’a’作为11,这样0-9,a-z转换为0-36。

有多少种结果转换多少个数字

一开始我想讲字符转为ASCii码,毕竟方便,z转换为96,此时生成的独热编码维度为[36, 96],大大浪费了计算资源。
z转换为36,独热编码维度为[36, 36],就很舒服。

  • 7
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值