利用微信监管MXNet训练

最近一直在跟沐神学习MXNet轮子。论坛的小伙伴很不错,gluon的特效也很简单实用(听说效率和显存的利用率都比其他的要高)。

无聊在知乎看到有人用用微信可以监管TF的训练结果——利用微信监管你的TF训练。国庆既然没得地方玩,就试着仿照作者做了个MXNet的微信监管。

功能主要有:

  1. 设置参数,主要有learning_rate、training_iters、batch_size
  2. 开始停止程序,反馈结果

这里写图片描述
这里写图片描述

利用的mnist数据集的cnn程序,代码也是一分为二,将cnn的主要内容写在nn_train函数里;对于微信,利用了itchat包(所以先要pip install itchat),写在了 itchat的handler里。

主要代码如下:

# -*- coding: utf-8 -*-

from mxnet.gluon import nn
import itchat
import threading
import utils

lock = threading.Lock()
running = False

from mxnet import autograd
from mxnet import gluon
from mxnet import nd

batch_size = 256
learning_rate = 0.5
training_iters = 2

def nn_train(wechat_name,param):
    global lock, running
    with lock:
        running = True


    learning_rate, training_iters, batch_size = param
    train_data, test_data = utils.load_data_fashion_mnist(batch_size)
    softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

    net = nn.Sequential()
    with net.name_scope():
        net.add(nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
        net.add(nn.MaxPool2D(pool_size=2, strides=2))
        net.add(nn.Conv2D(channels=50, kernel_size=3, activation='relu'))
        net.add(nn.MaxPool2D(pool_size=2, strides=2))
        net.add(nn.Flatten())
        net.add(nn.Dense(128, activation="relu"))
        net.add(nn.Dense(10))
    ctx = utils.try_gpu()
    net.initialize(ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':learning_rate})

    print('wait for lock')
    with lock:
        run_state = running
    print('start')
    epoch = 1
    while run_state and epoch < training_iters:

        train_loss = 0
        train_acc = 0
        for data,label in train_data:
            label = label.as_in_context(ctx)
            with autograd.record():
                output = net(data.as_in_context(ctx))
                loss = softmax_cross_entropy(output,label)
            loss.backward()
            trainer.step(batch_size)

            train_acc += utils.accuracy(output,label)
            train_loss += nd.mean(loss).asscalar()
        test_acc = utils.evaluate_accuracy(test_data, net, ctx)
        itchat.send("Epoch %d.\nLoss: %f\nTrain acc %f\nTest acc %f" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc), wechat_name)
        print("Epoch %d. Loss: %f, Train acc %f, Test acc %f\n" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc))
        epoch += 1
        with lock:
            run_state = running
    print('op is finished!')
    itchat.send('op is finished!',wechat_name)

    with lock:
        running = False     
@itchat.msg_register([itchat.content.TEXT])
def chat_trigger(msg):
    global lock, running, learning_rate, training_iters, batch_size
    if msg['Text'] == u'开始':
        print('Starting')
        with lock:
            run_state = running
        if not run_state:
            try:
                threading.Thread(target=nn_train, args=(msg['FromUserName'], (learning_rate, training_iters, batch_size))).start()
            except:
                msg.reply('Running')
    elif msg['Text'] == u'停止':
        print('Stopping')

        with lock:
            running = False

    elif msg['Text'] == u'参数':
        itchat.send('lr=%f, ti=%d, bs=%d'%(learning_rate, training_iters, batch_size),msg['FromUserName'])

    else:
        try:
            param = msg['Text'].split()
            key, value = param
            print(key, value)
            if key == 'lr':
                learning_rate = float(value)
            elif key == 'ti':
                training_iters = int(value)
            elif key == 'bs':
                batch_size = int(value)

        except:
            pass

if __name__ == '__main__':
    itchat.auto_login(hotReload=True)
    itchat.run()

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值