python装饰器应用之keras数据生成器

这里仅仅是装饰器的一个简单应用,与平常唯一不同的地方就是我把装饰器写在了类的外部,而且装饰器内部还调用了类中的方法。只要在wrapper参数里加上self即可。还是直接看代码吧,这是我用于ocr训练的数据生成器。

def generate(func):
    def wrapper(self, *args, **kwargs):
        index_all, batch_size = func(self, *args, **kwargs)
        i, n = 0, len(index_all)
        while True:
            if i + batch_size >= n:
                np.random.shuffle(index_all)
                i = 0
                continue
            batch_x, batch_y = [], []
            batch_input_length = np.ones(batch_size) * (max_img_weigth // 8)
            batch_label_length = []
            for j in range(i, i + batch_size):
                x, y = self.get_img_data(index_all[j])
                batch_x.append(x)
                batch_y.append(y)
                batch_label_length.append(self.label_length[j])
            i += batch_size
            yield [np.array(batch_x),
                   np.array(batch_y),
                   batch_input_length,
                   np.array(batch_label_length)], np.ones(batch_size)
    return wrapper


class ChineseDataset(object):
    def __init__(self):
        mat_annotation = loadmat(label_mat)
        self.img_dir = img_dir
        self.filenames = mat_annotation['img']
        self.labels = mat_annotation['label']
        self.label_length = mat_annotation['label_length'][0]

    def get_train_num(self):
        return int(len(self.filenames) * 0.8)

    def get_valid_num(self):
        return len(self.filenames) - int(len(self.filenames) * 0.8)

    def get_img_data(self, index):
        img = cv2.imread(os.path.join(self.img_dir, self.filenames[index]))
        img = cv2.resize(img, (max_img_weigth, max_img_height)) / 255.
        label = one_hot(self.labels[index])
        return img, label

    @generate
    def gen_train(self, batch_size):
        index_all = list(range(int(len(self.filenames) * 0.8)))
        return index_all, batch_size

    @generate
    def gen_valid(self, batch_size):
        index_all = list(range(int(len(self.filenames) * 0.8), len(self.filenames)))
        return index_all, batch_size
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值