用mxnet跑自己的数据
1 数据准备
参考 http://blog.csdn.net/a350203223/article/details/50263737 把数据转换成 REC 模式。
备注: make_list.py 可以自动生成 train 和 val 的 lst文件。 可使用参数 --train_ratio=XXX
2 跑数据
参考mxnet/example/image-classification里面train_cifar10.py 和 symbol_inception-bn-28-small.py
symbol文件主要用来保存网络结构
一个简单的3层CNN网络
symbol_UCM.py
import find_mxnet
import mxnet as mx
def get_symbol(num_classes = 21):
data = mx.symbol.Variable('data')
# first conv
conv1 = mx.symbol.Convolution(data=data, kernel=(3,3), num_filter=128)
bn1 = mx.symbol.BatchNorm(data=conv1)
relu1 = mx.symbol.Activation(data=bn1, act_type="relu")
pool1 = mx.symbol.Pooling(data=relu1, pool_type="max",
kernel=(5,5), stride=(3,3))
# second conv
conv2 = mx.symbol.Convolution(data=pool1, kernel=(3,3), num_filter=196)
bn2 = mx.symbol.BatchNorm(data=conv2)
relu2 = mx.symbol.Activation(data=bn2, act_type="relu")
pool2 = mx.symbol.Pooling(data=relu2, pool_type="max",
kernel=(3,3), stride=(2,2))
# second conv
conv3 = mx.symbol.Convolution(data=pool2, kernel=(3,3), num_filter=196)
bn3 = mx.symbol.BatchNorm(data=conv3)
relu3 = mx.symbol.Activation(data=bn3, act_type="relu")
pool3 = mx.symbol.Pooling(data=relu3, pool_type="max",
kernel=(2,2), stride=(2,2), name="final_pool")
# first fullc
flatten = mx.symbol.Flatten(data=pool3)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=420)
relu4 = mx.symbol.Activation(data=fc1,