保存网络模型和优化器使得下次训练从断开的epoch地方开始继续训练

# learn theta -> grid -> img_out:5x5
# loss 为 MSE
t_loss = np.inf
learning_rate = 5*1e-4
loss_fn   = torch.nn.MSELoss(reduction='mean') 

log = "./loss_log.txt"
# model = STN(0.5,6,img_size)
# model = STN(0.76,18,img_size,1,16*16*20,img_size*img_size)
model = Net_9x9_theta_learn_grid()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
pre_train = False
if pre_train:
    model.load_state_dict(torch.load('./min_9x9_model_img_out.pkl'))
    optimizer.load_state_dict(torch.load("./latest_optimizer.pth"))
    with open(log,"r") as file:
        for line in file:
            pass
        last_line = line
    start_epoch =int(last_line.split(' ')[0]) +1
    t_loss = float(last_line.split(' ')[-2])
    v_loss = float(last_line.split(' ')[-1])
    print('model loaded continue train with epoch %d and min loss %f'%(start_epoch,t_loss))
else:
    start_epoch = 1
    t_loss = np.inf
    v_loss = np.inf
    # print('Pre_trained model loaded')

start_epoch = 0
end_epoch = 10000
learning_rate = 5*1e-4
loss_fn   = torch.nn.MSELoss(reduction='mean') 


train_loss = []
axis_x = []

for epoch_i in range(start_epoch, end_epoch+1):
    start = time()
    model.train()

    output1 = model(input,8,9)
    output1 = output1.reshape([1,1,img_size,img_size])
    loss = loss_fn(output1, target)
    train_loss.append(loss.item())
    axis_x.append(epoch_i)
    end = time()
    time_cost = end - start
    print('epoch: %d, train loss: %4f, input loss:%4f, time cost:%4f'%(epoch_i,loss.item(),loss_fn(input, target).item(),time_cost))
    # ssim_12 = ssim(output, target, window_size=11, window=None, size_average=True, full=False, val_range=None)
    # loss = 1 - ssim_12
    # train_loss.append(loss.item())
    # axis_x.append(epoch_i)
    # print('epoch: %d, train loss: %4f, input loss:%4f, time cost:%4f'%(epoch_i,loss.item(),ssim_input.item(),time_cost))

    loss.backward()
    optimizer.step()      
    optimizer.zero_grad()
    torch.save(model.state_dict(),'./lateat_9x9_model_img_out.pkl')
    torch.save(optimizer.state_dict(),"./latest_optimizer_9x9.pth")

    if np.mean(train_loss) < t_loss:
        t_loss = np.mean(train_loss)
        torch.save(model.state_dict(),"./min_9x9_model_img_out.pkl") 
    with open(log,"a") as file:
        file.write(str(epoch_i) + " " + str(np.mean(train_loss)))

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值