把我自己对这篇论文《Pruning filters for efficient convnets》的学习过程发出来,当时什么都不会,只能一点一点进行注释hhh
train.py
def train_network(args, network=None, data_set=None):
# 将代码分配到设备
device = torch.device("cuda" if args.gpu_no >= 0 else "cpu")
if network is None:
# 根据输入看使用那个vgg模型和数据集
network = VGG(args.vgg, args.data_set)
network = network.to(device)
if data_set is None:
data_set = get_data_set(args, train_flag=True)
# 定义loss类
loss_calculator = Loss_Calculator()
# 神经网络优化器,主要是为了优化我们的神经网络,冲量、正则、SGD(不需要每次全部读入数据,可以分批读入)等,https://github.com/cen6667/Pruning_filters_for_efficient_convnets
optimizer, scheduler = get_optimizer(network, args)
# 恢复训练的标志
if args.resume_flag:
# 加载训练好的模型
check_point = torch.load(args.load_path)
# 将加载好的模型加载到网络中
network.load_state_dict(check_point['