根据ESRT和LBNet的代码学习编码。首先看main.py。
1. args模块 B站小侯学府的args讲解
需要三步,创建argparse.ArgumentParser解释器,添加add_argument参数,解析参数parse_args:
# 创建argparse.ArgumentParser解析器
parser = argparse.ArgumentParser(description='LBNet')
# 添加参数add_argument('添加的变量', type=变量类型, default=默认值, help='使用-h时告知的作用')
parser.add_argument('--n_threads', type=int, default=12,
help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
help='use cpu only')
# 解析参数parse_args
args = parser.parse_args() # 使用时 args.参数名 即可调用
2. if '__name__' == '__main__': CSDN农村詹姆斯的讲解
一个.py文件有两种运行方式,一种是直接运行,另一种是模块重用,被jimport调用执行。而有了这串代码,if'__name__' == '__main__':中的语句只能在直接运行这个文件时才可执行;当文件被import调用后,if'__name__' == '__main__':中的语句不会被执行了。
3. seed CSDN博主zyrlia的详细介绍
seed是随机种子的意思,使用seed是为了使我们复现论文作者代码时,与作者产生的随机数相同,使得结果偏差最小。
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 设置CPU生成随机数的种子
if torch.cuda.device_count() == 1: # GPU数量是否为1
torch.cuda.manual_seed(seed)
else:
torch.cuda.manual_seed_all(seed) # 为所有的GPU设置种子
有些论文不会单独定义set_seed函数,但是都会设置随机种子。
4. checkpoint
用来存储模型的,段哥说会用就行,我就先跳了,有空再分析啦~~
5. loading datasets B站刘二大人(我吹爆,适合零基础)
print("===> Loading datasets")
trainset = DIV2K.div2k(args)
testset = Set5_val.DatasetFromFolderVal("Test_Datasets/Set5/",
"Test_Datasets/Set5_LR/x{}/".format(args.scale),
args.scale)
training_data_loader = DataLoader(dataset=trainset, num_workers=args.threads, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True)
testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize,
shuffle=False)
一般都会用上面给出的格式,载入dataset,然后实例化dataloader。
6. building models 刘二大人的视频建议从头看
print("===> Building models")
args.is_train = True
model = esrt.ESRT(upscale = args.scale)#architecture.IMDN(upscale=args.scale) # 主要
将model实例化。
7. loss and opitimizer
l1_criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
优化器其实就是训练采用的方法。
8. training cycle
for epoch in range(1000):
y_pred = model(x_data)
loss = criierion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
在LBNet代码中,将7,8封装到了Trainer中:
t = Trainer(args, loader, model, loss, checkpoint)
while not t.terminate():
t.train()
t.test()
checkpoint.done()