2020-10-18

这是在github上下载的filter-pruning-geometric-median-master的代码,对pruning_cifar10.py的论述,基本上都是搜索的。(二)

model_names = sorted(name for name in models.__dict__ if name.islower() and not     
              name.startswith("__") and callable (models.__dict__[name]))
  1. sorted()函数对所有可迭代的对象进行排序操作。https://www.runoob.com/python/python-func-sorted.html
  2. models.__dict__:      这里,我是通过运行  import models       modelsname=sorted(name for name in models.__dict__)       print(modelsname)    查看输出的结果来了解models.__dict__的作用。输出结果是:['CifarResNet', 'ResNetBasicblock', 'VGG', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', 'caffe_cifar', 'densenet', 'densenet100_12', 'imagenet_resnet', 'imagenet_resnet_small', 'preresnet', 'preresnet110', 'preresnet20', 'preresnet32', 'preresnet44', 'preresnet56', 'res_utils', 'resnet', 'resnet101', 'resnet101_small', 'resnet110', 'resnet110_small', 'resnet152', 'resnet152_small', 'resnet18', 'resnet18_small', 'resnet20', 'resnet20_small', 'resnet32', 'resnet32_small', 'resnet34', 'resnet34_small', 'resnet44', 'resnet44_small', 'resnet50', 'resnet50_small', 'resnet56', 'resnet56_small', 'resnet_mod', 'resnet_mod110', 'resnet_mod20', 'resnet_mod32', 'resnet_mod44', 'resnet_mod56', 'resnet_small', 'resnext', 'resnext29_16_64', 'resnext29_8_64', 'vgg', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vgg_cifar10'],这些都是网络模型,除了以'__'开头的。
  3. name.islower(): islower()方法检测字符串是否由小写字母组成。如果都是小写,则返回True, 否则返回False.把'CIfarResNet', 'ResNetBasicblock', 'VGG' 排除掉。
  4. name.startwith("__"): 查看字符串是否以"__"开头,是返回True, 否返回False。把'__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__'排除掉。
  5. callable(): 用于检查一个对象是否是可调用的。如果返回True,对象仍然可能调用失败。但是如果返回False,则对象必然调用失败。
  6. model_names的值:['caffe_cifar', 'densenet100_12', 'preresnet110', 'preresnet20', 'preresnet32', 'preresnet44', 'preresnet56', 'resnet101', 'resnet101_small', 'resnet110', 'resnet110_small', 'resnet152', 'resnet152_small', 'resnet18', 'resnet18_small', 'resnet20', 'resnet20_small', 'resnet32', 'resnet32_small', 'resnet34', 'resnet34_small', 'resnet44', 'resnet44_small', 'resnet50', 'resnet50_small', 'resnet56', 'resnet56_small', 'resnet_mod110', 'resnet_mod20', 'resnet_mod32', 'resnet_mod44', 'resnet_mod56', 'resnext29_16_64', 'resnext29_8_64', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn']。
parser = argparse.ArgumentParser(description='Trains ResneXt on CIFAR or ImageNet',formatter_class=argparse.ArgumentDefaultsHelpFormatter)

argparse.ArgumentParser()创建一个解析器,ArgumentParser对象包含将命令行解析成Python数据类型所需要的全部信息。 

description参数简要描述这个程序做什么以及怎么做。此处说的是在CIFAR数据集或者ImageNet上训练ResneXt网络。

formatter_class参数:定义帮助文档格式。

parser.add_argument('data_path', type=str, help='Path to dataset')
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100', 'imagenet', 'svhn', 'stl10'], help='Choose between CIfar10/100 and ImageNet.')
parser.add_argument('--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture:' + '|'.join(model_names) + '(default: resnext29_8_64)')
  1. parser.add_argument():添加程序参数。type:命令行参数应当被转换成的类型;help: 对添加的程序的参数作用做一个简单的描述; choices: 可用的参数; metavar: 帮助信息中显示的参数名称,使用 print(parser.print_help()) 结果会输出 --arch ARCH 。
  2. 'data_path' : 数据集的存储路径。
  3. '--dataset' : 数据集名称。
  4. 'arch' : 使用哪个网络。
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
parser.add_argument('--learning_rate', type=float, default=0.1, help='The Learning Rate.')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', type=float, default=0.0005, help='weight decay(L2 penalty).')
parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], help='Decrease learning rate at these epochs.')
parser.add_argument('--gamms', type=float, nargs='+', default=[0.1,0.1], help='LR is multiplied by gamma on schedule, number of gammas should be equal to schedule')
  1. '--epochs' : 训练几轮。
  2. '--batch_size': 一个批次输入128张图片。
  3. 'learning_rate': 学习率大小,初始为0.1.
  4. 'momentum': 动量是为了加速学习的过程,用在权值更新的时候,为了防止网络陷入局部最小值,达不到全局最优解。Momentum不仅会使用当前的梯度,还会积累之前的梯度以确定走向。未引入momentum权值更新公式:w = w - lr \ast dw  ;引入momentum权值更新公式:v = momentum \ast v - lr \ast dw              w = w + v   ,其中v是上一次迭代的梯度。
  5. 'decay' : 权重衰减,这里用的L2 正则化。L2正则化的目的是为了让权重衰减到更小的值,在一定程度上减少模型过拟合的问题。公式是:loss = (y - yo) ^{2} + \alpha \left \| W \right \|^{2}  ,其中loss表示损失,y表示真实值, yo表示网络的预测值,\alpha是权值衰减系数,这里定义的是0.0005。
  6. 'schedule': 在第几个epoch改变学习率,改变方法:lr = lr \ast gamms
  7. 'nargs='+'' : 因为'schedule' 和 'gamms'情况特殊,默认都各有两个值,shedule是[150, 225], gamms是[0.1, 0.1]。需要从命令行读取不止一个参数。
parser.add_argument('--print_freq', default=200, type=int, metavar='N', help='print frequency(default:200)')
parser.add_argument('--save_path', type=str, default='./', help='Folder to save checkpoints and log.')
parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default:none)')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
  1. 'print_freq' : 
  2. 'save_path': 存储检查点和日志的路径。
  3. 'resume': 
  4. 'start_epoch': 第几个epoch开始,用于继续训练,例如,一共10个epoch, 运行到第6个停止了,可以继续从第6个开始训练。
  5. 'evaluate':  是训练模型还是测试/验证模型。 
  6. action='store_true':运行时给‘--evalute’传参数,就将'--evalute'设为True。
  7. dest:
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers(default:2)')
  1. '--ngpu':gpu的标号。
  2. '--workers': 线程数。
parser.add_argument('--manualSeed', type=int, help='manual seed')

 手动控制生成随机数。

parser.add_argument('--rate_norm', type=float, default=0.9, help='the remaining ratio of pruning based on Norm')
parser.add_argument('--rate_dist', type=float, default=0.1, help='the reducing ratio of pruning based on Distance')
parser.add_argument('--layer_begin', type=int, default=1, help='compress layer of model')
parser.add_argument('--layer_end', type=int, default=1. help='compress layer of model')
parser.add_argument('layer_inter', type=int, default=1, help='compress layer of model')
parser.add_argument('--epoch_prune', type=int, default=1, help='compress layer of model')
parser.add_argument('--use_state_dict', dest='use_state_dict', action='store_true', help='use state dict or not')
  1.  '--rate_norm':使用范数作为标准的裁剪率,0.9表示使用范数作为标准进行裁剪的比率为0.1(1-0.9)。
  2. '--rate_dist' :使用距离作为标准的裁剪率,默认是0.1。
  3. '--epoch_prune':在第几个epoch对网络进行剪枝。
  4. '--use_state_dict':是否使用状态字典。模型的state_dict存储的是权值参数,网络的每一层的名字。优化器对象Optimizer也有一个state_dict, 存储优化器的状态以及被使用的超参数,例如 lr , momentum, weight_decay等。
parser.add_argument('--use_pretrain', dest='use_pretrain', action='store_true', help='use pre-trained model or not')
parser.add_argument('--pretrain_path',default='', type=str, help='..path of pre-trained model')
parser.add_argument('--dist_type', default='l2', type=str, choices=['l2', 'l1', 'cos'], help='distance type of GM')
  1. '--use__pretrain':是否使用预训练过的模型。
  2. '--pretrain_path':预训练过的模型存储路径。
  3. '--dist_type': 距离的计算方法。 
拓展:Pytorch保存模型等相关参数,利用torch.save,以及读取保存之后的文件:

首先定义: model(模型)      optimizer(优化器)   epoch(保存模型时,运行到第几个epoch)

state_dict = {'net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}

torch.save(state_dict, path)  # path是存储路径


恢复训练:

 checkpoint = torch.load(dir)
 model.load_state_dict(checkpoint['net']) 
 optimizer.load_state_dict(checkpoint['optimizer'])
 start_eopch = checkpoint['epoch'] + 1




torhc.save()有两种保存方式:

    1. 只保存神经网络的训练模型的参数,torch.save(model.state_dict, path) 
    2. 既保存整个神经网络的模型结构又保存模型参数, torch.save(model, path/'net.pth'):
            model = torch.load(path/'net.pth')
         

 

 

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值