做一个图像分类任务(一)准备数据集

环境准备

推荐按照原作者推荐的环境:代码测试云GPU环境:GPU RTX 3060、CUDA v11.2;虽然本博主没有用。

准备数据集

这里的数据集是在百度上下载的。

对于深度学习的数据集,为了避免数据的分布而影响到模型的训练,数据集应当要尽可能包含目标物体的各类场景,并且各种场景下的图像的数量尽可能相近,这样训练出来的模型才具有很好的泛化性能。数据集应该包含:

  1. 不同尺寸、比例的图像

  2. 不同拍摄环境(光照、设备、拍摄角度、遮挡、远近、大小)

  3. 不同形态(完整西瓜、切瓣西瓜、切块西瓜)

  4. 不同部位(全瓜、瓜皮、瓜瓤、瓜子)

  5. 不同时期(瓜秧、小瓜、大瓜)

  6. 不同背景(人物、菜地、抠图)

  7. 不同图像域(照片、漫画、剪贴画、油画)

这里的数据集存放是在\dataset\子目录,如果需要自建自己的数据集,需要按照同样的数据结果对数据进行处理,否则可能会出现数据导入错误。

爬取数据代码

# coding:gbk
import os
import time
import requests
import urllib3
urllib3.disable_warnings()

# 进度条库
from tqdm import tqdm

cookies = {
'BDqhfp': '%E7%8B%97%E7%8B%97%26%26NaN-1undefined%26%2618880%26%2621',
'BIDUPSID': '06338E0BE23C6ADB52165ACEB972355B',
'PSTM': '1646905430',
'BAIDUID': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
'BDORZ': 'B490B5EBF6F3CD402E515D22BCDA1598',
'H_PS_PSSID': '35836_35105_31254_36024_36005_34584_36142_36120_36032_35993_35984_35319_26350_35723_22160_36061',
'BDSFRCVID': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
'H_BDCLCKID_SF': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
'BDSFRCVID_BFESS': '8--OJexroG0xMovDbuOS5T78igKKHJQTDYLtOwXPsp3LGJLVgaSTEG0PtjcEHMA-2ZlgogKK02OTH6KF_2uxOjjg8UtVJeC6EG0Ptf8g0M5',
'H_BDCLCKID_SF_BFESS': 'tJPqoKtbtDI3fP36qR3KhPt8Kpby2D62aKDs2nopBhcqEIL4QTQM5p5yQ2c7LUvtynT2KJnz3Po8MUbSj4QoDjFjXJ7RJRJbK6vwKJ5s5h5nhMJSb67JDMP0-4F8exry523ioIovQpn0MhQ3DRoWXPIqbN7P-p5Z5mAqKl0MLPbtbb0xXj_0D6bBjHujtT_s2TTKLPK8fCnBDP59MDTjhPrMypomWMT-0bFH_-5L-l5js56SbU5hW5LSQxQ3QhLDQNn7_JjOX-0bVIj6Wl_-etP3yarQhxQxtNRdXInjtpvhHR38MpbobUPUDa59LUvEJgcdot5yBbc8eIna5hjkbfJBQttjQn3hfIkj0DKLtD8bMC-RDjt35n-Wqxobbtof-KOhLTrJaDkWsx7Oy4oTj6DD5lrG0P6RHmb8ht59JROPSU7mhqb_3MvB-fnEbf7r-2TP_R6GBPQtqMbIQft20-DIeMtjBMJaJRCqWR7jWhk2hl72ybCMQlRX5q79atTMfNTJ-qcH0KQpsIJM5-DWbT8EjHCet5DJJn4j_Dv5b-0aKRcY-tT5M-Lf5eT22-usy6Qd2hcH0KLKDh6gb4PhQKuZ5qutLTb4QTbqWKJcKfb1MRjvMPnF-tKZDb-JXtr92nuDal5TtUthSDnTDMRhXfIL04nyKMnitnr9-pnLJpQrh459XP68bTkA5bjZKxtq3mkjbPbDfn02eCKuj6tWj6j0DNRabK6aKC5bL6rJabC3b5CzXU6q2bDeQN3OW4Rq3Irt2M8aQI0WjJ3oyU7k0q0vWtvJWbbvLT7johRTWqR4enjb3MonDh83Mxb4BUrCHRrzWn3O5hvvhKoO3MA-yUKmDloOW-TB5bbPLUQF5l8-sq0x0bOte-bQXH_E5bj2qRCqVIKa3f',
'indexPageSugList': '%5B%22%E7%8B%97%E7%8B%97%22%5D',
'cleanHistoryStatus': '0',
'BAIDUID_BFESS': '104BD58A7C408DABABCAC9E0A1B184B4:FG=1',
'BDRCVFR[dG2JNJb_ajR]': 'mk3SLVN4HKm',
'BDRCVFR[-pGxjrCMryR]': 'mk3SLVN4HKm',
'ab_sr': '1.0.1_Y2YxZDkwMWZkMmY2MzA4MGU0OTNhMzVlNTcwMmM2MWE4YWU4OTc1ZjZmZDM2N2RjYmVkMzFiY2NjNWM4Nzk4NzBlZTliYWU0ZTAyODkzNDA3YzNiMTVjMTllMzQ0MGJlZjAwYzk5MDdjNWM0MzJmMDdhOWNhYTZhMjIwODc5MDMxN2QyMmE1YTFmN2QyY2M1M2VmZDkzMjMyOThiYmNhZA==',
'delPer': '0',
'PSINO': '2',
'BA_HECTOR': '8h24a024042g05alup1h3g0aq0q',
}

headers = {
'Connection': 'keep-alive',
'sec-ch-ua': '" Not;A Brand";v="99", "Google Chrome";v="97", "Chromium";v="97"',
'Accept': 'text/plain, */*; q=0.01',
'X-Requested-With': 'XMLHttpRequest',
'sec-ch-ua-mobile': '?0',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.99 Safari/537.36',
'sec-ch-ua-platform': '"macOS"',
'Sec-Fetch-Site': 'same-origin',
'Sec-Fetch-Mode': 'cors',
'Sec-Fetch-Dest': 'empty',
'Referer': 'https://image.baidu.com/search/index?tn=baiduimage&ipn=r&ct=201326592&cl=2&lm=-1&st=-1&fm=result&fr=&sf=1&fmq=1647837998851_R&pv=&ic=&nc=1&z=&hd=&latest=&copyright=&se=1&showtab=0&fb=0&width=&height=&face=0&istype=2&dyTabStr=MCwzLDIsNiwxLDUsNCw4LDcsOQ%3D%3D&ie=utf-8&sid=&word=%E7%8B%97%E7%8B%97',
'Accept-Language': 'zh-CN,zh;q=0.9',
}

# 爬取图像并且保存到本地中
def craw_single_data(keyword,DOWNLOAD_NUM=200):
    """
    keyword:参数,表示要爬取的数据的关键词;
    DOWNLOAD_NUM:参数,表示要爬取的数据的数量;
    返回值:无;
    爬取的数据会保存在当前文件的同一层目录喜爱的 dataset/keyword 目录中,本函数调用一次可以和生成一个keyword的数据
    """
    if os.path.exists('dataset/'+keyword):
        print('文件夹 dataset/{} 已经存在,之后爬取的数据将直接保存在该文件夹中'.format(keyword))
    else:
        os.makedirs('dataset/{}'.format(keyword))
        print('新建文件夹 dataset/{}'.format(keyword))

    count=1
    with tqdm(total=DOWNLOAD_NUM,position=0,leave=True) as pbar:
        # 爬取第几张
        num=0
        # 是否继续爬取
        FLAG=True

        while FLAG:
            page=30*count
            params=(
                ('tn', 'resultjson_com'),
                ('logid', '12508239107856075440'),
                ('ipn', 'rj'),
                ('ct', '201326592'),
                ('is', ''),
                ('fp', 'result'),
                ('fr', ''),
                ('word', f'{keyword}'),
                ('queryWord', f'{keyword}'),
                ('cl', '2'),
                ('lm', '-1'),
                ('ie', 'utf-8'),
                ('oe', 'utf-8'),
                ('adpicid', ''),
                ('st', '-1'),
                ('z', ''),
                ('ic', ''),
                ('hd', ''),
                ('latest', ''),
                ('copyright', ''),
                ('s', ''),
                ('se', ''),
                ('tab', ''),
                ('width', ''),
                ('height', ''),
                ('face', '0'),
                ('istype', '2'),
                ('qc', ''),
                ('nc', '1'),
                ('expermode', ''),
                ('nojc', ''),
                ('isAsync', ''),
                ('pn', f'{page}'),
                ('rn', '30'),
                ('gsm', '1e'),
                ('1647838001666', ''),
            )
            # 关闭SSL验证
            response = requests.get('https://image.baidu.com/search/acjson', headers=headers, params=params, cookies=cookies,verify=False)
            if response.status_code==200:
                try:
                    json_data=response.json().get("data")

                    if json_data:
                        for x in json_data:
                            type=x.get('type')
                            if type not in ['gif']:
                                img=x.get('thumbURL')
                                fromPageTitleEnc=x.get('fromPageTitleEnc')
                                try:
                                    resp=requests.get(url=img,verify=False)
                                    time.sleep(1)

                                    # 保存到文件夹中
                                    file_save_path=f'dataset/{keyword}/{num}.{type}'
                                    with open(file_save_path,'wb') as f:
                                        f.write(resp.content)
                                        f.flush()
                                        num+=1
                                        # 更新进度条
                                        pbar.update(1)
                                    
                                    if num>DOWNLOAD_NUM:
                                        FLAG=False
                                        print('{}张图像爬取完毕'.format(num))
                                        break
                                except Exception:
                                    pass
                except Exception:
                    pass
            else:
                break
            count+=1



# 调用测试
# class_list=['黄瓜','南瓜']
# for each in class_list:
#     craw_single_data(each,DOWNLOAD_NUM=5)

测试结果:\dataset\
在这里插入图片描述
在这里插入图片描述

删除多余的文件

因为模型训练时是直接导入\dataset\下所有的子文件的,所以如果\dataset\以及它的子文件下有其他非法的文件,会导致训练出错,所以可以检查下 dataset 下有没有其他非法文件已经其子文件内有无非法文件。原作者有给出相应的代码,博主这里不太需要就省略了。

fruit81水果数据图像分类数据集下载

其他数据集请查看github。

点击下载

统计图像的尺寸和比例分布代码

这里可视化图像尺寸分布使用了散点密度图(主要用来计算样本点的出现次数,即密度)。

# coding:gbk
from ctypes.wintypes import PLARGE_INTEGER
from logging import warning
import os
import numpy as np
import cv2
from tqdm import tqdm
import pandas as pd
from scipy.stats import gaussian_kde
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
# %matplotlib inline # plt.show()

def visual_dimensions_and_proportions(dataset_path):
    """
    dataset_path:参数,表示原始数据的路径;
    返回值:无;
    本函数会生成原始数据集的图像宽和高的分布。
    可能遇到一个警告“ibpng warning: iCCP: known incorrect sRGB profile”
    """
    os.chdir(dataset_path)
    list=os.listdir()
    # 输出数据集下的子文件
    print(list)

    # 遍历图像
    df=pd.DataFrame()
    # 遍历每个类别
    # tqdm是快速可扩展的python进度条,可以在python长循环中添加一个进度提示信息,用户只需要封装任意的迭代器
    for fruit in tqdm(list):
        os.chdir(fruit)
        # 遍历每一张图像
        for file in os.listdir():
            try:
                img=cv2.imread(file)
                df=df.append({'类别':fruit,'文件名':file,'图像宽':img.shape[1],'图像高':img.shape[0]},ignore_index=True)
                
            except:
                print(os.path.join(fruit,file),'读取错误')
        os.chdir('../')
    os.chdir('../')
    # 打印信息
    print(df)

    # 可视化图像尺寸分布:散点密度图(主要用来计算样本点的出现次数,即密度)
    x=df['图像宽']
    y=df['图像高']
    # 一组一组(x,y)
    xy=np.vstack([x,y])
    # 通过gaussian_kde估计x和y的联合分布,得到数据点的密度z
    z=gaussian_kde(xy)(xy)
    # 按照密度排序,将密度最大的点排到最后
    idx=z.argsort()
    x,y,z=x[idx],y[idx],z[idx]

    plt.figure(figsize=(10,10))
    # s绘制点的大小,c绘制点的颜色,默认是蓝色,marker表示标记的样式,camp是浮点数数组时候使用
    plt.scatter(x,y,c=z,s=5,cmap='Spectral_r')
    plt.tick_params(labelsize=15)

    xy_max=max(max(df['图像宽']),max(df['图像高']))
    plt.xlim(xmin=0,xmax=xy_max)
    plt.ylim(ymin=0,ymax=xy_max)

    plt.xlabel('width',fontsize=25)
    plt.ylabel('height',fontsize=25)

    # 保存
    plt.savefig('图像尺寸分布.pdf',dpi=120,bbox_inches='tight')
    # 显示
    plt.show()

# 测试用例
# 指定数据集路径
# dataset_path='fruit81_full'
# visual_dimensions_and_proportions(dataset_path)

运行结果:
在这里插入图片描述

划分训练集和数据集代码

# coding:gbk
from email.mime import image
from multiprocessing.spawn import old_main_modules
import os
import shutil
import random
import pandas as pd


def split_data(dataset_path,dataset_name='dataset',test_frac=0.2):
    """
    dataset_path:需要划分训练集和测试集的原始数据路径;
    dataset_name:划分数据后的数据集命名;
    test_frac:测试集的比例;
    返回值:无;
    本函数用来分割训练集和测试集。
    """
    # 随机种子,便于复现
    random.seed(123)

    # 创建各类目录
    li=os.listdir(dataset_path)
    # 创建train目录
    os.mkdir(os.path.join(dataset_path,'train'))
    # 常见test目录
    os.mkdir(os.path.join(dataset_path,'test'))
    # 在训练集和测试集中分别创建各类的子文件夹
    for fruit in li:
        os.mkdir(os.path.join(dataset_path,'train',fruit))
        os.mkdir(os.path.join(dataset_path,'test',fruit))

    # 划分数据集
    df=pd.DataFrame()
    print('{:^18} {:^18} {:^18}'.format('类别','训练集数据个数','测试集数据个数'))
    for fruit in li:
        # 读取该类别下的所有图像文件名
        old_dir=os.path.join(dataset_path,fruit)
        images_filename=os.listdir(old_dir)
        # 随机打乱
        random.shuffle(images_filename)
        # 划分训练集和测试集
        # 测试集的数量
        testset_numer=int(len(images_filename)*test_frac)
        # 需要移动到测试集的图像名字
        testset_images=images_filename[:testset_numer]
        # 需要移动到测试集的图像名字
        trainset_images=images_filename[testset_numer:]

        # 移动数据集到test
        for image in testset_images:
            old_img_path=os.path.join(dataset_path,fruit,image)
            new_img_path=os.path.join(dataset_path,'test',fruit,image)
            shutil.move(old_img_path,new_img_path)

        # 移动数据集到train
        for image in trainset_images:
            old_img_path=os.path.join(dataset_path,fruit,image)
            new_img_path=os.path.join(dataset_path,'train',fruit,image)
            shutil.move(old_img_path,new_img_path)

        # 删除旧目录
        # 确保旧文件夹中的所有图像均被移走
        assert len(os.listdir(old_dir))==0
        # 删除文件夹
        shutil.rmtree(old_dir)

        # 输出每一个类别中数据的个数
        print('{:^18} {:^18} {:^18}'.format(fruit,len(trainset_images),len(testset_images)))
        # 保存到表格中
        df=df.append({'class':fruit,'trainset':len(trainset_images),'testset':len(testset_images)},ignore_index=True)

    # 重命名
    shutil.move(dataset_path,dataset_name)
    df['total']=df['trainset']+df['testset']
    df.to_csv('数据量统计.csv',index=False)

# 简单测试
# 指定数据集路径
dataset_path='fruit81_full'
dataset_name='dataset'
# 测试集的比例
test_frac=0.2
split_data(dataset_path,dataset_name,test_frac)
    

运行结果:
在这里插入图片描述

可视化图像代码

# coding:gbk
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
import math
import os
import cv2


def visual_dataset(path,N=36):
    """
    path:需要可视化的文件路径;
    N:可视化的图像的数量;
    本函数用来对图像进行可视化。
    """
    # n行n列
    n=math.floor(np.sqrt(N))

    # 读取文件夹中所有图像
    images=[]
    for img in os.listdir(path)[:N]:
        img_path=os.path.join(path,img)
        # 解决中文路径的问题
        img_bgr = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8),-1)
        # cv2读取的顺序为bgr,需要转换为正常的rgb
        img_rgb=cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
        images.append(img_rgb)

    # 画图
    fig=plt.figure(figsize=(10,10))
    grid=ImageGrid(fig,111,# 绘制子图
                    nrows_ncols=(n,n),# 创建n行n列的axes网格
                    axes_pad=0.02,# 网格间距
                    share_all=True)

    # 遍历每张图像
    for ax,im in zip(grid,images):
        ax.imshow(im)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

# 简单测试
path='dataset/train'
li=os.listdir(path)
for item in li:
    file_path=os.path.join(path,item)
    l=len(os.listdir(file_path))
    N=36 if 36<l else l
    visual_dataset(file_path,N)

运行结果(诸如此类):
在这里插入图片描述

统计各类别的数量

# coding:gbk
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt


def visual_quantity(path):
    """
    path:需要显示数据分布的表;
    本函数用来显示数据的分布。
    """
    # 中文字体
    matplotlib.rc('font',family='SimHei')
    # 用来正常显示负号
    plt.rcParams['axes.unicode_minus']=False

    # 导入表格
    df=pd.read_csv(path)
    print(df.shape)
    # 指定可视化的特征
    # feature='total'
    # df=df.sort_values(by=feature,ascending=False)

    plt.figure(figsize=(22,7))
    x=df['class']
    y1=df['testset']
    y2=df['trainset']
    # 柱状图宽度
    width=0.55
    # 横轴文字旋转
    plt.xticks(rotation=90)

    plt.bar(x,y1,width,label='测试集')
    plt.bar(x,y2,width,label='训练集',bottom=y1)

    plt.xlabel('类别',fontsize=20)
    plt.ylabel('图像数量',fontsize=20)
    # 设置坐标文字大小
    plt.tick_params(labelsize=13)
    # 图例
    plt.legend(fontsize=16)
    # 保存
    plt.savefig('各类别图像数量.pdf',dpi=120,bbox_inches='tight')
    plt.show()

# 简单测试
path='数据量统计.csv'
visual_quantity(path)

运行结果:
在这里插入图片描述

这里的代码与github上可能有一丢丢不同,大家想看全面点还是要看github上原版的代码以及原作者的视频讲解,博主这里只是博主按需学习的一些经验哦!
声明:这里只是做一个图像分类任务,不作为商业用途,不涉及任何利益交易,甚至不作为毕业设计。侵权可以删哦!

  • 14
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值