1、NiN块由一个卷积层和两个1×1卷积层组成。
构建NiN块代码如下
public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
Shape strideShape, Shape paddingShape) {
SequentialBlock tempBlock = new SequentialBlock();
// numChannels 通道数 滤波器层数
//kernelShape 卷积核大小
//strideShape 步幅
// paddingShape 填充大小
tempBlock.add(Conv2d.builder()
.setKernelShape(kernelShape)
.optStride(strideShape)
.optPadding(paddingShape)
.setFilters(numChannels)
.build())
.add(Activation::relu)
.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(numChannels)
.build())
.add(Activation::relu)
.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(numChannels)
.build())
.add(Activation::relu);
return tempBlock;
}
2、构建模型网络代码如下
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
SequentialBlock block = new SequentialBlock();
//构建NiN网络
block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
.add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
.add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
//构建 Dropout 层
.add(Dropout.builder().optRate(0.5f).build())
// There are 10 label classes
.add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
// The global average pooling layer automatically sets the window shape
// to the height and width of the input 平均池化层
.add(Pool.globalAvgPool2dBlock())
// Transform the four-dimensional output into two-dimensional output
// with a shape of (batch size, 10)
.add(Blocks.batchFlattenBlock());
//学习率
float lr = 0.1f;
Model model = Model.newInstance("cnn");
model.setBlock(block);
后面就是准备数据 和训练
整体代码如下
package com.example.demo.djl;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import com.example.demo.djl.covid19.Covid19Models;
import com.example.demo.djl.covid19.Covid19Training;
import org.apache.commons.lang3.ArrayUtils;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class NiNTest {
public static void main(String[] args) throws IOException, ModelException, TranslateException {
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
// 设置模型存放目录
Path modelDir = Paths.get("nin");
//学习率
float lr = 0.05f;
Loss loss = Loss.softmaxCrossEntropyLoss();
Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
try (NDManager manager = NDManager.newBaseManager(); Model model = getNiNModel(); Trainer trainer = model.newTrainer(config)) {
///从一个均匀分布[low,high)中随机采样,注意定义域是左闭右开,即包含low,不包含high.
//
//参数介绍:
//
// low: 采样下界,float类型,默认值为0;
// high: 采样上界,float类型,默认值为1;
// size: 输出样本数目,为int或元组(tuple)类型,例如,size=(m,n,k), 则输出m*n*k个样本,缺省时输出1个值。
Block block = model.getBlock();
NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 3, 224, 224));
trainer.initialize(X.getShape());
Shape currentShape = X.getShape();
for (int i = 0; i < block.getChildren().size(); i++) {
Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{currentShape});
currentShape = newShape[0];
//获取每一层的输出层的Shape
System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
}
// 批训练大小
int batchSize = 128;
//训练次数
int numEpochs = 30;
//训练损失值
double[] trainLoss;
//测试正确率
double[] testAccuracy;
double[] epochCount;
//训练正确率
double[] trainAccuracy;
epochCount = new double[numEpochs];
for (int i = 0; i < epochCount.length; i++) {
epochCount[i] = i + 1;
}
//FashionMnist 数据是 28*28 的灰度图 通道是1 那么第89 行应该改改为 NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224)); 单通道
// FashionMnist trainIter = FashionMnist.builder()
// .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
// .optUsage(Dataset.Usage.TRAIN)
// .setSampling(batchSize, true)
// .build();
//
// FashionMnist testIter = FashionMnist.builder()
// .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
// .optUsage(Dataset.Usage.TEST)
// .setSampling(batchSize, true)
// .build();
//
// trainIter.prepare();
// testIter.prepare();
//使用covid-19 x-ray图片
ImageFolder dataset = Covid19Training.initDataset("D:\\covid19dataset\\COVID-19 Radiography Database\\");
// 设置训练数据和验证数据
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
RandomAccessDataset trainIter = datasets[0];
RandomAccessDataset testIter = datasets[1];
trainIter.prepare();
testIter.prepare();
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
System.out.println();
String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];
Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");
model.save(modelDir, "ninCovid19");
// save labels into model directory
Covid19Models.saveSynset(modelDir, dataset.getSynset());
Table data = Table.create("Data").addColumns(
DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
StringColumn.create("lossLabel", lossLabel)
);
//画图
}
}
public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
Shape strideShape, Shape paddingShape) {
SequentialBlock tempBlock = new SequentialBlock();
// numChannels 通道数 滤波器层数
//kernelShape 卷积核大小
//strideShape 步幅
// paddingShape 填充大小
tempBlock.add(Conv2d.builder()
.setKernelShape(kernelShape)
.optStride(strideShape)
.optPadding(paddingShape)
.setFilters(numChannels)
.build())
.add(Activation::relu)
.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(numChannels)
.build())
.add(Activation::relu)
.add(Conv2d.builder()
.setKernelShape(new Shape(1, 1))
.setFilters(numChannels)
.build())
.add(Activation::relu);
return tempBlock;
}
public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {
trainer.setMetrics(new Metrics());
EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
Metrics metrics = trainer.getMetrics();
trainer.getEvaluators().stream()
.forEach(evaluator -> {
evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
});
avgTrainTimePerEpoch = metrics.mean("epoch");
}
public static Model getNiNModel() {
SequentialBlock block = new SequentialBlock();
//构建NiN网络
block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
.add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
.add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
//构建 Dropout 层
.add(Dropout.builder().optRate(0.5f).build())
// There are 10 label classes
.add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
// The global average pooling layer automatically sets the window shape
// to the height and width of the input 平均池化层
.add(Pool.globalAvgPool2dBlock())
// Transform the four-dimensional output into two-dimensional output
// with a shape of (batch size, 10)
.add(Blocks.batchFlattenBlock());
Model model = Model.newInstance("ninCovid19");
model.setBlock(block);
return model;
}
}
准备训练数据的代码见 文章 JAVA深度学习框架DJL之COVID19 x-ray图片分类 里面的代码
还需要pom文件添加
<dependency>
<groupId>tech.tablesaw</groupId>
<artifactId>tablesaw-jsplot</artifactId>
<version>0.30.4</version>
</dependency>