t_loss = np.inf
learning_rate = 5*1e-4
loss_fn = torch.nn.MSELoss(reduction='mean')
log = "./loss_log.txt"
model = Net_9x9_theta_learn_grid()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
pre_train = False
if pre_train:
model.load_state_dict(torch.load('./min_9x9_model_img_out.pkl'))
optimizer.load_state_dict(torch.load("./latest_optimizer.pth"))
with open(log,"r") as file:
for line in file:
pass
last_line = line
start_epoch =int(last_line.split(' ')[0]) +1
t_loss = float(last_line.split(' ')[-2])
v_loss = float(last_line.split(' ')[-1])
print('model loaded continue train with epoch %d and min loss %f'%(start_epoch,t_loss))
else:
start_epoch = 1
t_loss = np.inf
v_loss = np.inf
start_epoch = 0
end_epoch = 10000
learning_rate = 5*1e-4
loss_fn = torch.nn.MSELoss(reduction='mean')
train_loss = []
axis_x = []
for epoch_i in range(start_epoch, end_epoch+1):
start = time()
model.train()
output1 = model(input,8,9)
output1 = output1.reshape([1,1,img_size,img_size])
loss = loss_fn(output1, target)
train_loss.append(loss.item())
axis_x.append(epoch_i)
end = time()
time_cost = end - start
print('epoch: %d, train loss: %4f, input loss:%4f, time cost:%4f'%(epoch_i,loss.item(),loss_fn(input, target).item(),time_cost))
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.save(model.state_dict(),'./lateat_9x9_model_img_out.pkl')
torch.save(optimizer.state_dict(),"./latest_optimizer_9x9.pth")
if np.mean(train_loss) < t_loss:
t_loss = np.mean(train_loss)
torch.save(model.state_dict(),"./min_9x9_model_img_out.pkl")
with open(log,"a") as file:
file.write(str(epoch_i) + " " + str(np.mean(train_loss)))