Kreas中Sequence的使用样例

训练TF模型时发现电脑内存太小了,无法处理一万张720p的图片(VOC格式数据集),于是改用Sequence进行训练迭代,有效减少内存的要求。代码如下:

from tensorflow.python.keras.utils.data_utils import Sequence
# 定义在C:\ProgramData\Anaconda3\envs\tf\Lib\site-packages\tensorflow_core\python\keras\utils\data_utils.py
import random, os, gc, cv2
import numpy as np
from xml.dom.minidom import parse
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import img_to_array, load_img
seed = 295
random.seed(seed)

class SequenceData(Sequence):
    '''
    xmlPath, imgPath, cutSize=(0,0.7), batch_size=32, size=(720,1280)
    xml文件路径, img文件路径, 训练集切分起始比例, 批次大小, 图片大小
    '''
    def resize_img_keep_ratio(self, img_name,target_size):
        '''
        1.resize图片,先计算最长边的resize的比例,然后按照该比例resize。
        2.计算四个边需要padding的像素宽度,然后padding
        '''
        img = cv2.imread(img_name)
        old_size = img.shape[0:2]
        ratio = min(float(target_size[i])/(old_size[i]) for i in range(len(old_size)))
        new_size = tuple([int(i*ratio) for i in old_size])
        img = cv2.resize(img,(new_size[1], new_size[0]),interpolation=cv2.INTER_CUBIC)  #注意插值算法
        pad_w = target_size[1] - new_size[1]
        pad_h = target_size[0] - new_size[0]
        top,bottom = pad_h//2, pad_h-(pad_h//2)
        left,right = pad_w//2, pad_w -(pad_w//2)
        img_new = cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,None,(0,0,0))
        return cv2.cvtColor(img_new, cv2.COLOR_BGR2RGB)
    def resize_img(self, img_name,target_size):
        img = cv2.imread(img_name)
        img = cv2.resize(img,(target_size[1], target_size[0]),interpolation=cv2.INTER_CUBIC)
        return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    def __init__(self, xmlPath, imgPath, cutSize=(0,0.7), batch_size=32, size=(720,1280)):
        self.xmlPath = xmlPath
        self.imgPath = imgPath
        self.batch_size = batch_size
        self.cutSize = cutSize
        self.datas = os.listdir(xmlPath) # 读入文件列表
        random.shuffle(self.datas)
        self.L = len(self.datas)
        self.datas = self.datas[int(np.ceil(self.L*cutSize[0])):int(np.ceil(self.L*cutSize[1]))] # 切分训练集比例
        self.L = len(self.datas)
        self.size = size
        self.index = random.sample(range(self.L), self.L)
        self.names = ['asleep', 'side asleep', 'quilt kicked', 'awake',
                        'on stomach', 'crying', 'face covered']  #7个类型
    #返回长度,通过len(<你的实例>)调用
    def __len__(self):
        return int(np.ceil(len(self.datas) / self.batch_size))
    #即通过索引获取a[0],a[1]这种
    def __getitem__(self, idx):
        batch_indexs = self.index[idx:(idx+self.batch_size)]
        batch_datas = [self.datas[k] for k in batch_indexs]
        images,labels = self.data_generation(batch_datas)
        return images,labels

    def data_generation(self, batch_datas):
        #预处理操作
        data=[] # 分类数据
        img=[]  # 图片顺序
        for file in batch_datas:
            DOMTree = parse(os.path.join(self.xmlPath,file)) #读取XML文件
            imgName = file.replace('xml','jpg') #标签有问题,直接读取文件名
            if os.path.exists(os.path.join(self.imgPath,imgName)): #如果图片文件存在,读入并压缩图片
                img.append(self.resize_img_keep_ratio(os.path.join(self.imgPath,imgName),(self.size[0],self.size[1])))
                name = []
                # for obj in DOMTree.documentElement.getElementsByTagName("object"): # 对于每个object标签(在带有多个标签的数据集上会导致loss爆炸)
                #     name.append(obj.getElementsByTagName("name")[0].childNodes[0].data)
                name.append(DOMTree.documentElement.getElementsByTagName("object")[0].getElementsByTagName("name")[0].childNodes[0].data) # 导入第一个标签
                data.append(name)
            else:
                img.append(np.zeros((self.size[0],self.size[1],3)).tolist())
                data.append([])
        return np.asarray(img)/255, np.asarray(MultiLabelBinarizer(classes=self.names).fit_transform(data)) # 多分类独热编码
    def get_label(self):
        #预处理操作
        data=[] # 分类数据
        for file in [self.datas[k] for k in self.index]:
            DOMTree = parse(os.path.join(self.xmlPath,file)) #读取XML文件
            imgName = file.replace('xml','jpg') #标签有问题,直接读取文件名
            if os.path.exists(os.path.join(self.imgPath,imgName)): #如果图片文件存在,读入并压缩图片
                name = []
                # for obj in DOMTree.documentElement.getElementsByTagName("object"): # 对于每个object标签(在带有多个标签的数据集上会导致loss爆炸)
                #     name.append(obj.getElementsByTagName("name")[0].childNodes[0].data)
                name.append(DOMTree.documentElement.getElementsByTagName("object")[0].getElementsByTagName("name")[0].childNodes[0].data) # 导入第一个标签
                data.append(name)
            else:
                data.append([])
        return np.asarray(MultiLabelBinarizer(classes=self.names).fit_transform(data)) # 多分类独热编码
    def showImg(self,i):
        from tensorflow.keras.preprocessing.image import array_to_img
        images,labels = self.data_generation([self.datas[i]])
        for ii,xx in enumerate(labels[0]):
            if xx > 0:
                print(self.names[ii],end=',')
        print()
        array_to_img(images[0]*255.0).show()
使用方法
config = {"batch":2, "epochs":10, "imageResize":(720,1280), "lr":1e-5, "cut_size":(0,0.7,0.85,1),}
def trainModelBySequence(xmlPath, imgPath):
    import DataGenSequence
    DGS_train = DataGenSequence.SequenceData(xmlPath, imgPath, 
                                             cutSize=(config["cut_size"][0],config["cut_size"][1]), 
                                             batch_size=config["batch"], size=config['imageResize'])
    DGS_val = DataGenSequence.SequenceData(xmlPath, imgPath, 
                                           cutSize=(config["cut_size"][1],config["cut_size"][2]), 
                                           batch_size=config["batch"], size=config['imageResize'])
    
    DGS_test = DataGenSequence.SequenceData(xmlPath, imgPath, 
                                            cutSize=(config["cut_size"][2],
                                                     config["cut_size"][3]), 
                                            batch_size=config["batch"], 
                                            size=config['imageResize'])
    from tensorflow.keras.callbacks import EarlyStopping
    early_stop = EarlyStopping(monitor='val_loss',patience=config["epochs"]/10,verbose=1,mode='auto')
    hist = model.fit_generator(generator=DGS_train,steps_per_epoch=int(len(DGS_train)),
                               validation_data=DGS_val,validation_steps=int(len(DGS_val)),
                               workers=20,use_multiprocessing=False,verbose=1,
                               epochs=config["epochs"],callbacks=[early_stop,metrics])
    
    return hist
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值