pytorch模型调参、训练相关内容

调整学习率

PyTorch已经在torch.optim.lr_scheduler为我们封装好了一些动态调整学习率的方法。调用方法如下:

# 选择一种优化器
optimizer = torch.optim.Adam(...) 
# 选择上面提到的一种或多种动态调整学习率的方法
scheduler1 = torch.optim.lr_scheduler.... 
scheduler2 = torch.optim.lr_scheduler....
...
schedulern = torch.optim.lr_scheduler....
# 进行训练
for epoch in range(100):
    train(...)
    validate(...)
    optimizer.step()
    # 需要在优化器参数更新之后再动态调整学习率
	scheduler1.step() 
	...
    schedulern.step()

也可通过自定义函数来定义学习率变化。

微调

修改指定层,其余参数不变

# 冻结原预训练模型的参数
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

import torchvision.models as models
# 冻结参数的梯度
feature_extract = True
# 加载预训练好的模型
model = models.resnet18(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
# 输出部分的全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(in_features=num_ftrs, out_features=4, bias=True)

第三方库timm

timm 是Ross Wightman创建torchvision的扩充库,提供了许多计算机视觉的SOTA模型。可以通过以下命令获取预训练好的模型清单:

import timm
avail_pretrained_models = timm.list_models(pretrained=True)#还支持模糊查询
  • 模型微调代码
import timm
import torch
# 将1000类改为10类输出
model = timm.create_model('resnet34',num_classes=10,pretrained=True)
# 改变输入通道数
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1)
  • 模型存储、加载
torch.save(model.state_dict(),'./checkpoint/timm_model.pth')
model.load_state_dict(torch.load('./checkpoint/timm_model.pth'))

半精度

  • 定义:PyTorch浮点数存储方式从torch.float32改为torch.float16称为半精度
  • 目的:在实际应用过程中,保证数据精度需求的前提下,减少显存占用。
  • 应用场合:数据本身的size比较大,如3D图像、视频等。
  • 设置
from torch.cuda.amp import autocast

# 用autocast装饰模型中的forward函数
@autocast()   
def forward(self, x):
    ...
    return x
# 训练
for x in train_loader:
    x = x.cuda()
    #在将数据输入模型及其之后的部分放入with autocast()
    with autocast():
    output = model(x)
    ...

数据增强

图片数据可以使用imgaug库以及Albumentations库来进行数据增强。

调参

  • 传参过程
import argparse #python内置,无需安装

# 创建ArgumentParser()对象
parser = argparse.ArgumentParser()

# 添加参数
parser.add_argument('-o', '--output', action='store_true', 
    help="shows output")
# action = `store_true` 会将output参数记录为True
parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3') 
# type 规定了参数的格式
# default 规定了默认值
parser.add_argument('--batch_size', type=int, required=True, help='input batch size')  
# required=True 意为必选参数

# 使用parse_args()解析函数
args = parser.parse_args()

if args.output:
    print("This is some output")
    print(f"learning rate:{args.lr} ")
  • 超参数的操作配置文件config.py
import argparse  
  
def get_options(parser=argparse.ArgumentParser()):  
  
    parser.add_argument('--workers', type=int, default=0,  
                        help='number of data loading workers, you had better put it '  
                              '4 times of your gpu')  
  
    parser.add_argument('--batch_size', type=int, default=4, help='input batch size, default=64')  
  
    parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for, default=10')  
  
    parser.add_argument('--lr', type=float, default=3e-5, help='select the learning rate, default=1e-3')  
  
    parser.add_argument('--seed', type=int, default=118, help="random seed")  
  
    parser.add_argument('--cuda', action='store_true', default=True, help='enables cuda')  
    parser.add_argument('--checkpoint_path',type=str,default='',  
                        help='Path to load a previous trained model if not empty (default empty)')  
    parser.add_argument('--output',action='store_true',default=True,help="shows output")  
  
    opt = parser.parse_args()  
  
    if opt.output:  
        print(f'num_workers: {opt.workers}')  
        print(f'batch_size: {opt.batch_size}')  
        print(f'epochs (niters) : {opt.niter}')  
        print(f'learning rate : {opt.lr}')  
        print(f'manual_seed: {opt.seed}')  
        print(f'cuda enable: {opt.cuda}')  
        print(f'checkpoint_path: {opt.checkpoint_path}')  
  
    return opt  
  
if __name__ == '__main__':  
    opt = get_options()
  • 调用
import config

opt = config.get_options()

manual_seed = opt.seed
num_workers = opt.workers
batch_size = opt.batch_size
lr = opt.lr
niters = opt.niters
checkpoint_path = opt.checkpoint_path

# 随机数的设置,保证复现结果
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

...


if __name__ == '__main__':
	set_seed(manual_seed)
	for epoch in range(niters):
		train(model,lr,batch_size,num_workers,checkpoint_path)
		val(model,lr,batch_size,num_workers,checkpoint_path)

参考

datawhale 深入浅出pytorch

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值