1 np.random.permutation
np.random.seed()函数 - 知乎
https://zhuanlan.zhihu.com/p/266472620
顺序要一样,每次得到的结果也一样
import gdal
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
import time
# np.random.seed后,np.random.permutation出现的顺序必须一致才对应,否则不对应,可以理解为seed是堆
np.random.seed(0)
print(np.random.permutation(10))
print(np.random.permutation(4))
print("第一次")
np.random.seed(0)
print(np.random.permutation(4))
print(np.random.permutation(10))
print(np.random.permutation(4))
print("第二次")
np.random.seed(0)
print(np.random.permutation(10))
print(np.random.permutation(4))
print("第三次")
np.random.seed(0)
print(np.random.permutation(10))
print(np.random.permutation(4))
2 训练集和验证集
def get_voc_datagen(train_img_path,train_label_path,val_img_path=None,val_label_path=None,num_class, batch_size, target_size):
"""
进行文件的读取和预处理
:param train_img_path:训练集影像文件夹
:param train_label_path:
:param val_img_path:如果验证集为None,就从训练集中拿出20%
:param val_label_path:
:param num_class:
:param batch_size:
:param img_preprocess:
:param mask_preprocess:
:param target_size:
:return:
"""
train_imgs = os.listdir(train_img_path)
train_labels = os.listdir(train_label_path)
train_image_paths = [os.path.join(train_img_path, imgname) for imgname in train_imgs]
train_label_paths = [os.path.join(train_label_path, labelname) for labelname in train_labels]
# 设置随机数种子,之后打乱数据集
np.random.seed(0)
index = np.random.permutation(len(train_image_paths))
train_image_paths = np.array(train_image_paths)[index]
train_label_paths = np.array(train_label_paths)[index]
# 如果验证集为None, 就从训练集中拿出20%当作验证集
if val_img_path == None and val_label_path==None:
train_number = int(len(train_image_paths) * 0.8)
val_image_paths = train_image_paths[train_number:]
val_label_paths = train_label_paths[train_number:]
train_image_paths = train_image_paths[:train_number]
train_label_paths = train_label_paths[:train_number]
else:
val_images = os.listdir(val_img_path)
val_labels = os.listdir(val_label_path)
val_image_paths = [os.path.join(val_img_path, imgname) for imgname in val_images]
val_label_paths = [os.path.join(val_label_path, imgname) for imgname in val_labels]
# 检查标签和图像一一对齐
for i in range(len(train_image_paths)):
img = train_image_paths[i].split("\\")[-1]
label = train_label_paths[i].split("\\")[-1]
if img[:-4] != label[:-4]:
print("出错了" * 1000)
for i in range(len(val_image_paths)):
img = val_image_paths[i].split("\\")[-1]
label = val_label_paths[i].split("\\")[-1]
if img[:-4] != label[:-4]:
print("出错了" * 1000)
return train_image_paths, train_label_paths,val_image_paths, val_label_paths,