最近一直在跟沐神学习MXNet轮子。论坛的小伙伴很不错,gluon的特效也很简单实用(听说效率和显存的利用率都比其他的要高)。
无聊在知乎看到有人用用微信可以监管TF的训练结果——利用微信监管你的TF训练。国庆既然没得地方玩,就试着仿照作者做了个MXNet的微信监管。
功能主要有:
- 设置参数,主要有learning_rate、training_iters、batch_size
- 开始停止程序,反馈结果
利用的mnist数据集的cnn程序,代码也是一分为二,将cnn的主要内容写在nn_train函数里;对于微信,利用了itchat包(所以先要pip install itchat),写在了 itchat的handler里。
主要代码如下:
# -*- coding: utf-8 -*-
from mxnet.gluon import nn
import itchat
import threading
import utils
lock = threading.Lock()
running = False
from mxnet import autograd
from mxnet import gluon
from mxnet import nd
batch_size = 256
learning_rate = 0.5
training_iters = 2
def nn_train(wechat_name,param):
global lock, running
with lock:
running = True
learning_rate, training_iters, batch_size = param
train_data, test_data = utils.load_data_fashion_mnist(batch_size)
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
net = nn.Sequential()
with net.name_scope():
net.add(nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
net.add(nn.MaxPool2D(pool_size=2, strides=2))
net.add(nn.Conv2D(channels=50, kernel_size=3, activation='relu'))
net.add(nn.MaxPool2D(pool_size=2, strides=2))
net.add(nn.Flatten())
net.add(nn.Dense(128, activation="relu"))
net.add(nn.Dense(10))
ctx = utils.try_gpu()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':learning_rate})
print('wait for lock')
with lock:
run_state = running
print('start')
epoch = 1
while run_state and epoch < training_iters:
train_loss = 0
train_acc = 0
for data,label in train_data:
label = label.as_in_context(ctx)
with autograd.record():
output = net(data.as_in_context(ctx))
loss = softmax_cross_entropy(output,label)
loss.backward()
trainer.step(batch_size)
train_acc += utils.accuracy(output,label)
train_loss += nd.mean(loss).asscalar()
test_acc = utils.evaluate_accuracy(test_data, net, ctx)
itchat.send("Epoch %d.\nLoss: %f\nTrain acc %f\nTest acc %f" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc), wechat_name)
print("Epoch %d. Loss: %f, Train acc %f, Test acc %f\n" % (epoch, train_loss/len(train_data),train_acc/len(train_data), test_acc))
epoch += 1
with lock:
run_state = running
print('op is finished!')
itchat.send('op is finished!',wechat_name)
with lock:
running = False
@itchat.msg_register([itchat.content.TEXT])
def chat_trigger(msg):
global lock, running, learning_rate, training_iters, batch_size
if msg['Text'] == u'开始':
print('Starting')
with lock:
run_state = running
if not run_state:
try:
threading.Thread(target=nn_train, args=(msg['FromUserName'], (learning_rate, training_iters, batch_size))).start()
except:
msg.reply('Running')
elif msg['Text'] == u'停止':
print('Stopping')
with lock:
running = False
elif msg['Text'] == u'参数':
itchat.send('lr=%f, ti=%d, bs=%d'%(learning_rate, training_iters, batch_size),msg['FromUserName'])
else:
try:
param = msg['Text'].split()
key, value = param
print(key, value)
if key == 'lr':
learning_rate = float(value)
elif key == 'ti':
training_iters = int(value)
elif key == 'bs':
batch_size = int(value)
except:
pass
if __name__ == '__main__':
itchat.auto_login(hotReload=True)
itchat.run()