冻结训练。

参考资料:

yolov5——训练策略 - overfit.cn

一般选择前面几层stage,不让它参与训练,因为觉得它已经训练的比较好了,前面几层一般是提取的公共特征。实现很简单,几行代码。

代码:

import argparse
import os

import torch.distributed as dist

import torch
import torchvision
from torch import nn
from torchvision import transforms

LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))  # https://pytorch.org/docs/stable/elastic/run.html
RANK = int(os.getenv('RANK', -1))
WORLD_SIZE = int(os.getenv('WORLD_SIZE', 2))


def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
    parser.add_argument('--freeze', nargs='+', type=int, default=[2], help='Freeze layers: backbone=10, first3=0 1 2')
    args = parser.parse_args()
    print(args.local_rank)

    return args


def main(args):
    getpid_X = os.getpid()
    print(f'当前进程id:{getpid_X}')
    print(f'local_rannk:{LOCAL_RANK}, rannk:{RANK}, world_size:{WORLD_SIZE}')
    if LOCAL_RANK != -1:
        print(f'local_rank:{LOCAL_RANK}')
        torch.cuda.set_device(LOCAL_RANK)
        device = torch.device('cuda', LOCAL_RANK)
        dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo")
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    data_set = torchvision.datasets.MNIST('~/DATA/', train=True,
                                          transform=trans, target_transform=None, download=True)
    train_sampler = torch.utils.data.distributed.DistributedSampler(data_set)
    data_loader_train = torch.utils.data.DataLoader(dataset=data_set,
                                                    batch_size=256,
                                                    sampler=train_sampler,
                                                    num_workers=2,
                                                    pin_memory=True)
    net = torchvision.models.resnet18(num_classes=10)
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 10)
    net = net.cuda()
    # 冻结训练
    # freeze = [f'model.{x}.' for x in (args.freeze if len(args.freeze) > 1 else range(args.freeze[0]))]  # layers to freeze
    freeze = ['layer1', 'layer2']
    for k, v in net.named_parameters():
        v.requires_grad = True  # train all layers
        # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
        # if any(x in k for x in freeze):
        if k.split('.')[0] in freeze:
            print(f'freezing {k}')
            v.requires_grad = False
    # ddp 模型
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[LOCAL_RANK],
                                                    output_device=LOCAL_RANK)
    criterion = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(1):
        for i, data in enumerate(data_loader_train):
            images, labels = data
            images = images.repeat(1, 3, 1, 1)
            # 要将数据送入指定的对应的gpu中
            print(f'数据移动到驱动前:{LOCAL_RANK}, images:{images.device}, labels:{labels.device}')
            images = images.to(LOCAL_RANK, non_blocking=True)
            labels = labels.to(LOCAL_RANK, non_blocking=True)
            print(f'数据移动到驱动后:{LOCAL_RANK}, images:{images.device}, labels:{labels.device}')
            opt.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            if i % 10 == 0:
                print("loss: {}".format(loss.item()))


if __name__ == "__main__":
    args = parse_opt()
    main(args)

yolo中的代码:

    freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))]  # layers to freeze
    for k, v in model.named_parameters():
        v.requires_grad = True  # train all layers
        # v.register_hook(lambda x: torch.nan_to_num(x))  # NaN to 0 (commented for erratic training results)
        if any(x in k for x in freeze):
            LOGGER.info(f'freezing {k}')
            v.requires_grad = False

kaggle运行

!python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 --node_rank=0 --master_addr="localhost" --master_port=12355  ../input/hellow11/dist_test1.py

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值