LLNet模型实现——模型训练(完结)

本文详述了LLNet模型的训练过程,包括使用稀疏自动编码和图像增强技术的实现细节。同时,针对Ubuntu环境下保存checkpoint时遇到的问题,作者提供了修正方案,并分享了修改前后的代码供读者参考。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# Ref: LLNet: Deep Autoencoders for Low-light Image Enhancement
#
# Author: HSW
# Date: 2018-05-11 
#

from prepare_data import * 
from LLNet        import * 

# 训练样本/测试样本的个数
TRAIN_NUM_SAMPLES = 14584
TEST_NUM_SAMPLES  = 14584

def read_batch_data(batch_size, root_dir, split="training"):
    ''' read batch data '''
    train_startIdx = 0
    test_startIdx = 0;

    readObj = LLNet_Data(root_dir, split)
    
    while train_startIdx < TRAIN_NUM_SAMPLES:
        batch_data = []
        batch_label = []

        idx = 0

        while idx < batch_size:
        
            data, label = readObj.read_interface(train_startIdx)
            
            # print("data = {}".format(data))
            
            # print("label = {}".format(label))
            
            train_startIdx += 1
            
            if (data is None) or (label is None):
            	continue
            else:
            	batch_data.append(data)
            	batch_label.append(label)
            	idx += 1
		
        yield np.array(batch_data, dtype = np.float32), np.array(batch_label, dtype=np.float32) 
    
    


def train_pretrain(batch_size, root_dir,  beta_pretrain, lambda_pretrain, lambda_finetune, split="training", epochs = 1001):
    ''' train pre-train '''
    model = LLNet_Model(beta_pretrain, lambda_pretrain, lambda_finetune, transfer_function=tf.nn.sigmoid, LLnet_Shape=(289,847,578, 289), sparseCoef = 0.05)
    model.build_graph_pretrain()
    
    for epoch in range(epochs):
        avg_loss = 0
        idx = 1
        for (batch_data, batch_label) in  read_batch_data(batch_size, root_dir, split):

            pretrain_loss = model.run_fitting_pretrain(batch_data, batch_label)

            # print("pretrai
评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值