MXNet学习笔记——3 Module

写在前面

本系列博客记录了作者上手MXNet的全过程。作者在接触MXNet之前主要使用keras,和一点tensorflow,因此在上手MXNet之前有一点deep learning的项目基础。主要参考资料为MXNet官方教程,也阅读了一些有价值的博客。

博客结构为:先列出作者对于该阶段的期望目标,以及各目标完成过程中的笔记(仅记下个人认为重要的),再附上学习过程中自己的提问(solved & unsolved,天马行空的提问,欢迎讨论)。


本阶段目标

学习MXNet中module模块(用来神经网络训练和预测),这部分较简单,并且和keras的指令相似。

具体笔记

构建模型

mod = mx.mod.Module(symbol=net, #计算图 
                    context=mx.cpu(), #设备
                    data_names=['data'], #输入变量名称
                    label_names=['softmax_label'])  #输出变量名称

中间层次接口

使用forward\backward\更新参数等进行每一次迭代中的计算

# 将train_iter的信息提供给module,可根据shape分配内存的大小
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
# 参数初始化
mod.init_params(initializer=mx.init.Uniform(scale=.1))
# lr=0.1的sgd
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), ))
# accuracy作为评估指标
metric = mx.metric.create('acc')

for epoch in range(5):
    train_iter.reset()
    metric.reset()
    for batch in train_iter:
        mod.forward(batch, is_train=True)       # 前向计算
        mod.update_metric(metric, batch.label)  # 更新评估指标
        mod.backward()                          # 反向计算
        mod.update()                            # 更新参数
    print('Epoch %d, Training %s' % (epoch, metric.get()))

高层次接口

  • 训练:mod.fit
# 重设迭代器
train_iter.reset()

# 创建模型
mod = mx.mod.Module(symbol=net,
                    context=mx.cpu(),
                    data_names=['data'],
                    label_names=['softmax_label'])

mod.fit(train_iter,
        eval_data=val_iter,
        optimizer='sgd',
        optimizer_params={'learning_rate':0.1},
        eval_metric='acc',
        num_epoch=8)

mod.fit()可以加上模型的参数,这样就告诉代码从这些参数开始训练而不是先初始化

# 存储下目前网络的结构和参数信息
sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3) #(prefix,epoc)

mod = mx.mod.Module(symbol=sym)
mod.fit(train_iter,
        num_epoch=21,
        arg_params=arg_params, # 参数
        aux_params=aux_params,
        begin_epoch=3)
  • 预测:mod.predict()
  • 评估:mod.score()
# 在训练结束后还可以使用score()对验证集进行最后的评估
score = mod.score(val_iter, ['acc'])
print("Accuracy score is %f" % (score[0][1]))
assert score[0][1] > 0.77, "Achieved accuracy (%f) is less than expected (0.77)" % score[0][1]

提出问题

AboutQuestionAnswer
pythonassert的作用?

在错误出现之前,当错误条件出现的时候就报出异常。

assert expression

等价于

if not epression:
    
    raise AssertionError

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值