深度学习框架很多,我讲的是deeplearning4j,因为它能和spark结合,代码是java,虽然我java也很烂
数据源是4种动物的照片,有熊,鸭,鹿,龟
示例的分类结果不是很理想,建议我们通过以下方式提高:
1.增加照片数量
2.进行更多的数据预处理
3.增加训练次数,所有数据都训练完了才叫一次训练
4.调整模型配置
5.调整学习率,更新器,激活函数,损失函数,正则化参数等
这也是实战中深度学习问题需要面对的问题
贴出代码,dl4j几乎没什么注释,我只能根据相关资料来分析,有不对的地方还望指出
public class AnimalsClassification { protected static final Logger log = LoggerFactory.getLogger(AnimalsClassification.class);//通过反射获取日志名 protected static int height = 100;//照片是100*100 protected static int width = 100; protected static int channels = 3;//过滤器数量,就是输入层和几个过滤器连接,每个过滤器都按不同规则对输入层进行处理 protected static int numExamples = 80;//80个样本 protected static int numLabels = 4;//4个类别 protected static int batchSize = 20;//每次处理20个样本,这80个样本分4批训练完,参数会更新4次,80个样本训练完了才是一步训练 protected static long seed = 42; protected static Random rng = new Random(seed);//随机数生成器 protected static int listenerFreq = 1;//参数更新一次,就打印一次score protected static int iterations = 1;//每步训练的迭代次数,正常来讲4批是一次,但是我们可以增加迭代次数让每步训练迭代更多次 protected static int epochs = 50;//训练步数,够多的,时间应该很长 protected static double splitTrainTest = 0.8;//80%训练,20%测试 protected static int nCores = 2;//装载数据的队列数 protected static boolean save = false;//不存储 protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out//使用AlexNet网络 public void run(String[] args) throws Exception { log.info("Load data...."); /**cd * Data Setup -> organize and limit data file paths: * - mainPath = path to image files//图片路径 * - fileSplit = define basic dataset split with limits on format//定义数据划分 * - pathFilter = define additional file load filter to limit size and balance batch content//定义额外的文件加载过滤器用来限制大小平衡批内容 **/ ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();//按文件名产生标签0,1,2,3 File mainPath = new File(System.getProperty("user.dir"), "dl4j-examples/src/main/resources/animals/");//图片主路径 FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);//把所有图片弄成一个经过shuffle的数组 BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, batchSize);//平衡每个batch中label的数量 /** * Data Setup -> train test split//划分训练和测试集 * - inputSplit = define train and test split **/ InputSplit[] inputSplit = fileSplit.sample(pathFilter, numExamples * (1 + splitTrainTest), numExamples * (1 - splitTrainTest));//第二个参数是144,第三个是16,这里不是普通的80%,20%,内部有自己的策略,所以最终训练测试条数也跟预想的不一致 InputSplit trainData = inputSplit[0];//训练数据68条 InputSplit testData = inputSplit[1];//测试数据8条