背景
按理讲,夏天是一个比较令人激动和向往的季节,比方说暑假、啤酒、绿茵、绿荫还有大街上的……但这个夏天gemfield最轻松的时候还是使用pytorch的片刻。现在pytorch 1.0正式版就要发布了,这将是一个新的里程碑。而在这之前,gemfield抓紧时间感受了下旧时代的最后一个版本:pytorch 0.4.1,以纪念又一个悄悄过去的夏天。
环境准备
1,使用pytorch github仓库中的Dockerfile build一个pytorch的image;
2,克隆下面的仓库:
这个仓库实现了各种经典的分类网络(在pytorch>=0.4.1的版本上,需要merge一下gemfield的PR),其中就包括经典的resnet50;
3,安装上述仓库中必备的python package,最好是写在Dockerfile里;
4,准备train数据集和val数据集:
比如gemfield的数据集存放路径是/bigdata/gemfield/github/data/,则在这个目录下有如下结构:
train/class1/*.jpg
train/class2/*.jpg
...
train/classN/*.jpg
val/class1/*.jpg
val/class2/*.jpg
...
val/classN/*.jpg
不止支持jpg格式哈。
接下来的训练、测试、服务均在容器中进行。
训练
训练的话使用如下脚本和参数:
/opt/conda/bin/python tars_train.py -idl /bigdata/gemfield/github/data/ -sl checkpoint/ -mo “resnet50” -ep 5 -b 8 -fi
分别指定了数据集路径、模型存放路径、网络的名字、epoch数量、batch size等。
root@gemfield:/bigdata/gemfield/github/pytorch_classifiers# /opt/conda/bin/python tars_train.py -idl /bigdata/gemfield/github/data/ -sl checkpoint/ -