参考rcan的程序
在utils.py添加
import torch
import math
import os
from functools import reduce
import numpy as np
import imageio as misc
import time
import datetime
import torch.optim as optim
#以上import哪个有用不知道,懒得试 全复制过来加载了
class checkpoint():
def __init__(self):
# self.args = args
self.ok = True
self.log = torch.Tensor()
# now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
self.dir = r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
#_make_dir(self.dir + '/model')
open_type = 'a' if os.path.exists(self.dir + '/log_614.txt') else 'w'
self.log_file = open(self.dir + '/log_614.txt', open_type)
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
在train.py和main里改动
main.py
from __future__ import print_function
import argparse
from math import log10
import os
from typing import Any
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
#from dataset import MyDataset,dataset_split
from utils import *
from train import *
#from data import get_training_set
import pdb
import socket
import time
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Super Res Example')
parser.add_argument('--batchSize', type=int, default=640, help='training batch size')
parser.add_argument('--nEpochs', type=int, default=1000, help='number of epochs to train for')
parser.add_argument('--snapshots', type=int, default=50, help='Snapshots')
parser.add_argument('--lr', type=float, default=1e-2, help='Learning Rate. Default=0.0001')
parser.add_argument('--threads', type=int, default=5, help='number of threads for data loader to use') # 由1 改为0
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--pretrained_sr', default='MIX2K_LR_aug_x4dl10DBPNITERtpami_epoch_399.pth', help='sr pretrained base model')
parser.add_argument('--pretrained', type=bool, default=False)
parser.add_argument('--model_type', type=str, default='MyCNN')
#parser.add_argument('--data_root', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\dataset', help='all dataset Location ')
#parser.add_argument('--train_path', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\dataset\train_small', help='dataset Location')
parser.add_argument('--train_path', default=r'C:\train_set', help='dataset Location')
#parser.add_argument('--model_save_path', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict\test_73.pth', help='model_save_path')
parser.add_argument('--model_save', default=r'D:\LY\DLmodify\3Deblur_cnn\Deblur_cvpr\document\3test\modelDict', help='model_save')
parser.add_argument('--start_iter', type=int, default=1, help='Starting Epoch')
checkpoint = checkpoint() #增加的
opt = parser.parse_args()
#gpus_list = range(opt.gpus)
hostname = str(socket.gethostname())
cudnn.benchmark = True
print(opt)
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
#改动1
#print('Total number of parameters: %d' % num_params)
checkpoint.write_log('Total number of parameters: %d' % num_params)
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#改动2
checkpoint.write_log('===> Loading datasets')
#print('===> Loading datasets')
train_path = opt.train_path
train_path = train_path
train_ds = MyDataset(train_path)
new_train_ds, validate_ds = dataset_split(train_ds, 0.8)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
new_train_loader = torch.utils.data.DataLoader(new_train_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
validate_loader = torch.utils.data.DataLoader(validate_ds, batch_size=opt.batchSize,shuffle=True, pin_memory=True, num_workers=3)
print('===> Building model ', opt.model_type)
if opt.model_type == 'MyCNN':
net = MyCNN()
criterion = torch.nn.CrossEntropyLoss()
# criterion = torch.nn.NLLLoss()
print('---------- Networks architecture -------------')
print_network(net)
print('----------------------------------------------')
optimizer = optim.SGD(net.parameters(), lr=opt.lr, momentum=0.8)
#optimizer = optim.Adam(net.parameters(),lr=opt.lr) # 使用Adam
for epoch in range(opt.start_iter, opt.nEpochs + 1):
#改动3 增加一个checkpoint位置
train(epoch,new_train_loader,device,net,criterion,optimizer,checkpoint)
validate(validate_loader, device, net, criterion)
#改动4
checkpoint.write_log("validate acc:{}".format(validate(validate_loader,device,net,criterion)))
#print("validate acc:",validate(validate_loader,device,net,criterion))
if (epoch+1) % 70 == 0:
for param_group in optimizer.param_groups:
param_group['lr'] /= 5.0
#改动5
#print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
checkpoint.write_log('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
if (epoch + 1) % (opt.snapshots) == 0:
model_save_path = opt.model_save + r'\classify73_{}.pth'.format(epoch)
torch.save(net, model_save_path)
if __name__ == '__main__':
main()
在train.py里
from dataset import MyDataset,dataset_split
#from config import config as C
from model import MyCNN
import torch.optim as optim
from utils import *
import time
# 改动1 增加一个ckp
def train( epoch, train_loader, device, model, criterion, optimizer,ckp):
#改动2
ckp = ckp
model = model.to(device)
criterion = criterion.to(device)
#for epoch in range(epochs):
model.train()
top1 = AvgrageMeter()
train_loss = 0.0
t2 = time.time()
#改动3
ckp.write_log('Learning rate : lr={}'.format(optimizer.param_groups[0]['lr']))
#print('Learning rate : lr={}'.format(optimizer.param_groups[0]['lr']))
for i, data in enumerate(train_loader, 0): # 0是下标起始位置默认为0
t3 = time.time()
inputs, labels = data[0].to(device), data[1].to(device)
# 初始为0,清除上个batch的梯度信息
#print(inputs.shape)
optimizer.zero_grad()
t0 = time.time()
outputs = model(inputs)
outputs = torch.squeeze(outputs)
#print('output shape',outputs.shape)
#print('output data',outputs.data)
#print(torch.max(outputs.data,1))
#print(labels.shape)
#print(outputs.shape)
# print(torch.squeeze(outputs).shape)
#print('train predicted',train_predicted)
#print('label',labels.data)
#loss = criterion(outputs,labels)
loss = criterion(outputs,labels)
#print('loss.data',loss.data)
#print('loss.item',loss.item())
t1 = time.time()
loss.backward()
optimizer.step()
prec1, prec2 = accuracy(outputs, labels, topk=(1, 2))
n = inputs.size(0)
top1.update(prec1.item(), n)
train_loss += loss.data
#print("===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||Timer: {:.4f} sec || Timer: {:.4f} sec.".format(epoch, i, len(train_loader), loss.data,train_correct/train_total*100, (t2 - t0),(t1 - t0)))
#改改4
ckp.write_log(
"===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||train_loader Timer: {:.4f} sec || Timer: {:.4f} sec.".format(
epoch, i, len(train_loader), train_loss / (i + 1), top1.avg, (t3 - t2), (t1 - t0)))
#print("===> Epoch[{}]({}/{}): Loss: {:.4f} ||train_acc:{:.4f}%||train_loader Timer: {:.4f} sec || Timer: {:.4f} sec.".format(epoch, i, len(train_loader), train_loss / (i + 1),top1.avg, (t3 - t2),(t1 - t0)))
t2 = time.time()
#print("===> Epoch {} Complete: Avg. Loss: {:.4f} || train_acc:{:.4f}% ".format(epoch, train_loss / len(train_loader),top1.avg))
#改动5
ckp.write_log('Finished Training')
#print('Finished Training')