训练集数据量过大受设备内存影响不能将全部数据直接放到网络中进行训练,需要分批读取训练数据。
train_x为训练集地址,train_y为训练集标签,val_X为验证集数据,val_y为验证集标签。
分批读取函数如下:
def dataset_split(images, labels, batch_size):
while 1:
i = 0
n = math.ceil(len(images)/batch_size)
print(n)
for j in range(n):
if j != n-1:
x = images[i : i + batch_size]
y = labels[i : i + batch_size]
i = i + batch_size
X = []
for m in range(len(x)):
a = cv2.imread(x[m])
#print(type(a))
a = a.tolist()
#print(type(a))
X.append(a)
X = np.array(X)
yield X, y
if j == n-1:
x = images[len(images)-batch_size: ]
y = labels[len(labels