这里仅仅是装饰器的一个简单应用,与平常唯一不同的地方就是我把装饰器写在了类的外部,而且装饰器内部还调用了类中的方法。只要在wrapper参数里加上self即可。还是直接看代码吧,这是我用于ocr训练的数据生成器。
def generate(func):
def wrapper(self, *args, **kwargs):
index_all, batch_size = func(self, *args, **kwargs)
i, n = 0, len(index_all)
while True:
if i + batch_size >= n:
np.random.shuffle(index_all)
i = 0
continue
batch_x, batch_y = [], []
batch_input_length = np.ones(batch_size) * (max_img_weigth // 8)
batch_label_length = []
for j in range(i, i + batch_size):
x, y = self.get_img_data(index_all[j])
batch_x.append(x)
batch_y.append(y)
batch_label_length.append(self.label_length[j])
i += batch_size
yield [np.array(batch_x),
np.array(batch_y),
batch_input_length,
np.array(batch_label_length)], np.ones(batch_size)
return wrapper
class ChineseDataset(object):
def __init__(self):
mat_annotation = loadmat(label_mat)
self.img_dir = img_dir
self.filenames = mat_annotation['img']
self.labels = mat_annotation['label']
self.label_length = mat_annotation['label_length'][0]
def get_train_num(self):
return int(len(self.filenames) * 0.8)
def get_valid_num(self):
return len(self.filenames) - int(len(self.filenames) * 0.8)
def get_img_data(self, index):
img = cv2.imread(os.path.join(self.img_dir, self.filenames[index]))
img = cv2.resize(img, (max_img_weigth, max_img_height)) / 255.
label = one_hot(self.labels[index])
return img, label
@generate
def gen_train(self, batch_size):
index_all = list(range(int(len(self.filenames) * 0.8)))
return index_all, batch_size
@generate
def gen_valid(self, batch_size):
index_all = list(range(int(len(self.filenames) * 0.8), len(self.filenames)))
return index_all, batch_size