package com.example.demo.djl;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import ai.djl.util.PairList;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
public class DenseBlock extends AbstractBlock {
private static final byte VERSION = 2;
public SequentialBlock net = new SequentialBlock();
public static void main(String[] args) throws IOException, ModelException, TranslateException {
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
int batchSize = 256;
float lr = 0.1f;
int numEpochs = 10;
double[] trainLoss;
double[] testAccuracy;
double[] epochCount;
double[] trainAccuracy;
epochCount = new double[numEpochs];
for (int i = 0; i < epochCount.length; i++) {
epochCount[i] = (i + 1);
}
Loss loss = Loss.softmaxCrossEntropyLoss();
Tracker lrt = Tracker.fixed(lr);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
.addEvaluator(new Accuracy()) // Model Accuracy
.addTrainingListeners(TrainingListener.Defaults.logging()); // Logging
try (NDManager manager = NDManager.newBaseManager(); Model model = DenseBlock.getModel(); Trainer trainer = model.newTrainer(config);) {
///
//
// SequentialBlock block = new SequentialBlock()
// .add(new DenseBlock(2, 10));
// // new DenseBlock(x, y) new Shape(a, b, c, d)
// NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 3, 8, 8));
//
// block.setInitializer(new XavierInitializer());
// block.initialize(manager, DataType.FLOAT32, X.getShape());
//
// ParameterStore parameterStore = new ParameterStore(manager, true);
//
// Shape currentShape = X.getShape();
//
// for (int i = 0; i < block.getChildren().size(); i++) {
//
// Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{X.getShape()});
// currentShape = newShape[0];
// // currentShape = (a,x*y+b,c,d)
// System.out.println(currentShape);
// }
/
FashionMnist trainIter =
FashionMnist.builder()
.optPipeline(new Pipeline().add(new Resize(96)).add(new ToTensor()))
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.build();
FashionMnist testIter =
FashionMnist.builder()
.optPipeline(new Pipeline().add(new Resize(96)).add(new ToTensor()))
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.build();
trainIter.prepare();
testIter.prepare();
trainer.initialize(new Shape(1, 1, 96, 96));
Map<String, double[]> evaluatorMetrics = new HashMap<>();
double avgTrainTimePerEpoch = 0;
trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);
trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");
System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
System.out.println();
}
}
public DenseBlock(int numConvs, int numChannels) {
super(VERSION);
for (int i = 0; i < numConvs; i++) {
this.net.add(
addChildBlock("denseBlock" + i, convBlock(numChannels))
);
}
}
@Override
public String toString() {
return "DenseBlock()";
}
@Override
public NDList forward(ParameterStore parameterStore, NDList X, boolean training, PairList<String, Object> pairList) {
NDArray Y;
for (Block block : this.net.getChildren().values()) {
Y = block.forward(parameterStore, X, training).singletonOrThrow();
X = new NDList(NDArrays.concat(new NDList(X.singletonOrThrow(), Y), 1));
}
return X;
}
@Override
public Shape[] getOutputShapes(NDManager ndManager, Shape[] inputs) {
Shape[] shapesX = inputs;
for (Block block : this.net.getChildren().values()) {
Shape[] shapesY = block.getOutputShapes(ndManager, shapesX);
shapesX[0] = new Shape(
shapesX[0].get(0),
shapesY[0].get(1) + shapesX[0].get(1),
shapesX[0].get(2),
shapesX[0].get(3)
);
}
return shapesX;
}
@Override
public void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {
Shape shapesX = inputShapes[0];
for (Block block : this.net.getChildren().values()) {
Shape[] shapesY = block.initialize(manager, DataType.FLOAT32, shapesX);
shapesX = new Shape(
shapesX.get(0),
shapesY[0].get(1) + shapesX.get(1),
shapesX.get(2),
shapesX.get(3)
);
}
}
public static SequentialBlock convBlock(int numChannels) {
SequentialBlock block = new SequentialBlock()
.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(Conv2d.builder()
.setFilters(numChannels)
.setKernelShape(new Shape(3, 3))
.optPadding(new Shape(1, 1))
.optStride(new Shape(1, 1))
.build()
);
return block;
}
public static SequentialBlock transitionBlock(int numChannels) {
SequentialBlock blk = new SequentialBlock()
.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(
Conv2d.builder()
.setFilters(numChannels)
.setKernelShape(new Shape(1, 1))
.optStride(new Shape(1, 1))
.build()
)
.add(Pool.avgPool2dBlock(new Shape(2, 2), new Shape(2, 2)));
return blk;
}
public static Model getModel() {
Model model = Model.newInstance("DenseNet");
SequentialBlock net = new SequentialBlock()
.add(Conv2d.builder()
.setFilters(64)
.setKernelShape(new Shape(7, 7))
.optStride(new Shape(2, 2))
.optPadding(new Shape(3, 3))
.build())
.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2), new Shape(1, 1)));
int numChannels = 64;
int growthRate = 32;
int[] numConvsInDenseBlocks = new int[]{4, 4, 4, 4};
for (int index = 0; index < numConvsInDenseBlocks.length; index++) {
int numConvs = numConvsInDenseBlocks[index];
net.add(new DenseBlock(numConvs, growthRate));
numChannels += (numConvs * growthRate);
if (index != (numConvsInDenseBlocks.length - 1)) {
numChannels = (numChannels / 2);
net.add(transitionBlock(numChannels));
}
}
net.add(BatchNorm.builder().build())
.add(Activation::relu)
.add(Pool.globalAvgPool2dBlock())
//10 图片分类的 类别数
.add(Linear.builder().setUnits(10).build());
System.out.println(net);
model.setBlock(net);
return model;
}
public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {
trainer.setMetrics(new Metrics());
EasyTrain.fit(trainer, numEpochs, trainIter, testIter);
Metrics metrics = trainer.getMetrics();
trainer.getEvaluators().stream()
.forEach(evaluator -> {
evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
.mapToDouble(x -> x.getValue().doubleValue()).toArray());
});
avgTrainTimePerEpoch = metrics.mean("epoch");
}
}
java深度学习之DJL创建DenseNet网络
于 2021-02-06 20:41:43 首次发布
该博客展示了如何使用DJL库实现DenseBlock网络结构,并进行训练。代码中详细解释了DenseBlock的构建过程,包括卷积层、批量归一化和激活函数的使用。此外,还提供了训练配置、数据预处理、训练循环以及模型评估的步骤,以FashionMnist数据集为例进行模型训练和验证。
摘要由CSDN通过智能技术生成