加载预训练模型

class EMA():
    def __init__(self, decay=0.999):
        self.decay = decay
        self.shadow = {}

    def register(self, name, val):
        self.shadow[name] = val.cpu().detach()

    def get(self, name):
        return self.shadow[name]

    def update(self, name, x):
        assert name in self.shadow
        new_average = (1.0 - self.decay) * x.cpu().detach() + self.decay * self.shadow[name]
        self.shadow[name] = new_average.clone()




ema = EMA(args.ema_decay)
    for name, param in network.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)
def save_checkpoint(state, epoch, dst, is_best):
    filename = os.path.join(dst, str(args.start_epoch + epoch)) + '.pth.tar'
    torch.save(state, filename)
    if is_best:
        dst_best = os.path.join(dst, 'model_best', str(epoch)) + '.pth.tar'
        shutil.copyfile(filename, dst_best)

#args.num_epoches是总轮数,args.start_epoch是预训练轮数
for epoch in range(args.num_epoches - args.start_epoch):


#保存network,network_ema,optimizer,W,epoch到checkpoint_dir。
state =                        {'network':network.state_dict(),'network_ema':ema.shadow,'optimizer':optimizer.state_dict(),'W': compute_loss.W, 'epoch': args.start_epoch + epoch}
save_checkpoint(state,  epoch, args.checkpoint_dir, False)



logging.info('Epoch:  [{}|{}], train_time: {:.3f}, train_loss:{:.3f}'.format(args.start_epoch + epoch, args.num_epoches,train_time, train_loss))

logging.info('image_precision: {:.3f}, text_precision: {:.3f}'.format(image_precision, text_precision))

adjust_lr(optimizer, args.start_epoch + epoch, args)

scheduler.step()
for param in optimizer.param_groups:
        print('lr:{}'.format(param['lr']))
        break

logging.info('Train done')

logging.info(args.checkpoint_dir)

logging.info(args.log_dir)


#初始化network
def network_config(args, split='train', param=None, resume=False, model_path=None, ema=False):
    network = Model(args)
#使用多个GPU
    network = nn.DataParallel(network).cuda()
#增加运行效率
    cudnn.benchmark = True

    args.start_epoch = 0
1、resume: 'whether or not to restore the pretrained whole model'
使用方法 !python train.py --resume
resume为true时,加载model_path里的.tar文件,args.start_epoch等于文件保存的epoch+1,network_dict是文件里的network,
如果ema为true,向network_dict添加network_ema,然后network加载network_dict

    if resume:
#train_config里的args.model_path为model_path
        directory.check_file(model_path, 'model_file')

        checkpoint = torch.load(model_path)
        args.start_epoch = checkpoint['epoch'] + 1

        network_dict = checkpoint['network']
#是否使用指数加权平均        
        if ema:
#向日志中添加信息
            logging.info('==> EMA Loading')
            network_dict.update(checkpoint['network_ema'])
        network.load_state_dict(network_dict,False)
        print('==> Loading checkpoint "{}"'.format(model_path))

resume为false,network_dict是初始化的network.state_dict,加载start往后的参数


    else:
        # pretrained
        if model_path is not None:
            print('==> Loading from pretrained models')
            network_dict = network.state_dict()
            if args.image_model == 'mobilenet_v1':
                cnn_pretrained = torch.load(model_path)['state_dict']
                start = 7
            else:
                cnn_pretrained = torch.load(model_path)
                start = 0
            # process keyword of pretrained model
            prefix = 'module.image_model.'
            pretrained_dict = {prefix + k[start:] :v for k,v in cnn_pretrained.items()}
# 将pretained_dict里不属于model_dict的键剔除掉            
            pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in network_dict}
            network_dict.update(pretrained_dict)
            network.load_state_dict(network_dict)

    #process optimizer params
    if split == 'test':
        optimizer = None
    else:
        # optimizer
        # different params for different part
        cnn_params = list(map(id, network.module.image_model.parameters()))
        other_params = filter(lambda p: id(p) not in cnn_params, network.parameters())
        other_params = list(other_params)
        if param is not None:
            other_params.extend(list(param))
        param_groups = [{'params':other_params},
            {'params':network.module.image_model.parameters(), 'weight_decay':args.wd}]
        optimizer = torch.optim.Adam(
            param_groups,
            lr = args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon)
        if resume:
            optimizer.load_state_dict(checkpoint['optimizer'])

    print('Total params: %2.fM' % (sum(p.numel() for p in network.parameters()) / 1000000.0))
    # seed
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    torch.cuda.manual_seed_all(manualSeed)

    return network, optimizer

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值