进入get_model()
,查看定义
def get_model(args):
if args.first_layer_dense:
args.first_layer_type = "DenseConv"
print("=> Creating model '{}'".format(args.arch))
if args.set == 'ImageNet' or args.set == 'TinyImageNet':
num_classes = 1000
elif args.set == 'CIFAR100':
num_classes = 100
else:
num_classes = 10
model = models.__dict__[args.arch](num_classes=num_classes)
# applying sparsity to the network
if (
args.conv_type != "DenseConv"
and args.conv_type != "SampleSubnetConv"
and args.conv_type != "ContinuousSparseConv"
):
if args.prune_rate < 0:
raise ValueError("Need to set a positive prune rate")
set_model_prune_rate(model, prune_rate=args.prune_rate)
print(
f"=> Rough estimate model params {sum(int(p.numel() * (1-args.prune_rate)) for n, p in model.named_parameters() if not n.endswith('scores'))}"
)
# freezing the weights if we are only doing subnet training
# if args.freeze_weights:
# freeze_model_weights(model)
return model
以README.md(详情点击这里)的To search an RST from a randomly initialized PreActResNet18 on CIFAR-10部分code为例。进入config_rst/resnet18-usc-unsigned-cifar.yaml
文件,可得arg.arch = cResNet18
、args.set=CIFAR10
,故model = models.__dict__[args.arch](num_classes=num_classes)
其实为model = models.__dict__[cResNet18](10)
。右键打开models.__dict__
的定义,如下图:
在第3行找到了cResNet18
,故右键打开models.resnet_cifar
定义,看到cResNet18
的函数定义
def cResNet18(num_classes=10):
return ResNet(get_builder(), BasicBlock, [2, 2, 2, 2], num_classes=num_classes)