原地址:https://github.com/Sarasra/models/tree/master/research/capsules
运行必须要有GPU
下载两个数据包,第一个是 MNIST tfrecords,第二个是已经训练好的模型文件。
- Download and extract MNIST tfrecords to $DATA_DIR/ from:https://storage.googleapis.com/capsule_toronto/mnist_data.tar.gz
- Download and extract MNIST model checkpoint to $CKPT_DIR from:https://storage.googleapis.com/capsule_toronto/mnist_checkpoints.tar.gz
直接放入对应的文件夹下,结构如图所示:
首先利用训练好的模型进行测试(由于Python版本原因,可能需要把几个py文件中的xrange全部修改为range)在文件夹下打开命令行,输入:
python experiment.py --data_dir=./mnist_data/ --train=false --summary_dir=./tmp/ --checkpoint=./mnist_checkpoint/model.ckpt-1
结果如图:
准确率达到了100%-0.24%=99.76%
然后自己训练网络,输入:
python experiment.py --data_dir=./mnist_data/ --max_steps=300000 --summary_dir=./tmp/attempt0/
等待结果,为了节省时间,我把max_steps设置成3000,代码设置的是每1500次自动保存一次模型
训练完成后在.\capsules\tmp\attempt0\train文件夹下会生成相应的模型文件
model.ckpt-3000.data-00000-of-00001 改名 model.ckpt-1.data-00000-of-00001
model.ckpt-3000.index 改名 model.ckpt-1.index
model.ckpt-3000.meta 改名 model.ckpt-1.meta
然后放到mnist_checkpoint文件夹下,替换原文件,再次输入
python experiment.py --data_dir=./mnist_data/ --train=false --summary_dir=./tmp/ --checkpoint=./mnist_checkpoint/model.ckpt-1
测试结果:
准确率达到了100%-0.94%=99.06%,迭代3000次就有这么高的准确率,可见胶囊神经网络确实很强大。