# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11
#
from prepare_data import *
from LLNet import *
# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES = 14584
def read_batch_data(batch_size, root_dir, split="training"):
''' read batch data '''
train_startIdx = 0
test_startIdx = 0;
readObj = LLNet_Data(root_dir, split)
while train_startIdx < TRAIN_NUM_SAMPLES:
batch_data = []
batch_label = []
idx = 0
while idx < batch_size:
data, label = readObj.read_interface(train_startIdx)
# print("data = {}".format(data))
# print("label = {}".format(label))
train_startIdx += 1
if (data is None) or (label is None):
continue
else:
batch_data.append(data)
batch_label.append(label)
idx += 1
yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32)
def train_pretrain(batch_size, root_dir, beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
''' train pre-train '''
model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
model.build_graph_pretrain()
for epoch in range(epochs):
avg_loss = 0
idx = 1
for (batch_data, batch_label) in read_batch_data(batch_size, root_dir, split):
pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)
# print("pretrai
LLNet模型实现——模型训练(完结)
最新推荐文章于 2025-05-24 08:52:14 发布