训练DPT:由测试test到训练train图像的一个epochs的optimize.zero_grad() loss.backward() optimizer.step()

不知道大家有没有这样的感受,很多研究型论文通常会给出他们的test.py文件,但是其train.py文件往往是空白的,这时候感觉文章的test确实很nice,就想去探究其更原始、最优参数,训练出的参数过程。那么这里就不得不开始研究如何从test进一步得到train呢?

我想的是,绝对不止我一个人想这么干,前人必然都有不同的经验分享,这里仅是我初步的探索test到构建train:

1.研究现有的train.py文件:

  • 我们完全可以从现有研究所携带的train文件去学习他是如何进行训练的搭建,这里还是很安利adabin的训练文件,它整体架构罗列详细,同时结合了在线可视化wandb工具。

2.对比研究可以发现test和train的普遍区别:

  • test:简单从inpu读取图像,经过模型model处理,直接输出output,图像张数不限制。
  • train:调用数据集(很多张图像),训练方式(CPU/GPU/GPUs),损失函数(loss),优化器(Optimizer),调度器(Scheduler),训练流程(train loop),验证流程(validation loop)。

其中,很多训练都是大同小异,将test进行修改时,除了添加缺失的模块,还要修改test的流程,不再是简单的输入输出,而是进一步由数据集进行,同时加入损失,此时,test也就被修改成了train。

3.实践

以DPT为例实践,学习adabins的train文件,将其test.py进行修改成train,主要是图像的预处理涉及正确读取,以及train中的遍历epochs(详细举例可参考)

  • optimize.zero_grad():
    • 函数会遍历模型的所有参数,通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。
  • loss.backward():
    • 损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。
      • PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。
  • optimizer.step():
    • step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。
  • model.train():

    • 使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。

  • model.eval():

    • PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。

  • with torch.no_grad():
    • PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。

DPT的train loop 

for epoch in range(args.epoch, epochs):
################################# Train loop ####################################
        if should_log: wandb.log({"Epoch": epoch}, step=step)
        for i, batch in tqdm(enumerate(train_loader), desc=f"Epoch: {epoch + 1}/{epochs}. Loop: Train",
                             total=len(train_loader)) if is_rank_zero(
                args) else enumerate(train_loader):

            optimizer.zero_grad()

            #完成tensor数据向GPU复制粘贴
            img = batch['image'].to(device)
            depth = batch['depth'].to(device)

            pred = model.forward(img)
            #print(pred)
            pred =pred.unsqueeze(1)
            #print(pred)
            
            mask = depth
            #loss = criterion_ueff(pred, depth, mask=mask.to(torch.bool), interpolate=True)
            loss = ss(pred, depth, mask=mask.to(torch.bool),interpolate=True)
            
            loss.backward()
            
            nn.utils.clip_grad_norm_(model.parameters(), 0.1)  # optional
            optimizer.step()
            if should_log and step % 5 == 0:

                wandb.log({f"Train/{ss.name}": loss.item()}, step=step)
               # wandb.log({f"Train/{criterion_bins.name}": l_chamfer.item()}, step=step)

            step += 1
            scheduler.step()

            #########################################################################

    return model

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
帮我翻译以下代码 <select id="getStatisticalInfoDetailCount" resultType="java.lang.Integer"> select count()from(select psi.station_id,cda.cdc as count_device_code,cda.appkey from (select count(dci.device_code) as cdc,dci.appkey from iledproduction.device_code_info dci where CONV(left(dci.device_code,2),16,10) in <foreach collection="manufacturerIdList" item="item" index="index" open="(" close=")" separator=","> #{item} </foreach> <if test="propMap.size>0 and propMap!=null"> <foreach collection="propMap" index="index" item="item"> <if test='index.toString=="device_type"'> and substr(dci.device_code,3,4) = "${item[1]}" </if> <if test='index.toString=="produce_date"'> and STR_TO_DATE(substr(dci.device_code,7,6),'%y%m%d') BETWEEN "${item[1]}" AND DATE_ADD("${item[2]}",INTERVAL 1 DAY) </if> </foreach> </if> and dci.cert_import_status = 1 group by appkey) as cda left join (select s1.station_id as station_id,s1.old_appkey as appkey from iledproduction.appkey_log s1 union (select s2.station_id as station_id,s2.new_appkey as appkey from iledproduction.appkey_log s2) union (select s3.station_id as station_id,s3.appkey as appkey from iledproduction.production_station_info s3) ) psi on cda.appkey = psi.appkey) a left join (select p.,dpt.manufacturer_id,dpt.department_name from iledproduction.production_station_info p ,iledproduction.department dpt where p.department_key=dpt.department_key) ps on a.station_id = ps.station_id <if test="propMap.size>0 and propMap!=null"> <foreach collection="propMap" index="index" item="item"> <if test='index.toString=="manufacturer_id"'> where ps.manufacturer_id = ${item[1]} </if> </foreach> </if> </select>
07-14
这段代码是一个SQL查询语句,用于统计满足特定条件的数据行数。以下是代码的翻译: ```sql <select id="getStatisticalInfoDetailCount" resultType="java.lang.Integer"> SELECT count(*) FROM ( SELECT psi.station_id, cda.cdc AS count_device_code, cda.appkey FROM ( SELECT count(dci.device_code) AS cdc, dci.appkey FROM iledproduction.device_code_info dci WHERE CONV(left(dci.device_code,2),16,10) IN <foreach collection="manufacturerIdList" item="item" index="index" open="(" close=")" separator=","> #{item} </foreach> <if test="propMap.size>0 and propMap!=null"> <foreach collection="propMap" index="index" item="item"> <if test='index.toString=="device_type"'> AND substr(dci.device_code,3,4) = "${item[1]}" </if> <if test='index.toString=="produce_date"'> AND STR_TO_DATE(substr(dci.device_code,7,6),'%y%m%d') BETWEEN "${item[1]}" AND DATE_ADD("${item[2]}",INTERVAL 1 DAY) </if> </foreach> </if> AND dci.cert_import_status = 1 GROUP BY appkey ) AS cda LEFT JOIN ( SELECT s1.station_id AS station_id, s1.old_appkey AS appkey FROM iledproduction.appkey_log s1 UNION SELECT s2.station_id AS station_id, s2.new_appkey AS appkey FROM iledproduction.appkey_log s2 UNION SELECT s3.station_id AS station_id, s3.appkey AS appkey FROM iledproduction.production_station_info s3 ) psi ON cda.appkey = psi.appkey ) a LEFT JOIN ( SELECT p.*, dpt.manufacturer_id, dpt.department_name FROM iledproduction.production_station_info p, iledproduction.department dpt WHERE p.department_key = dpt.department_key ) ps ON a.station_id = ps.station_id <if test="propMap.size>0 and propMap!=null"> <foreach collection="propMap" index="index" item="item"> <if test='index.toString=="manufacturer_id"'> WHERE ps.manufacturer_id = ${item[1]} </if> </foreach> </if> </select> ``` 希望能对你有所帮助。如果你有任何其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MengYa_DreamZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值