跟着论文代码学习编码第一天:main.py

根据ESRT和LBNet的代码学习编码。首先看main.py。

1.  args模块  B站小侯学府的args讲解

需要三步,创建argparse.ArgumentParser解释器,添加add_argument参数,解析参数parse_args:

# 创建argparse.ArgumentParser解析器
parser = argparse.ArgumentParser(description='LBNet')

# 添加参数add_argument('添加的变量', type=变量类型, default=默认值, help='使用-h时告知的作用')
parser.add_argument('--n_threads', type=int, default=12,
                    help='number of threads for data loading')
parser.add_argument('--cpu', action='store_true',
                    help='use cpu only')


# 解析参数parse_args
args = parser.parse_args()      # 使用时 args.参数名 即可调用

2.  if '__name__' == '__main__':   CSDN农村詹姆斯的讲解

一个.py文件有两种运行方式,一种是直接运行,另一种是模块重用,被jimport调用执行。而有了这串代码,if'__name__' == '__main__':中的语句只能在直接运行这个文件时才可执行;当文件被import调用后,if'__name__' == '__main__':中的语句不会被执行了。

3.  seed  CSDN博主zyrlia的详细介绍

seed是随机种子的意思,使用seed是为了使我们复现论文作者代码时,与作者产生的随机数相同,使得结果偏差最小。

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)     # 设置CPU生成随机数的种子
    if torch.cuda.device_count() == 1:      # GPU数量是否为1
        torch.cuda.manual_seed(seed)
    else:
        torch.cuda.manual_seed_all(seed)    # 为所有的GPU设置种子

有些论文不会单独定义set_seed函数,但是都会设置随机种子。

4.  checkpoint

用来存储模型的,段哥说会用就行,我就先跳了,有空再分析啦~~

5.  loading datasets  B站刘二大人(我吹爆,适合零基础)

print("===> Loading datasets")

trainset = DIV2K.div2k(args)
testset = Set5_val.DatasetFromFolderVal("Test_Datasets/Set5/",
                                       "Test_Datasets/Set5_LR/x{}/".format(args.scale),
                                       args.scale)
training_data_loader = DataLoader(dataset=trainset, num_workers=args.threads, batch_size=args.batch_size, shuffle=True, pin_memory=True, drop_last=True)
testing_data_loader = DataLoader(dataset=testset, num_workers=args.threads, batch_size=args.testBatchSize,
                                 shuffle=False)

一般都会用上面给出的格式,载入dataset,然后实例化dataloader。

6.  building models  刘二大人的视频建议从头看

print("===> Building models")
args.is_train = True


model = esrt.ESRT(upscale = args.scale)#architecture.IMDN(upscale=args.scale)    # 主要

将model实例化。

7.  loss and opitimizer  

l1_criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)

优化器其实就是训练采用的方法。

8.  training cycle

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criierion(y_pred, y_data)
    print(epoch, loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

在LBNet代码中,将7,8封装到了Trainer中:

t = Trainer(args, loader, model, loss, checkpoint)
    while not t.terminate():
        t.train()
        t.test()
    checkpoint.done()

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值