基于pytorch,numpy和opencv对训练图像数据集进行划分,分为训练集和验证集
前言
今晚又是在工位熬夜的一天,没办法,我实在是太菜了,只能熬夜学习了,说起来都是以前自己过得太轻松了,导致我现在不得不使劲补基础。
好了回归正题,为什么我要写这篇博客,最近我在github上找到了一个指纹识别的代码,不过这个代码里面用到了训练集(train),测试集(eval)和验证集(val)。一开始我想测试一下这个代码能不能跑,所以测试集和验证集我用的是一个数据集,结果可以跑啊,不过网络效果就是很不错,毕竟用的测试集就是验证集。所以我现在想以8:2的形式把训练集分为训练集和测试集,没有找到相关的代码,所以今晚就熬夜研究了一下,研究出来了,记录一下这个过程,顺便心疼一下我的头发。
参考
本文感谢以下参考博客:
以及我之前写过的一个博客:
3. 基于python和Opencv将多张图片结合为一张图片的办法
思路分析
先介绍一下我的情况啊,我用的是livdet数据集,这个数据集下面的分类是:传感器-真/假(假的话下面还会分一个材料),大概是这样的啊:
LivDet-LivDet_2009-Training-Biometrika-Alive
然后这个文件夹下面都是图片了,本次测试文件夹为tif文件。
看了参考-1大佬的代码,我也用了这个函数,不过我测试的时候,我发现好像这个东西只能分割数字本身,而不能分割图片,当然也有可能我的代码写的比较垃圾,没试出来,这里只介绍我个人思路。
- 首先读入一个文件夹中的图片,保存在一个numpy数组里面
- 读取数组的长度,并保存其索引
- 将索引以8:2的比例分割一下,并打乱其顺序
- 根据打乱的索引找到对应的图片
- 将图片保存到对应的路径里面
代码分析
需要用到的第三方模块
opencv
numpy
torch
torchvision
glob
导入模块
import torch
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
import glob
import numpy as np
import cv2
定义打开文件夹图片函数
关于这一点,可以看一下参考-3当中的用法
def open_image(path1):
img_path = glob.glob(path1)
return np.array([cv2.imread(true_path,0) for true_path in img_path])
定义根据索引读取图片函数
这里参数的意义是:
images:上一个函数打开过的图片数组,每张图片都在这个大数组里面
list1:训练集的索引,我将其保存在了这个列表里面
list2:验证集的索引,同上
代码相当简单,相信看一眼就明白了,毕竟老夫也不是什么厉害人,写不出漂亮代码
def generate_image(images,list1,list2):
a = len(list1)
b = len(list2)
res1 = []
res2 = []
for i in range(a):
res1.append(images[i])
for j in range(b):
res2.append(images[j])
return res1,res2
主程序内容
关于random_split的使用可以看一下参考-1,这个大佬说的很明白,我一开始看人家的,我以为十分制分割呢,结果我一开始就写的lengths=[8:2],然后疯狂报错,我才发现是写你要分割的具体数量,这里我一共有520张图片,分一下就是416张训练集,104张验证集。
all_data = open_image('LivDet/LivDet_2009/Traning/Biometrika/Alive/*') # 保存图片
num = range(len(all_data)) # 获得图片长度
train_data_num, val_data_num = random_split(dataset=num,
lengths=[416,104]) # 得到打乱过后的索引
print(list(train_data_num)) # 一会给你们看看打印出来的效果
print(list(val_data_num))
output_dir = '/dataset/livdet2009/train/Biometrika/live/' # 设置你要保存的路径
output_dirr = '/dataset/livdet2009/val/Biometrika/live/'
train_data,test_data = generate_image(all_data,list(train_data_num),list(val_data_num)) # 给图片的过程
for i,img in enumerate(train_data):
cv2.imwrite(output_dir+str(i)+'.tif',img) # 保存训练集
for j,imgg in enumerate(test_data):
cv2.imwrite(output_dirr+str(j)+'.tif',imgg) # 保存验证集
给你们看看print那两句话得到的效果:
可以看到啊,这里成功把索引打乱了,可以确保得到的图片具有随机性,使得网络效果更好
完整版代码,方便复制粘贴
import torch
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
import glob
import numpy as np
import cv2
def open_image(path1):
img_path = glob.glob(path1)
return np.array([cv2.imread(true_path,0) for true_path in img_path])
def generate_image(images,list1,list2):
a = len(list1)
b = len(list2)
res1 = []
res2 = []
for i in range(a):
res1.append(images[i])
for j in range(b):
res2.append(images[j])
return res1,res2
all_data = open_image('LivDet/LivDet_2009/Traning/Biometrika/Alive/*') # 保存图片
num = range(len(all_data)) # 获得图片长度
train_data_num, val_data_num = random_split(dataset=num,
lengths=[416,104]) # 得到打乱过后的索引
print(list(train_data_num)) # 一会给你们看看打印出来的效果
print(list(val_data_num))
output_dir = '/dataset/livdet2009/train/Biometrika/live/' # 设置你要保存的路径
output_dirr = '/dataset/livdet2009/val/Biometrika/live/'
train_data,test_data = generate_image(all_data,list(train_data_num),list(val_data_num)) # 给图片的过程
for i,img in enumerate(train_data):
cv2.imwrite(output_dir+str(i)+'.tif',img) # 保存训练集
for j,imgg in enumerate(test_data):
cv2.imwrite(output_dirr+str(j)+'.tif',imgg) # 保存验证集
总结
代码很好懂啊,随便看看就明白了,实在不明白私信我,我给你讲
防火防盗防诈骗