目前网上有很多公布的代码,但是,我们如何使用自己的数据对其训练呢,可以使用以下的代码试试改程序的输入。如果有错误,欢迎批评指正,谢谢。
from glob import glob
import random
import numpy as np
import os
batch_size = 8
step = 100 #程序运行的总代数
def get_batch(batch_filename):
#eg .对于二分类程序,标签的形式是[0,1],[1,0]的形式
# 对于分割程序,标签的形式和图片的形式相同,本程序就是针对这部分的
batch_array = []
batch_label = []
for npy in batch_filename:
try:
arr = np.load(npy)
arr=arr[:,:,:,np.newaxis]
arr_mask=np.load(npy.replace('img', 'mask'))
arr_mask = arr_mask[:, :, :, np.newaxis] #根据自己的需要更改就行
batch_array.append(arr)
batch_label.append(arr_mask)
except Exception as e:
print("file not exists! %s"%npy)
batch_array.append(batch_array[-1])
return np.array(batch_array),np.array(batch_label)
#加载数据
root = '数据的路径'
every_file = glob(os.path.join(root, "*_img.npy"))
times = int(len(every_file)/batch_size)
for i in range(1, step):
random.shuffle(every_file)
for t in range(times): # 每代中batchsize运行的次数
batch_files = every_file[t * batch_size:(t + 1) * batch_size]
img, mask = get_batch(batch_files)