python制作自定义数据集(cifar+txt版)

第一种(CIFAR)

readCifar.py

这里主要是拿制作的样本格式跟原样本格式比对,输出一下

import pickle
import numpy as np
import chardet


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='latin-1')
    return dict


cc1 = unpickle("E:/pythonwork/makemycifar/img2array.bin")
print(cc1)
cc2 = unpickle("E:/pythonwork/makemycifar/data_batch_0")
print(cc2)

第一个字典{}里是最终完善后的,第二个是原格式:
在这里插入图片描述
可以看到,数据集由字典dic构成,格式为
data(数据,用矩阵表示)、labels(类标签)、filenames(图片名称)、batch_label(整个模块标签)

cc1为我改写的样本,cc2为cifar数据集的格式。经过改写,下方原样本格式是基本一样,可行。

dump.py

这里是获得四个标签

from PIL import Image
import numpy as np
import pickle, glob, os
# 加载标签
def seplabel(fname):
    filestr = fname.split(".")[0]
    label = int(filestr.split("_")[0])
    return label
    
arr = [[]]
n = 1

for infile in glob.glob("E:/pythonwork/makemycifar/zuichutup/*.png"):
    file, ext = os.path.splitext(infile)
    Img = Image.open(infile)
    print(Img.mode, file)  # 图片尺寸和文件名(用于调试过程中定位错误)

    if Img.mode != 'RGB':
        Img = Img.convert('RGB')

    width = Img.size[0]
    height = Img.size[1]
    # print('{} imagesize is:{} X {}'.format(n, width, height))
    n += 1  # zzz

    if width > height:
        tmp = int((width - height) / 2)
        Img = Img.crop([width - height - tmp, 0, width - tmp, height])
    if height > width:
        tmp = int((height - width) / 2)
        Img = Img.crop([0, height - width - tmp, width, height - tmp])
    size = 32, 32
    Img.thumbnail(size, Image.ANTIALIAS)

    r, g, b = Img.split()
    r_array = np.array(r).reshape([1024])
    g_array = np.array(g).reshape([1024])
    b_array = np.array(b).reshape([1024])
    merge_array = np.concatenate((r_array, g_array, b_array))
    if arr == [[]]:
        arr = [merge_array]
        continue  # 拼接
    arr = np.concatenate((arr, [merge_array]), axis=0)  # 打乱顺序
    # 定义标签和文件名
labels = []
filenames = []

imglist = glob.glob(r"E:/pythonwork/makemycifar/zuichutup/*.png")  # 将图片存入列表
print(imglist[0])

num = len(imglist)  # 获取路径下图片的长度 11
for i in range(0, num):
    imglist2 = os.path.basename(imglist[i])  # 获取路径下图片名称

    labels.append(seplabel(imglist2))
    print("aaa")
    filenames.append(imglist2)
    print(imglist2)
# 加入字典
dic = {'data': arr, 'labels': labels, 'filenames': filenames, 'batch_label': 'testing batch 1 of 1'}

f = open('E:/pythonwork/makemycifar/img2array.bin', 'wb')
pickle.dump(dic, f)

为了保证数据集无纰漏,我一步步测试,把数据一点点抠出来比对,发现没问题后,装入list,最后放入dic字典中。
在这里插入图片描述

第二种(CIFAR)

也是cifar格式的制作过程参考:参考(点←)。
在这里插入图片描述
第一个跟第三个
这里放的原图,运行后自动处理成小图了

在这里插入图片描述
这里放原格式,最后就是会直接修改他的原格式
在这里插入图片描述
这里放原图
在这里插入图片描述
这里放原格式。
然后是代码:

# -*- coding: utf-8 -*-
import numpy as np
from numpy import *
from PIL import Image
# import operator
from os import listdir
# import sys
import pickle
import time

# import random
data = {}
list1 = []
list2 = []
list3 = []
list4 = []


def seplabel(fname):
    filestr = fname.split(".")[0]
    label = int(filestr.split("_")[0])
    return label


def img_tra():
    for k in range(0, num):
        currentpath = folder_init + "/" + imglist[k]
        im = Image.open(currentpath)
        # width=im.size[0]
        # height=im.size[1]
        x_s = 32
        y_s = 32
        out = im.resize((x_s, y_s), Image.ANTIALIAS)
        print('sss')
        print(str(imglist[k]))
        print('sss')
        out.save(folder_changed + "/" + str(imglist[k]))

def addWord(theIndex, word, adder):
    theIndex.setdefault(word, []).append(adder)


def mkcf():
    global data
    global list1
    global list2
    global list3
    global list4
    for k in range(0, num):
        currentpath = folder_changed + "/" + imglist[k]
        im = Image.open(currentpath)
        with open(binpath, 'a') as f:
            for i in range(0, 32):
                for j in range(0, 32):
                    cl = im.getpixel((i, j))
                    # print(imglist[k])
                    # print(type(cl[0]))
                    # with open(binpath, 'a') as f:
                    # print(str(cl[0]))
                    list1.append(cl[0])
                    # print(list1)

            for i in range(0, 32):
                for j in range(0, 32):
                    cl = im.getpixel((i, j))
                    list1.append(cl[1])

            for i in range(0, 32):
                for j in range(0, 32):
                    cl = im.getpixel((i, j))
                    list1.append(cl[2])
        list2.append(list1)
        list1 = []

        f.close()
        print("image" + str(k + 1) + "saved.")
        lfilenames.append(imglist[k].encode('utf-8'))
    arr22 = np.array(list2, dtype=np.uint8)

    print(lj1.shape)

    arr2 = np.concatenate((lj1, arr22))
    data['batch_label'.encode('utf-8')] = 'training batch 5 of 5'.encode('utf-8')

    data.setdefault('labels'.encode('utf-8'), llabel)
    data.setdefault('data'.encode('utf-8'), arr2)

    data.setdefault('filenames'.encode('utf-8'), lfilenames)
    # addWord(cifar10,'filenames'.encode('utf-8'),list3)
    output = open(binpath, 'wb')
    pickle.dump(data, output)
    output.close()


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='iso-8859-1')
    return dict



cc = unpickle("E:/pythonwork/finaldata/yuanshi/data_batch_5")

total = 10000
lj1 = cc['data']


def delpar():
    global lj1
    del cc['labels'][i - n]
    del cc['filenames'][i - n]
    lj1 = np.delete(lj1, i - n, 0)


lj = []
n = 0
for i in range(0, total):
    if cc['labels'][i - n] == 9:
        delpar()
        n = n + 1
llabel = cc['labels']
lfilenames = cc['filenames']

folder_init = "E:/pythonwork/finaldata/init"
folder_changed = "E:/pythonwork/finaldata/data_batch_5"
print("****")
print(folder_changed)
imglist = listdir(folder_changed)
print("****")
print(imglist)
print("****")
num = len(imglist)
print("****")
print(num)
print("****")
img_tra()  #这个方法目录下的是把你的图片处理了再保存,所以直接把图片放到他的目录下就可以了
# label=[]
for i in range(0, num):
    lj = seplabel(imglist[i])
    llabel.append(lj)
    # print(label)
    # binpath = "F:/tiaotong/get/data_batch_5"
binpath = "E:/pythonwork/finaldata/get/data_batch_5"

mkcf()

在这里插入图片描述
在这里插入图片描述
看更新时间,看来成功了。
在这里插入图片描述
就是不知道为什么多个’b’

第三种(txt格式,个人常用)

#!/usr/bin/env python
# coding: utf-8
import os, random, shutil
from PIL import Image
import numpy as np
from torchvision import datasets, transforms
import torchvision
import torch

# 创建Map,这个Map的Key是图片的路径,而他的Value就是Label
dataMap = {}
base_path = os.getcwd()
# phonepicture目录下有若干个,表示类别
categories = os.listdir(r'E:/pythonwork/phonepicture')

print(categories)
rawdata_path = os.path.join(base_path, r"E:/pythonwork/phonepicture")
print(rawdata_path)
for c in categories:
    # a_image_folder里面存的是当前文件夹的路径,如果再加上图片名就是该图片的绝对路径了,而这正是我们想要的。
    a_image_folder = os.path.join(rawdata_path, c)
    # 某一类别的图片名字(不是绝对路径)都在image_files里面了,而c就是他的类别名
    image_files = os.listdir(a_image_folder)
    for image in image_files:
        dataMap[a_image_folder + "\\" + image] = c

# 将其存储为Txt文件
with open(r"E:/pythonwork/phonetype.txt", 'w') as f:
    for k, v in dataMap.items():
        f.write(k + ' ' + v + '\n')

# 读取Txt文件
with open(r"E:/pythonwork/phonetype.txt", 'r') as f:
    theList = f.readlines()
print(theList[1].split()[0])
print(theList[1].split()[1])


class MyDataset(torch.utils.data.Dataset):  # 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self, root, transform=None, target_transform=None):
        super(MyDataset, self).__init__()
        # 按照传入的路径和txt文本参数,打开这个文本,并读取内容
        fh = open(root, 'r')
        # 创建List
        self.dataList = fh.readlines()
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        # 这个Index应该就是你给的List的len,他都有了嘛,所以他可以随机取或者怎么的
        # 但是这里的getitem真的就是只取了一个!item,这里就是 一行 (绝对路径加Label)
        path_label = self.dataList[index]
        path, label = path_label.split()[0], path_label.split()[1]
        #         print(path_label)
        #         print(path)
        #         print(label)

        img = Image.open(path)
        # img.show()
        img = transforms.ToTensor()(img)
        return img, label

    def __len__(self):  # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.dataList)


myData = MyDataset(r"E:/pythonwork/phonetype.txt")
# 注意Dataset和Dataloader的区别
train_dataloader = torch.utils.data.DataLoader(myData, batch_size=1, shuffle=True)
image, label = iter(train_dataloader).next()
torchvision.transforms.ToPILImage()(image.squeeze(0)).show()
print(label)

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sinysama

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值